In [1]:
import os
import gym
import ray
import time
import math
import numpy as np
import pandas as pd
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.algorithms.ppo import PPOConfig, PPO
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from models.SimpleTorchModel import SimpleCustomTorchModel
from policies.policy_w_custom_model import SimpleTorchPolicy
from add_ons.normalize_advantages import NormalizeAdvantagesCallback
from PyFlyt.pz_envs.fixedwing_envs.ma_fixedwing_dogfight_env import MAFixedwingDogfightEnv

from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
from ray.air.config import ScalingConfig

  import distutils.spawn
pybullet build time: Nov 28 2023 23:45:17


In [2]:
path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

2024-07-02 16:36:16,154	INFO worker.py:1771 -- Started a local Ray instance.


0,1
Python version:,3.10.12
Ray version:,2.31.0


[36m(RolloutWorker pid=179764)[0m pybullet build time: Nov 28 2023 23:45:17


[36m(RolloutWorker pid=179759)[0m [A                             [A
[36m(RolloutWorker pid=179758)[0m [A                             [A[32m [repeated 65x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(RolloutWorker pid=179758)[0m [A                             [A[32m [repeated 130x across cluster][0m
[36m(RolloutWorker pid=179757)[0m [A                             [A[32m [repeated 139x across cluster][0m
[36m(RolloutWorker pid=179761)[0m [A                             [A[32m [repeated 115x across cluster][0m
[36m(RolloutWorker pid=179760)[0m [A                             [A[32m [repeated 126x across cluster][0m
[36m(RolloutWorker pid=179759)[0m [A                             [A[32m [repeated 112x across cluster][0m
[36m(RolloutWorker pid=179760)[0m [A     

In [3]:
def env_creator(env_config):
    return MAFixedwingDogfightEnv(assisted_flight = True)
register_env("MAFixedwingDogfightEnv", env_creator)

In [4]:
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    # Check if agent_id is a digit
    if agent_id.isdigit():
        return 'policy_1' if int(agent_id) % 2 == 0 else 'policy_2'
    # Handle agent_ids like 'uav_0', 'uav_1', etc.
    return 'policy_1' if int(agent_id.split('_')[1]) % 2 == 0 else 'policy_2'

In [5]:
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation import Episode

In [6]:
class CustomMetricsCallback(DefaultCallbacks):
    def on_episode_end(self, *, worker, base_env, policies, episode: Episode, **kwargs):
        episode.custom_metrics["episode_reward_mean"] = episode.total_reward
        episode.custom_metrics["episode_len_mean"] = episode.length

In [7]:
%%time

env_name = "MAFixedwingDogfightEnv"
env = env_creator(None)
obs_space = env.observation_space()
action_space = env.action_space()

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 0.5,
    vf_clip_param = 15.0,
    clip_param = 0.2,
    grad_clip_by ='norm', 
    train_batch_size = 65_000, 
    sgd_minibatch_size = 4_096,
    grad_clip = 0.5,
    model = {'custom_model': 'SimpleCustomTorchModel', 
           'vf_share_layers': False,
           'fcnet_hiddens': [256,256],
           'fcnet_activation': 'LeakyReLU',
             #this isn't used for some models, but doesn't hurt to keep it
           'custom_model_config': {
                'num_gaussians': 2,
               'num_outputs': action_space.shape[0]
           }
            }
).environment(env = env_name
).rollouts(
num_rollout_workers = 10
).resources(num_gpus = 1
).callbacks(CustomMetricsCallback
).multi_agent(
    policies = {
        'policy_1': (SimpleTorchPolicy, obs_space, action_space, {}),
        'policy_2': (SimpleTorchPolicy, obs_space, action_space, {}),
    },
    policy_mapping_fn=policy_mapping_fn
)


algo = config.build()

num_iterations = 5
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Result Keys: {result.keys()}")  # Print all keys in the result dictionary
    if 'episode_reward_mean' in result:
        print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
        results.append([result['episode_reward_mean'], result['episode_len_mean']])
    else:
        print(f"Iteration: {i}, Mean Reward not found in the result")
        results.append([None, None])

ray.shutdown()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
  prep = cls(observation_space, options)
  self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
  self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
  self._preprocessor = get_preprocessor(obs_space)(


Iteration: 0, Result Keys: dict_keys(['custom_metrics', 'episode_media', 'info', 'env_runners', 'num_healthy_workers', 'num_in_flight_async_sample_reqs', 'num_remote_worker_restarts', 'num_agent_steps_sampled', 'num_agent_steps_trained', 'num_env_steps_sampled', 'num_env_steps_trained', 'num_env_steps_sampled_this_iter', 'num_env_steps_trained_this_iter', 'num_env_steps_sampled_throughput_per_sec', 'num_env_steps_trained_throughput_per_sec', 'timesteps_total', 'num_env_steps_sampled_lifetime', 'num_agent_steps_sampled_lifetime', 'num_steps_trained_this_iter', 'agent_timesteps_total', 'timers', 'counters', 'done', 'training_iteration', 'trial_id', 'date', 'timestamp', 'time_this_iter_s', 'time_total_s', 'pid', 'hostname', 'node_ip', 'config', 'time_since_restore', 'iterations_since_restore', 'perf'])
Iteration: 0, Mean Reward not found in the result
Iteration: 1, Result Keys: dict_keys(['custom_metrics', 'episode_media', 'info', 'env_runners', 'num_healthy_workers', 'num_in_flight_async

[36m(RolloutWorker pid=179758)[0m pybullet build time: Nov 28 2023 23:45:17[32m [repeated 9x across cluster][0m


CPU times: user 45.8 s, sys: 8.02 s, total: 53.8 s
Wall time: 4min 8s
