Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various MTSAC bug fixes #1975

Merged
merged 1 commit into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 21 additions & 15 deletions examples/torch/mtsac_metaworld_ml1_pick_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,30 @@ def mtsac_metaworld_ml1_pick_place(ctxt=None, seed=1, _gpu=None):
epochs = timesteps // batch_size
epoch_cycles = epochs // num_evaluation_points
epochs = epochs // epoch_cycles
mtsac = MTSAC(policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
eval_env=ml1_test_envs,
env_spec=ml1_train_envs.spec,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
min_buffer_size=1500,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=1280)
mtsac = MTSAC(
policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=ml1_test_envs,
env_spec=ml1_train_envs.spec,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please only call args as args and kwargs as kwargs.

min_buffer_size=1500,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=1280,
)
if _gpu is not None:
set_gpu_mode(True, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=ml1_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=ml1_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
6 changes: 5 additions & 1 deletion examples/torch/mtsac_metaworld_mt10.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None):
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=mt10_test_envs,
env_spec=mt10_train_envs.spec,
num_tasks=10,
Expand All @@ -95,7 +96,10 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None):
if _gpu is not None:
set_gpu_mode(True, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=mt10_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=mt10_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why limit the number of workers here, and can't we use ray?

runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
8 changes: 6 additions & 2 deletions examples/torch/mtsac_metaworld_mt50.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0):
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=mt50_test_envs,
env_spec=mt50_train_envs.spec,
num_tasks=10,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
min_buffer_size=7500,
Expand All @@ -96,7 +97,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0):
buffer_batch_size=6400)
set_gpu_mode(use_gpu, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=mt50_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=mt50_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
19 changes: 19 additions & 0 deletions src/garage/torch/algos/mtsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class MTSAC(SAC):
agent is being trained in.
num_tasks (int): The number of tasks being learned.
max_episode_length (int): The max episode length of the algorithm.
max_episode_length_eval (int or None): Maximum length of episodes used
for off-policy evaluation. If None, defaults to
`max_episode_length`.
eval_env (Environment): The environment used for collecting evaluation
episodes.
gradient_steps_per_itr (int): Number of optimization steps that should
Expand Down Expand Up @@ -75,10 +78,12 @@ def __init__(
qf2,
replay_buffer,
env_spec,
*,
num_tasks,
max_episode_length,
eval_env,
gradient_steps_per_itr,
max_episode_length_eval=None,
fixed_alpha=None,
target_entropy=None,
initial_log_entropy=0.,
Expand All @@ -100,6 +105,7 @@ def __init__(
replay_buffer=replay_buffer,
env_spec=env_spec,
max_episode_length=max_episode_length,
max_episode_length_eval=max_episode_length_eval,
gradient_steps_per_itr=gradient_steps_per_itr,
fixed_alpha=fixed_alpha,
target_entropy=target_entropy,
Expand Down Expand Up @@ -153,13 +159,25 @@ def _get_log_alpha(self, samples_data):
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`

Raises:
ValueError: If the number of tasks, num_tasks passed to
this algorithm doesn't match the length of the task
one-hot id in the observation vector.

Returns:
torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size)

"""
obs = samples_data['observation']
log_alpha = self._log_alpha
one_hots = obs[:, -self._num_tasks:]
if (log_alpha.shape[0] != one_hots.shape[1]
or one_hots.shape[1] != self._num_tasks
or log_alpha.shape[0] != self._num_tasks):
raise ValueError(
'The number of tasks in the environment does '
'not match self._num_tasks. Are you sure that you passed '
'The correct number of tasks?')
ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze()
return ret

Expand All @@ -183,6 +201,7 @@ def _evaluate_policy(self, epoch):
obtain_evaluation_episodes(
self.policy,
self._eval_env,
self._max_episode_length_eval,
num_eps=self._num_evaluation_episodes))
eval_eps = EpisodeBatch.concatenate(*eval_eps)
last_return = log_multitask_performance(epoch, eval_eps,
Expand Down
1 change: 0 additions & 1 deletion src/garage/torch/algos/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class SAC(RLAlgorithm):
for off-policy evaluation. If None, defaults to
`max_episode_length`.
gradient_steps_per_itr (int): Number of optimization steps that should
max_episode_length(int): Max episode length of the environment.
gradient_steps_per_itr(int): Number of optimization steps that should
occur before the training step is over and a new batch of
transitions is collected by the sampler.
Expand Down
55 changes: 55 additions & 0 deletions tests/garage/torch/algos/test_mtsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,61 @@ def test_mtsac_get_log_alpha(monkeypatch):
assert log_alpha.size() == torch.Size([mtsac._buffer_batch_size])


@pytest.mark.mujoco
def test_mtsac_get_log_alpha_incorrect_num_tasks(monkeypatch):
"""Check that if the num_tasks passed does not match the number of tasks

in the environment, then the algorithm should raise an exception.

MTSAC uses disentangled alphas, meaning that

"""
env_names = ['CartPole-v0', 'CartPole-v1']
task_envs = [GymEnv(name) for name in env_names]
env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy)
deterministic.set_seed(0)
policy = TanhGaussianMLPPolicy(
env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=torch.nn.ReLU,
output_nonlinearity=None,
min_std=np.exp(-20.),
max_std=np.exp(2.),
)

qf1 = ContinuousMLPQFunction(env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=F.relu)

qf2 = ContinuousMLPQFunction(env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=F.relu)
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), )

buffer_batch_size = 2
mtsac = MTSAC(policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
eval_env=env,
env_spec=env.spec,
num_tasks=4,
steps_per_epoch=5,
replay_buffer=replay_buffer,
min_buffer_size=1e3,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=buffer_batch_size)
monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.]))
error_string = ('The number of tasks in the environment does '
'not match self._num_tasks. Are you sure that you passed '
'The correct number of tasks?')
obs = torch.Tensor([env.reset()[0]] * buffer_batch_size)
with pytest.raises(ValueError, match=error_string):
mtsac._get_log_alpha(dict(observation=obs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to test a private method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true, however, the tests are in place to verify the correctness of this implementation. I feel more comfortable having the tests than not. At the same time, it makes no sense to have this as a publicly exposed field because it has no use outside of the algorithm.



@pytest.mark.mujoco
def test_mtsac_inverted_double_pendulum():
"""Performance regression test of MTSAC on 2 InvDoublePendulum envs."""
Expand Down