Skip to content

Commit

Permalink
Various MTSAC bug fixes
Browse files Browse the repository at this point in the history
fixes to Examples to use the correct num_tasks

fixes to max_episode_length_eval being used by the algorithm

Co-authored-by: Tianhong Dai <tianhongdai914@gmail.com>
  • Loading branch information
2 people authored and mergify-bot committed Aug 28, 2020
1 parent a2fb966 commit c83d34e
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 19 deletions.
36 changes: 21 additions & 15 deletions examples/torch/mtsac_metaworld_ml1_pick_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,30 @@ def mtsac_metaworld_ml1_pick_place(ctxt=None, seed=1, _gpu=None):
epochs = timesteps // batch_size
epoch_cycles = epochs // num_evaluation_points
epochs = epochs // epoch_cycles
mtsac = MTSAC(policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
eval_env=ml1_test_envs,
env_spec=ml1_train_envs.spec,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
min_buffer_size=1500,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=1280)
mtsac = MTSAC(
policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=ml1_test_envs,
env_spec=ml1_train_envs.spec,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
min_buffer_size=1500,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=1280,
)
if _gpu is not None:
set_gpu_mode(True, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=ml1_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=ml1_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
6 changes: 5 additions & 1 deletion examples/torch/mtsac_metaworld_mt10.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None):
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=mt10_test_envs,
env_spec=mt10_train_envs.spec,
num_tasks=10,
Expand All @@ -95,7 +96,10 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None):
if _gpu is not None:
set_gpu_mode(True, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=mt10_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=mt10_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
8 changes: 6 additions & 2 deletions examples/torch/mtsac_metaworld_mt50.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0):
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
max_episode_length_eval=150,
eval_env=mt50_test_envs,
env_spec=mt50_train_envs.spec,
num_tasks=10,
num_tasks=50,
steps_per_epoch=epoch_cycles,
replay_buffer=replay_buffer,
min_buffer_size=7500,
Expand All @@ -96,7 +97,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0):
buffer_batch_size=6400)
set_gpu_mode(use_gpu, _gpu)
mtsac.to()
runner.setup(algo=mtsac, env=mt50_train_envs, sampler_cls=LocalSampler)
runner.setup(algo=mtsac,
env=mt50_train_envs,
sampler_cls=LocalSampler,
n_workers=1)
runner.train(n_epochs=epochs, batch_size=batch_size)


Expand Down
19 changes: 19 additions & 0 deletions src/garage/torch/algos/mtsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class MTSAC(SAC):
agent is being trained in.
num_tasks (int): The number of tasks being learned.
max_episode_length (int): The max episode length of the algorithm.
max_episode_length_eval (int or None): Maximum length of episodes used
for off-policy evaluation. If None, defaults to
`max_episode_length`.
eval_env (Environment): The environment used for collecting evaluation
episodes.
gradient_steps_per_itr (int): Number of optimization steps that should
Expand Down Expand Up @@ -75,10 +78,12 @@ def __init__(
qf2,
replay_buffer,
env_spec,
*,
num_tasks,
max_episode_length,
eval_env,
gradient_steps_per_itr,
max_episode_length_eval=None,
fixed_alpha=None,
target_entropy=None,
initial_log_entropy=0.,
Expand All @@ -100,6 +105,7 @@ def __init__(
replay_buffer=replay_buffer,
env_spec=env_spec,
max_episode_length=max_episode_length,
max_episode_length_eval=max_episode_length_eval,
gradient_steps_per_itr=gradient_steps_per_itr,
fixed_alpha=fixed_alpha,
target_entropy=target_entropy,
Expand Down Expand Up @@ -153,13 +159,25 @@ def _get_log_alpha(self, samples_data):
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Raises:
ValueError: If the number of tasks, num_tasks passed to
this algorithm doesn't match the length of the task
one-hot id in the observation vector.
Returns:
torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size)
"""
obs = samples_data['observation']
log_alpha = self._log_alpha
one_hots = obs[:, -self._num_tasks:]
if (log_alpha.shape[0] != one_hots.shape[1]
or one_hots.shape[1] != self._num_tasks
or log_alpha.shape[0] != self._num_tasks):
raise ValueError(
'The number of tasks in the environment does '
'not match self._num_tasks. Are you sure that you passed '
'The correct number of tasks?')
ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze()
return ret

Expand All @@ -183,6 +201,7 @@ def _evaluate_policy(self, epoch):
obtain_evaluation_episodes(
self.policy,
self._eval_env,
self._max_episode_length_eval,
num_eps=self._num_evaluation_episodes))
eval_eps = EpisodeBatch.concatenate(*eval_eps)
last_return = log_multitask_performance(epoch, eval_eps,
Expand Down
1 change: 0 additions & 1 deletion src/garage/torch/algos/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class SAC(RLAlgorithm):
for off-policy evaluation. If None, defaults to
`max_episode_length`.
gradient_steps_per_itr (int): Number of optimization steps that should
max_episode_length(int): Max episode length of the environment.
gradient_steps_per_itr(int): Number of optimization steps that should
occur before the training step is over and a new batch of
transitions is collected by the sampler.
Expand Down
55 changes: 55 additions & 0 deletions tests/garage/torch/algos/test_mtsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,61 @@ def test_mtsac_get_log_alpha(monkeypatch):
assert log_alpha.size() == torch.Size([mtsac._buffer_batch_size])


@pytest.mark.mujoco
def test_mtsac_get_log_alpha_incorrect_num_tasks(monkeypatch):
"""Check that if the num_tasks passed does not match the number of tasks
in the environment, then the algorithm should raise an exception.
MTSAC uses disentangled alphas, meaning that
"""
env_names = ['CartPole-v0', 'CartPole-v1']
task_envs = [GymEnv(name) for name in env_names]
env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy)
deterministic.set_seed(0)
policy = TanhGaussianMLPPolicy(
env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=torch.nn.ReLU,
output_nonlinearity=None,
min_std=np.exp(-20.),
max_std=np.exp(2.),
)

qf1 = ContinuousMLPQFunction(env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=F.relu)

qf2 = ContinuousMLPQFunction(env_spec=env.spec,
hidden_sizes=[1, 1],
hidden_nonlinearity=F.relu)
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), )

buffer_batch_size = 2
mtsac = MTSAC(policy=policy,
qf1=qf1,
qf2=qf2,
gradient_steps_per_itr=150,
max_episode_length=150,
eval_env=env,
env_spec=env.spec,
num_tasks=4,
steps_per_epoch=5,
replay_buffer=replay_buffer,
min_buffer_size=1e3,
target_update_tau=5e-3,
discount=0.99,
buffer_batch_size=buffer_batch_size)
monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.]))
error_string = ('The number of tasks in the environment does '
'not match self._num_tasks. Are you sure that you passed '
'The correct number of tasks?')
obs = torch.Tensor([env.reset()[0]] * buffer_batch_size)
with pytest.raises(ValueError, match=error_string):
mtsac._get_log_alpha(dict(observation=obs))


@pytest.mark.mujoco
def test_mtsac_inverted_double_pendulum():
"""Performance regression test of MTSAC on 2 InvDoublePendulum envs."""
Expand Down

0 comments on commit c83d34e

Please sign in to comment.