In [1]:
import ray.tune as tune
from ray.rllib.algorithms.ppo import PPOConfig
import ray
from matplotlib import pyplot as plt
import torch
from torch import _dynamo
from ray.air import session
import numpy as np

torch._dynamo.allow_in_graph(torch.distributions.kl.kl_divergence)
torch._dynamo.disallow_in_graph(torch.distributions.kl.kl_divergence)

# Set up a ppo config that corresponds to our atari-ppo.yaml
config = PPOConfig().rl_module(_enable_rl_module_api=True).training(_enable_learner_api=True)
config.environment("ALE/Breakout-v5", clip_rewards=True)
config.training( 
    train_batch_size= 5000,
    sgd_minibatch_size= 500,
    num_sgd_iter=10,
    vf_loss_coeff=0.01,
    lambda_=0.95,
    kl_coeff=0.5,
    clip_param=0.1,
    )
config.rollouts(
    num_rollout_workers=7, 
    num_envs_per_worker=16,
    batch_mode="truncate_episodes",
    observation_filter="NoFilter",
    )




<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x7f8261157ca0>

In [2]:
config.framework(
    torch_compile_worker=tune.grid_search([True, False]), 
    torch_compile_learner=tune.grid_search([True, False]),
)
config.resources(num_gpus_per_learner_worker=1)

run_config= ray.air.RunConfig(
    stop={"training_iteration": 1},
    )
# Only execute one trial at a time
tune_config = tune.TuneConfig(num_samples=1, max_concurrent_trials=1)


def trainable(config):
    cfg = PPOConfig.from_dict(config)
    algo = cfg.build()
    # Warm up
    algo.step()

    training_iter_times_ms = []

    # Train for 5 steps to get an estimate
    for x in range(10):
        results_dict = algo.step()  
        training_iter_time_ms = results_dict["timers"]["training_iteration_time_ms"]
        training_iter_times_ms.append(training_iter_time_ms)
    
    mean_training_iter_time_ms = np.mean(training_iter_times_ms)
        

    session.report({"mean_training_iter_time_ms": mean_training_iter_time_ms})  # Send the score to Tune.

tuner = tune.Tuner(
    tune.with_resources(trainable, tune.PlacementGroupFactory(
        # 1 Learner Worker should have one CPU and one GPU
        # 7 Rollout Workers should have one GPU
        [{"CPU": 1, "GPU": 1},
         {"CPU": 1},
         {"CPU": 1},
         {"CPU": 1},
         {"CPU": 1},
         {"CPU": 1},
         {"CPU": 1},
         {"CPU": 1},
        ]
    )),
    run_config=run_config,
    tune_config=tune_config,
    param_space=config,
)

results = tuner.fit()

0,1
Current time:,2023-06-06 23:51:31
Running for:,00:42:01.15
Memory:,16.6/62.0 GiB

Trial name,status,loc,torch_compile_learne r,torch_compile_worker,iter,total time (s),mean_training_iter_t ime_ms
trainable_ALE_Breakout-v5_dd410_00000,TERMINATED,10.0.57.149:148584,True,True,1,622.481,55886.0
trainable_ALE_Breakout-v5_dd410_00001,TERMINATED,10.0.57.149:148584,False,True,1,633.554,56773.6
trainable_ALE_Breakout-v5_dd410_00002,TERMINATED,10.0.57.149:148584,True,False,1,620.004,55581.0
trainable_ALE_Breakout-v5_dd410_00003,TERMINATED,10.0.57.149:148584,False,False,1,634.319,56846.1


[2m[36m(ImplicitFunc pid=148584)[0m Install gputil for GPU system monitoring.
[2m[36m(trainable pid=148584)[0m A.L.E: Arcade Learning Environment (version 0.8.0+919230b)
[2m[36m(trainable pid=148584)[0m [Powered by Stella]
[2m[36m(trainable pid=148584)[0m `UnifiedLogger` will be removed in Ray 2.7.
[2m[36m(trainable pid=148584)[0m   return UnifiedLogger(config, logdir, loggers=None)
[2m[36m(trainable pid=148584)[0m The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
[2m[36m(trainable pid=148584)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))
[2m[36m(trainable pid=148584)[0m The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
[2m[36m(trainable pid=148584)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))
[2m[36m(trainable pid=148584)[0m The `TBXLogger interface

In [3]:
df = results.get_dataframe()