Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fails restoring weights #41508

Open
Finebouche opened this issue Nov 29, 2023 · 6 comments
Open

Fails restoring weights #41508

Finebouche opened this issue Nov 29, 2023 · 6 comments
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical rllib RLlib related issues tune Tune-related issues

Comments

@Finebouche
Copy link

Finebouche commented Nov 29, 2023

What happened + What you expected to happen

The code of examples/restore_1_of_n_agents_from_checkpoint.py seems to not be working (at least in my case).

The weight are not recovered but re-initialized. The way I see it is that instead of having the same policy reward means (in Wandb) as before I get reinitialized values.

Maybe the example is not up to date or maybe I am doing something wrong here. I am using tune.Tuner().fit() and not tune.train() as in the example. But not sure why this would fail...

Versions / Dependencies

Python 3.10
Ray 2.8

Reproduction script

from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.registry import get_trainable_cls

from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback

config =  (
   get_trainable_cls("PPO").get_default_config()
...
...
   .multi_agent(
        policies= {
            "prey": PolicySpec(
                policy_class=None,  # infer automatically from Algorithm
                observation_space=env.observation_space[0],  # if None infer automatically from env
                action_space=env.action_space[0],  # if None infer automatically from env
                config={"gamma": 0.85},  # use main config plus <- this override here
            ),
            "predator": PolicySpec(
                policy_class=None,
                observation_space=env.observation_space[0],
                action_space=env.action_space[0],
                config={"gamma": 0.85},
            ),
        },
        policy_mapping_fn = lambda id, *arg, **karg: "prey" if env.agents[id].agent_type == 0 else "predator",
        policies_to_train=["prey", "predator"]
    )
)

path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

def restore_weights(path, policy_type):
    checkpoint_path = os.path.join(path, f"policies/{policy_type}")
    restored_policy = Policy.from_checkpoint(checkpoint_path)
    return restored_policy.get_weights()

class RestoreWeightsCallback(DefaultCallbacks):
    def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
        algorithm.set_weights({"predator": restore_weights(path_to_checkpoint, "predator")})
        algorithm.set_weights({"prey": restore_weights(path_to_checkpoint, "prey")})

config.callbacks(RestoreWeightsCallback)

ray.init()

# Define experiment    
tuner = tune.Tuner(
    "PPO",                                  
    param_space=config,                         
    run_config=train.RunConfig(         
        stop={                                    
            "training_iteration": 1,
            "timesteps_total": 20000,
        },
        verbose=3,
        callbacks=[WandbLoggerCallback(       
            project="ppo_marl", 
            group="PPO",
            api_key="blabla",
            log_config=True,
        )],
        checkpoint_config=train.CheckpointConfig(        
            checkpoint_at_end=True,
            checkpoint_frequency=1
        ),
    ),
)

# Run the experiment 
results = tuner.fit()

ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

@Finebouche Finebouche added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Nov 29, 2023
@Finebouche Finebouche changed the title [<Ray component: Core|RLlib|etc...>] 'RestoreWeightsCallback' example script Nov 29, 2023
@Finebouche Finebouche changed the title 'RestoreWeightsCallback' example script 'RestoreWeightsCallback' example script doesn't seem to work anymore ? Nov 29, 2023
@Finebouche Finebouche changed the title 'RestoreWeightsCallback' example script doesn't seem to work anymore ? Fails restoring weights Nov 29, 2023
@Finebouche
Copy link
Author

I should add that I checked that the checkpoint were correctly saved. If I do use


path_to_checkpoint = "/blablabla/ray_results/PPO_2023-11-29_02-51-09/PPO_CustomEnvironment_c4c87_00000_0_2023-11-29_02-51-09/checkpoint_000008"

algo = Algorithm.from_checkpoint(path_to_checkpoint)

and then use algo.compute_single_action()/ run the environment for several steps and then visualize the agents. I get the correct output.

It's really when trying to keep training those previous policies using the method described above that it fails.

@Finebouche
Copy link
Author

Finebouche commented Nov 30, 2023

I fell like it might be due to me using tune.Tuner().fit() and not tune.run(). In this other example with train(), it seem to work for the person that tried. Is there a way that fit reinitialize the weights ? Can you actually prevent that ?

@Finebouche
Copy link
Author

Related to #40777, #32751, #36761, #36830, #41290 and #40347
All on loading previously train Model/Policies.

@Finebouche
Copy link
Author

The trick of passing the checkpoint via "start_from_checkpoint" parameter to tune.Tuner().param_space found here doesn't work either :/

@Finebouche
Copy link
Author

Finebouche commented Dec 1, 2023

I was able to use tune.run() instead of tune.Tuner().fit() but it stil seems to be not working. The way I asses that is by visualizing an episode run of 3 environement:

  1. The initial one I want to retrieve
  2. an environment after attempt to restore weights
  3. an environment after one step

And 2. and 3. have similar behavior, different from 1.

Side problem is that tune.run is absent from documentation. So I first thought that it was being deprecated. I finally found the info I needed in the function implementation in the repo but wasn't straightforward at all.

Questions still remains:

  1. Is tune.run absent from the docs because it's being deprecated ?
  2. The weight retrieval still doesn't work with tune.run and tune.Tuner().fit() + callbacks but works with Algorithm.from_checkpoint(path_to_checkpoint)

@anyscalesam anyscalesam added the rllib RLlib related issues label Dec 2, 2023
@Finebouche
Copy link
Author

Finebouche commented Dec 11, 2023

Also linked to : #40626 #40777 and #37515

Documentation should clearly explain how to do that

@sven1977 sven1977 added P2 Important issue, but not time-critical and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Dec 18, 2023
@justinvyu justinvyu added the tune Tune-related issues label Jan 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical rllib RLlib related issues tune Tune-related issues
Projects
None yet
Development

No branches or pull requests

4 participants