-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Various MTSAC bug fixes #1975
Various MTSAC bug fixes #1975
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,7 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None): | |
qf2=qf2, | ||
gradient_steps_per_itr=150, | ||
max_episode_length=150, | ||
max_episode_length_eval=150, | ||
eval_env=mt10_test_envs, | ||
env_spec=mt10_train_envs.spec, | ||
num_tasks=10, | ||
|
@@ -95,7 +96,10 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None): | |
if _gpu is not None: | ||
set_gpu_mode(True, _gpu) | ||
mtsac.to() | ||
runner.setup(algo=mtsac, env=mt10_train_envs, sampler_cls=LocalSampler) | ||
runner.setup(algo=mtsac, | ||
env=mt10_train_envs, | ||
sampler_cls=LocalSampler, | ||
n_workers=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why limit the number of workers here, and can't we use ray? |
||
runner.train(n_epochs=epochs, batch_size=batch_size) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,61 @@ def test_mtsac_get_log_alpha(monkeypatch): | |
assert log_alpha.size() == torch.Size([mtsac._buffer_batch_size]) | ||
|
||
|
||
@pytest.mark.mujoco | ||
def test_mtsac_get_log_alpha_incorrect_num_tasks(monkeypatch): | ||
"""Check that if the num_tasks passed does not match the number of tasks | ||
|
||
in the environment, then the algorithm should raise an exception. | ||
|
||
MTSAC uses disentangled alphas, meaning that | ||
|
||
""" | ||
env_names = ['CartPole-v0', 'CartPole-v1'] | ||
task_envs = [GymEnv(name) for name in env_names] | ||
env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy) | ||
deterministic.set_seed(0) | ||
policy = TanhGaussianMLPPolicy( | ||
env_spec=env.spec, | ||
hidden_sizes=[1, 1], | ||
hidden_nonlinearity=torch.nn.ReLU, | ||
output_nonlinearity=None, | ||
min_std=np.exp(-20.), | ||
max_std=np.exp(2.), | ||
) | ||
|
||
qf1 = ContinuousMLPQFunction(env_spec=env.spec, | ||
hidden_sizes=[1, 1], | ||
hidden_nonlinearity=F.relu) | ||
|
||
qf2 = ContinuousMLPQFunction(env_spec=env.spec, | ||
hidden_sizes=[1, 1], | ||
hidden_nonlinearity=F.relu) | ||
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) | ||
|
||
buffer_batch_size = 2 | ||
mtsac = MTSAC(policy=policy, | ||
qf1=qf1, | ||
qf2=qf2, | ||
gradient_steps_per_itr=150, | ||
max_episode_length=150, | ||
eval_env=env, | ||
env_spec=env.spec, | ||
num_tasks=4, | ||
steps_per_epoch=5, | ||
replay_buffer=replay_buffer, | ||
min_buffer_size=1e3, | ||
target_update_tau=5e-3, | ||
discount=0.99, | ||
buffer_batch_size=buffer_batch_size) | ||
monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.])) | ||
error_string = ('The number of tasks in the environment does ' | ||
'not match self._num_tasks. Are you sure that you passed ' | ||
'The correct number of tasks?') | ||
obs = torch.Tensor([env.reset()[0]] * buffer_batch_size) | ||
with pytest.raises(ValueError, match=error_string): | ||
mtsac._get_log_alpha(dict(observation=obs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to test a private method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is true, however, the tests are in place to verify the correctness of this implementation. I feel more comfortable having the tests than not. At the same time, it makes no sense to have this as a publicly exposed field because it has no use outside of the algorithm. |
||
|
||
|
||
@pytest.mark.mujoco | ||
def test_mtsac_inverted_double_pendulum(): | ||
"""Performance regression test of MTSAC on 2 InvDoublePendulum envs.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please only call args as args and kwargs as kwargs.