Skip to content

[RLlib] Checkpoint metrics loading with Tune is broken in 2.47.0 #53877

Open
@maybe-otski

Description

@maybe-otski

What happened + What you expected to happen

If you are using PPO with tune.fit() and loading an existing checkpoint, the metrics are incorrect. For example num_env_steps_sampled_lifetime is not initialized from the checkpoint. As a side effect for example the learning rate schedule starts from the beginning instead of continuing from the checkpoint value in the next tune run.

This appears to be a regression in ray 2.47.0 release as it works with 2.46.0.

Expected result

Using the reproduction script, the output should be:

Tune run 1 (initial):
Num env steps sampled lifetime:  4000
Learning rate 0.000255

Tune run 2 (load from checkpoint):
Num env steps sampled lifetime:  8000
Learning rate 1e-05

This is what happens with ray 2.46.0

Actual result

Tune run 1 (initial):
Num env steps sampled lifetime:  4000
Learning rate 0.000255

Tune run 2 (load from checkpoint):
Num env steps sampled lifetime:  4000.0
Learning rate 0.000255
Traceback (most recent call last):
  File "/home/otski/external_src/ray/rllib/examples/debugging/temp.py", line 67, in <module>
    assert int(lifetime_steps_after_second) == 2*int(lifetime_steps_after_first), \
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Expected lifetime steps to be incremented, now: 4000 vs. 4000

Versions / Dependencies

Python 3.12.8
ray 2.47.0
Pop!_OS 22.04 LTS x86_64

Reproduction script

from pathlib import Path

import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig


def _run_with_tune(algo_config):
    tuner = tune.Tuner(
        algo_config.algo_class,
        param_space=algo_config.to_dict(),
        run_config=tune.RunConfig(
            storage_path=Path.cwd() / "ray_results",
            name="example",
            stop={"iterations_since_restore": 1},
            checkpoint_config=tune.CheckpointConfig(
                checkpoint_at_end=True,
            )
        )
    )
    return tuner.fit()

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)

    config = (
        PPOConfig()
        .api_stack(
            enable_rl_module_and_learner=True,
            enable_env_runner_and_connector_v2=True)
        .framework("torch")
        .environment("CartPole-v1")
        .env_runners(
            num_env_runners=1,
        )
        .training(
            lr=[
                [0, 0.0005],
                [8000, 0.00001],
            ]
        )
    )
    tune_results = _run_with_tune(config)

    checkpoint_dir = tune_results[0].checkpoint.path
    lifetime_steps_after_first = tune_results[0].metrics['num_env_steps_sampled_lifetime']
    learning_rate_after_first = tune_results[0].metrics['learners']['default_policy']['default_optimizer_learning_rate']

    # Load the first run checkpoint for the second run
    config.callbacks(
        on_algorithm_init=(
            lambda algorithm, _dir=str(checkpoint_dir), **kw: algorithm.restore_from_path(_dir)
        )
    )
    tune_results = _run_with_tune(config)
    lifetime_steps_after_second = tune_results[0].metrics['num_env_steps_sampled_lifetime']
    learning_rate_after_second = tune_results[0].metrics['learners']['default_policy']['default_optimizer_learning_rate']

    print('Tune run 1 (initial):')
    print('Num env steps sampled lifetime: ', lifetime_steps_after_first)
    print('Learning rate', learning_rate_after_first)
    print()
    print('Tune run 2 (load from checkpoint):')
    print('Num env steps sampled lifetime: ', lifetime_steps_after_second)
    print('Learning rate', learning_rate_after_second)

    assert int(lifetime_steps_after_second) == 2*int(lifetime_steps_after_first), \
        f"Expected lifetime steps to be incremented, now: {int(lifetime_steps_after_first)} vs. {int(lifetime_steps_after_second)}"
    assert abs(learning_rate_after_first - learning_rate_after_second) > 1e-7, \
        f"Expected learning rates to differ, now: {learning_rate_after_first} vs. {learning_rate_after_second}"

Issue Severity

None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tregressionrllibRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions