diff --git a/examples/torch/mtsac_metaworld_ml1_pick_place.py b/examples/torch/mtsac_metaworld_ml1_pick_place.py index bf22008595..f8952ee37c 100755 --- a/examples/torch/mtsac_metaworld_ml1_pick_place.py +++ b/examples/torch/mtsac_metaworld_ml1_pick_place.py @@ -84,24 +84,30 @@ def mtsac_metaworld_ml1_pick_place(ctxt=None, seed=1, _gpu=None): epochs = timesteps // batch_size epoch_cycles = epochs // num_evaluation_points epochs = epochs // epoch_cycles - mtsac = MTSAC(policy=policy, - qf1=qf1, - qf2=qf2, - gradient_steps_per_itr=150, - max_episode_length=150, - eval_env=ml1_test_envs, - env_spec=ml1_train_envs.spec, - num_tasks=50, - steps_per_epoch=epoch_cycles, - replay_buffer=replay_buffer, - min_buffer_size=1500, - target_update_tau=5e-3, - discount=0.99, - buffer_batch_size=1280) + mtsac = MTSAC( + policy=policy, + qf1=qf1, + qf2=qf2, + gradient_steps_per_itr=150, + max_episode_length=150, + max_episode_length_eval=150, + eval_env=ml1_test_envs, + env_spec=ml1_train_envs.spec, + num_tasks=50, + steps_per_epoch=epoch_cycles, + replay_buffer=replay_buffer, + min_buffer_size=1500, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=1280, + ) if _gpu is not None: set_gpu_mode(True, _gpu) mtsac.to() - runner.setup(algo=mtsac, env=ml1_train_envs, sampler_cls=LocalSampler) + runner.setup(algo=mtsac, + env=ml1_train_envs, + sampler_cls=LocalSampler, + n_workers=1) runner.train(n_epochs=epochs, batch_size=batch_size) diff --git a/examples/torch/mtsac_metaworld_mt10.py b/examples/torch/mtsac_metaworld_mt10.py index 73d0fa4695..bb650410df 100755 --- a/examples/torch/mtsac_metaworld_mt10.py +++ b/examples/torch/mtsac_metaworld_mt10.py @@ -83,6 +83,7 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None): qf2=qf2, gradient_steps_per_itr=150, max_episode_length=150, + max_episode_length_eval=150, eval_env=mt10_test_envs, env_spec=mt10_train_envs.spec, num_tasks=10, @@ -95,7 +96,10 @@ def mtsac_metaworld_mt10(ctxt=None, seed=1, _gpu=None): if _gpu is not None: set_gpu_mode(True, _gpu) mtsac.to() - runner.setup(algo=mtsac, env=mt10_train_envs, sampler_cls=LocalSampler) + runner.setup(algo=mtsac, + env=mt10_train_envs, + sampler_cls=LocalSampler, + n_workers=1) runner.train(n_epochs=epochs, batch_size=batch_size) diff --git a/examples/torch/mtsac_metaworld_mt50.py b/examples/torch/mtsac_metaworld_mt50.py index d4e02eff9b..c0064d7166 100755 --- a/examples/torch/mtsac_metaworld_mt50.py +++ b/examples/torch/mtsac_metaworld_mt50.py @@ -85,9 +85,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0): qf2=qf2, gradient_steps_per_itr=150, max_episode_length=150, + max_episode_length_eval=150, eval_env=mt50_test_envs, env_spec=mt50_train_envs.spec, - num_tasks=10, + num_tasks=50, steps_per_epoch=epoch_cycles, replay_buffer=replay_buffer, min_buffer_size=7500, @@ -96,7 +97,10 @@ def mtsac_metaworld_mt50(ctxt=None, seed=1, use_gpu=False, _gpu=0): buffer_batch_size=6400) set_gpu_mode(use_gpu, _gpu) mtsac.to() - runner.setup(algo=mtsac, env=mt50_train_envs, sampler_cls=LocalSampler) + runner.setup(algo=mtsac, + env=mt50_train_envs, + sampler_cls=LocalSampler, + n_workers=1) runner.train(n_epochs=epochs, batch_size=batch_size) diff --git a/src/garage/torch/algos/mtsac.py b/src/garage/torch/algos/mtsac.py index 287f672263..70eed10a1c 100644 --- a/src/garage/torch/algos/mtsac.py +++ b/src/garage/torch/algos/mtsac.py @@ -33,6 +33,9 @@ class MTSAC(SAC): agent is being trained in. num_tasks (int): The number of tasks being learned. max_episode_length (int): The max episode length of the algorithm. + max_episode_length_eval (int or None): Maximum length of episodes used + for off-policy evaluation. If None, defaults to + `max_episode_length`. eval_env (Environment): The environment used for collecting evaluation episodes. gradient_steps_per_itr (int): Number of optimization steps that should @@ -75,10 +78,12 @@ def __init__( qf2, replay_buffer, env_spec, + *, num_tasks, max_episode_length, eval_env, gradient_steps_per_itr, + max_episode_length_eval=None, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., @@ -100,6 +105,7 @@ def __init__( replay_buffer=replay_buffer, env_spec=env_spec, max_episode_length=max_episode_length, + max_episode_length_eval=max_episode_length_eval, gradient_steps_per_itr=gradient_steps_per_itr, fixed_alpha=fixed_alpha, target_entropy=target_entropy, @@ -153,6 +159,11 @@ def _get_log_alpha(self, samples_data): terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` + Raises: + ValueError: If the number of tasks, num_tasks passed to + this algorithm doesn't match the length of the task + one-hot id in the observation vector. + Returns: torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size) @@ -160,6 +171,13 @@ def _get_log_alpha(self, samples_data): obs = samples_data['observation'] log_alpha = self._log_alpha one_hots = obs[:, -self._num_tasks:] + if (log_alpha.shape[0] != one_hots.shape[1] + or one_hots.shape[1] != self._num_tasks + or log_alpha.shape[0] != self._num_tasks): + raise ValueError( + 'The number of tasks in the environment does ' + 'not match self._num_tasks. Are you sure that you passed ' + 'The correct number of tasks?') ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze() return ret @@ -183,6 +201,7 @@ def _evaluate_policy(self, epoch): obtain_evaluation_episodes( self.policy, self._eval_env, + self._max_episode_length_eval, num_eps=self._num_evaluation_episodes)) eval_eps = EpisodeBatch.concatenate(*eval_eps) last_return = log_multitask_performance(epoch, eval_eps, diff --git a/src/garage/torch/algos/sac.py b/src/garage/torch/algos/sac.py index 34c9cc5063..98cf302474 100644 --- a/src/garage/torch/algos/sac.py +++ b/src/garage/torch/algos/sac.py @@ -48,7 +48,6 @@ class SAC(RLAlgorithm): for off-policy evaluation. If None, defaults to `max_episode_length`. gradient_steps_per_itr (int): Number of optimization steps that should - max_episode_length(int): Max episode length of the environment. gradient_steps_per_itr(int): Number of optimization steps that should occur before the training step is over and a new batch of transitions is collected by the sampler. diff --git a/tests/garage/torch/algos/test_mtsac.py b/tests/garage/torch/algos/test_mtsac.py index 64dcfea5de..6cda1dc135 100644 --- a/tests/garage/torch/algos/test_mtsac.py +++ b/tests/garage/torch/algos/test_mtsac.py @@ -70,6 +70,61 @@ def test_mtsac_get_log_alpha(monkeypatch): assert log_alpha.size() == torch.Size([mtsac._buffer_batch_size]) +@pytest.mark.mujoco +def test_mtsac_get_log_alpha_incorrect_num_tasks(monkeypatch): + """Check that if the num_tasks passed does not match the number of tasks + + in the environment, then the algorithm should raise an exception. + + MTSAC uses disentangled alphas, meaning that + + """ + env_names = ['CartPole-v0', 'CartPole-v1'] + task_envs = [GymEnv(name) for name in env_names] + env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[1, 1], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[1, 1], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[1, 1], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + + buffer_batch_size = 2 + mtsac = MTSAC(policy=policy, + qf1=qf1, + qf2=qf2, + gradient_steps_per_itr=150, + max_episode_length=150, + eval_env=env, + env_spec=env.spec, + num_tasks=4, + steps_per_epoch=5, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=buffer_batch_size) + monkeypatch.setattr(mtsac, '_log_alpha', torch.Tensor([1., 2.])) + error_string = ('The number of tasks in the environment does ' + 'not match self._num_tasks. Are you sure that you passed ' + 'The correct number of tasks?') + obs = torch.Tensor([env.reset()[0]] * buffer_batch_size) + with pytest.raises(ValueError, match=error_string): + mtsac._get_log_alpha(dict(observation=obs)) + + @pytest.mark.mujoco def test_mtsac_inverted_double_pendulum(): """Performance regression test of MTSAC on 2 InvDoublePendulum envs."""