Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fixed bug in restoring a gpu trained algorithm #35024

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions rllib/BUILD
Expand Up @@ -2459,6 +2459,16 @@ py_test(
args = ["TestCheckpointRestorePG"]
)


py_test(
name = "tests/test_checkpoint_restore_pg_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "large",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestorePG"]
)

py_test(
name = "tests/test_checkpoint_restore_off_policy",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -2468,6 +2478,16 @@ py_test(
args = ["TestCheckpointRestoreOffPolicy"]
)


py_test(
name = "tests/test_checkpoint_restore_off_policy_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "large",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestoreOffPolicy"]
)

py_test(
name = "tests/test_checkpoint_restore_evolution_algos",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -2477,6 +2497,15 @@ py_test(
args = ["TestCheckpointRestoreEvolutionAlgos"]
)

py_test(
name = "tests/test_checkpoint_restore_evolution_algos_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "medium",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestoreEvolutionAlgos"]
)

py_test(
name = "policy/tests/test_policy_checkpoint_restore",
main = "policy/tests/test_policy_checkpoint_restore.py",
Expand Down
8 changes: 7 additions & 1 deletion rllib/policy/torch_policy.py
Expand Up @@ -775,7 +775,13 @@ def set_state(self, state: PolicyState) -> None:
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
optim_state_dict = convert_to_torch_tensor(s, device=self.device)
# Torch optimizer param_groups include things like beta, etc. These
# parameters should be left as scalar and not converted to tensors.
# otherwise, torch.optim.step() will start to complain.
optim_state_dict = {"param_groups": s["param_groups"]}
optim_state_dict["state"] = convert_to_torch_tensor(
s["state"], device=self.device
)
o.load_state_dict(optim_state_dict)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
Expand Down
8 changes: 7 additions & 1 deletion rllib/policy/torch_policy_v2.py
Expand Up @@ -993,7 +993,13 @@ def set_state(self, state: PolicyState) -> None:
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
optim_state_dict = convert_to_torch_tensor(s, device=self.device)
# Torch optimizer param_groups include things like beta, etc. These
# parameters should be left as scalar and not converted to tensors.
# otherwise, torch.optim.step() will start to complain.
optim_state_dict = {"param_groups": s["param_groups"]}
optim_state_dict["state"] = convert_to_torch_tensor(
s["state"], device=self.device
)
o.load_state_dict(optim_state_dict)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
Expand Down
40 changes: 33 additions & 7 deletions rllib/tests/test_algorithm_checkpoint_restore.py
Expand Up @@ -16,6 +16,7 @@
from ray.rllib.algorithms.ars import ARSConfig
from ray.rllib.algorithms.a3c import A3CConfig
from ray.tune.registry import get_trainable_cls
import os


def get_mean_action(alg, obs):
Expand All @@ -32,7 +33,12 @@ def get_mean_action(alg, obs):
# explore=None if we compare the mean of the distribution of actions for the
# same observation to be the same.
algorithms_and_configs = {
"A3C": (A3CConfig().exploration(explore=False).rollouts(num_rollout_workers=1)),
"A3C": (
A3CConfig()
.exploration(explore=False)
.rollouts(num_rollout_workers=1)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"APEX_DDPG": (
ApexDDPGConfig()
.exploration(explore=False)
Expand All @@ -42,51 +48,65 @@ def get_mean_action(alg, obs):
optimizer={"num_replay_buffer_shards": 1},
num_steps_sampled_before_learning_starts=0,
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"ARS": (
ARSConfig()
.exploration(explore=False)
.rollouts(num_rollout_workers=2, observation_filter="MeanStdFilter")
.training(num_rollouts=10, noise_size=2500000)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"DDPG": (
DDPGConfig()
.exploration(explore=False)
.reporting(min_sample_timesteps_per_iteration=100)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"DQN": (
DQNConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"ES": (
ESConfig()
.exploration(explore=False)
.training(episodes_per_batch=10, train_batch_size=100, noise_size=2500000)
.rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"PPO": (
# See the comment before the `algorithms_and_configs` dict.
# explore is set to None for PPO in favor of RLModule API support.
PPOConfig()
.training(num_sgd_iter=5, train_batch_size=1000)
.rollouts(num_rollout_workers=2)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"SimpleQ": (
SimpleQConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"SAC": (
SACConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
}


def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=False):
def ckpt_restore_test(
algo_name,
tf2=False,
object_store=False,
replay_buffer=False,
run_restored_algorithm=True,
):
config = algorithms_and_configs[algo_name].to_dict()
# If required, store replay buffer data in checkpoints as well.
if replay_buffer:
Expand Down Expand Up @@ -172,22 +192,28 @@ def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=Fa
raise AssertionError(
"algo={} [a1={} a2={}]".format(algo_name, a1, a2)
)
# Stop both algos.
# Stop algo 1.
alg1.stop()

if run_restored_algorithm:
# Check that algo 2 can still run.
print("Starting second run on Algo 2...")
alg2.train()
alg2.stop()


class TestCheckpointRestorePG(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_a3c_checkpoint_restore(self):
ckpt_restore_test("A3C")
# TODO(Kourosh) A3C cannot run a restored algorithm for some reason.
ckpt_restore_test("A3C", run_restored_algorithm=False)

def test_ppo_checkpoint_restore(self):
ckpt_restore_test("PPO", object_store=True)
Expand All @@ -196,7 +222,7 @@ def test_ppo_checkpoint_restore(self):
class TestCheckpointRestoreOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
Expand All @@ -221,7 +247,7 @@ def test_simpleq_checkpoint_restore(self):
class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
Expand Down