Skip to content

Commit

Permalink
[RLlib] SAC: log_alpha not being learnt when on GPU. (#11298)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Oct 12, 2020
1 parent 7dcfd25 commit f5e2cda
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
7 changes: 3 additions & 4 deletions rllib/agents/sac/sac_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def build_q_net(name_):
else:
self.twin_q_net = None

self.log_alpha = torch.tensor(
data=[np.log(initial_alpha)],
dtype=torch.float32,
requires_grad=True)
log_alpha = nn.Parameter(
torch.from_numpy(np.array([np.log(initial_alpha)])).float())
self.register_parameter("log_alpha", log_alpha)

# Auto-calculate the target entropy.
if target_entropy is None or target_entropy == "auto":
Expand Down
25 changes: 17 additions & 8 deletions rllib/agents/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_sac_loss_function(self):
"_model.0.weight",
"default_policy/value_out/bias": "_value_branch."
"_model.0.bias",
"default_policy/log_alpha": "log_alpha",
# Target net.
"default_policy/sequential_2/action_1/kernel": "action_model."
"action_0._model.0.weight",
Expand All @@ -130,6 +131,7 @@ def test_sac_loss_function(self):
"_model.0.weight",
"default_policy/value_out_1/bias": "_value_branch."
"_model.0.bias",
"default_policy/log_alpha_1": "log_alpha",
}

env = SimpleEnv
Expand Down Expand Up @@ -186,7 +188,7 @@ def test_sac_loss_function(self):
# Actually convert to torch tensors (by accessing everything).
input_ = policy._lazy_tensor_dict(input_)
input_ = {k: input_[k] for k in input_.keys()}
log_alpha = policy.model.log_alpha.detach().numpy()[0]
log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]

# Only run the expectation once, should be the same anyways
# for all frameworks.
Expand Down Expand Up @@ -259,7 +261,7 @@ def test_sac_loss_function(self):
]
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g))
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)

Expand All @@ -283,7 +285,7 @@ def test_sac_loss_function(self):
torch_c_grads = [v.grad for v in policy.model.q_variables()]
for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g))
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)
# Compare (unchanged(!) actor grads) with tf ones.
Expand All @@ -292,7 +294,7 @@ def test_sac_loss_function(self):
]
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g))
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)

Expand Down Expand Up @@ -354,7 +356,10 @@ def test_sac_loss_function(self):
tf_var = tf_weights[tf_key]
torch_var = policy.model.state_dict()[map_[tf_key]]
if tf_var.shape != torch_var.shape:
check(tf_var, np.transpose(torch_var), rtol=0.05)
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
rtol=0.05)
else:
check(tf_var, torch_var, rtol=0.05)
# And alpha.
Expand All @@ -366,7 +371,10 @@ def test_sac_loss_function(self):
torch_var = policy.target_model.state_dict()[map_[
tf_key]]
if tf_var.shape != torch_var.shape:
check(tf_var, np.transpose(torch_var), rtol=0.05)
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
rtol=0.05)
else:
check(tf_var, torch_var, rtol=0.05)

Expand Down Expand Up @@ -510,9 +518,10 @@ def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
def _translate_weights_to_torch(self, weights_dict, map_):
model_dict = {
map_[k]: convert_to_torch_tensor(
np.transpose(v) if re.search("kernel", k) else v)
np.transpose(v) if re.search("kernel", k) else np.array([v])
if re.search("log_alpha", k) else v)
for k, v in weights_dict.items()
if re.search("(sequential(/|_1)|value_out/)", k)
if re.search("(sequential(/|_1)|value_out/|log_alpha)", k)
}
return model_dict

Expand Down

0 comments on commit f5e2cda

Please sign in to comment.