In [1]:
import os
import ray
import time
import math
import numpy as np
import pandas as pd
from ray import tune
import seaborn as sns
from typing import Any
import gymnasium as gym
from copy import deepcopy
import plotly.express as px
from gymnasium import spaces
from pettingzoo import AECEnv
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from ray.rllib.env import PettingZooEnv
from ray.tune.logger import pretty_print
from models.MOGTorchModel import MOGTorchModel
from ray.rllib.algorithms.ppo import PPOConfig
from policies.ppo_sb3_loss import CustomLossPolicy
# from models.PyFlytModel_MOG import PyFlytModel_MOG
# from models.PyFlytModel_ENN import PyFlytModel_ENN
from ray.rllib.utils.framework import try_import_torch
from policies.ppo_torch_policy import SimpleTorchPolicy
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from models.SimpleTorchModel import SimpleCustomTorchModel
from add_ons.normalize_advantages import NormalizeAdvantagesCallback
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy

import PyFlyt.gym_envs
from ray.tune.registry import register_env
from PyFlyt.gym_envs import FlattenWaypointEnv
from PyFlyt.gym_envs.quadx_envs import quadx_hover_env, quadx_waypoints_env
from PyFlyt.pz_envs.fixedwing_envs.ma_fixedwing_dogfight_env import MAFixedwingDogfightEnv

pybullet build time: Nov 28 2023 23:45:17


In [2]:
path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

2024-07-08 15:18:17,569	INFO worker.py:1771 -- Started a local Ray instance.


0,1
Python version:,3.10.12
Ray version:,2.31.0


[36m(RolloutWorker pid=24268)[0m pybullet build time: Nov 28 2023 23:45:17


[36m(RolloutWorker pid=24268)[0m [A                             [A




[36m(RolloutWorker pid=24265)[0m [A                             [A[32m [repeated 22x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(RolloutWorker pid=24270)[0m [A                             [A[32m [repeated 44x across cluster][0m
[36m(RolloutWorker pid=24269)[0m [A                             [A[32m [repeated 45x across cluster][0m
[36m(RolloutWorker pid=24268)[0m [A                             [A[32m [repeated 42x across cluster][0m
[36m(RolloutWorker pid=24272)[0m [A                             [A[32m [repeated 36x across cluster][0m
[36m(RolloutWorker pid=24273)[0m [A                             [A[32m [repeated 33x across cluster][0m
[36m(RolloutWorker pid=24266)[0m [A                             [A[32m [repeated 40x across cluster][0m
[36m(RolloutWork

In [3]:
# class CustomRewardWrapper(gym.RewardWrapper):
#     def __init__(self, env):
#         super().__init__(env)

#     def reward(self, reward):
#         return reward / 100

In [4]:
class CustomDogfightEnv(MultiAgentEnv):
    def __init__(self, 
                 config, 
                 env: AECEnv = None):

        super().__init__()
        if env is None:
            self.env = MAFixedwingDogfightEnv()
        else:
            self.env = env
        self.env.reset()
        
        self.agent_ids = self.env.possible_agents
        self.observation_space = self.env.observation_space(self.env.agents[0])
        self.action_space = self.env.action_space(self.env.agents[0])

        # self.custom_reward_wrapper = CustomRewardWrapper(self.env)

        assert all(
            self.env.observation_space(agent) == self.observation_space
            for agent in self.env.agents
        ), (
            "Observation spaces for all agents must be identical. Perhaps "
            "SuperSuit's pad_observations wrapper can help (useage: "
            "`supersuit.aec_wrappers.pad_observations(env)`"
        )

        assert all(
            self.env.action_space(agent) == self.action_space
            for agent in self.env.agents
        ), (
            "Action spaces for all agents must be identical. Perhaps "
            "SuperSuit's pad_action_space wrapper can help (usage: "
            "`supersuit.aec_wrappers.pad_action_space(env)`"
        )
        self._agent_ids = set(self.env.agents)


    def reset(self, seed=None, options=None):
        observations, infos = self.env.reset()
        
        return observations, infos

    def step(self, action_dict):
        observations, rewards, terminations, truncations, infos = self.env.step(action_dict)

        # Ensure "__all__" keys are present in terminations and truncations dictionaries
        terminations["__all__"] = any(terminations.values())
        truncations["__all__"] = any(truncations.values())

        # processed_rewards = {
        #     agent_id: self.custom_reward_wrapper.reward(reward)
        #     for agent_id, reward in rewards.items()
        # }

        return observations, rewards, terminations, truncations, infos


def env_creator(config):
    return CustomDogfightEnv(config)
register_env('MAFixedwingDogfightEnv', env_creator)


In [5]:
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    # Check if agent_id is a digit
    if agent_id.isdigit():
        return 'policy_1' if int(agent_id) % 2 == 0 else 'policy_2'
    # Handle agent_ids like 'uav_0', 'uav_1', etc.
    return 'policy_1' if int(agent_id.split('_')[1]) % 2 == 0 else 'policy_2'

In [6]:
env_config = {
    'spawn_height': 5.0,
    'damage_per_hit': 0.02,
    'lethal_distance': 15.0,
    'lethal_angle_radians': 0.1,
    'assisted_flight': True,
    'sparse_reward': False,
    'flight_dome_size': 150.0,
    'max_duration_seconds': 60.0,
    'agent_hz': 30,
    'render_mode': None,
}

In [None]:
%%time

env_example = env_creator(env_config)
obs_space = env_example.observation_space
action_space = env_example.action_space

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    # lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    lr = 0.0003,
    vf_loss_coeff = 0.5,
    # vf_clip_param = 1.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 2_000, 
    sgd_minibatch_size = 500,
    grad_clip = 0.5,
    # kl_coeff = 0.01,
    # entropy_coeff = 0.001,
    optimizer = {
        'weight_decay': 0.01
    },
    model = {'custom_model': 'SimpleCustomTorchModel', #SimpleCustomTorchModel MOGTorchModel
           'vf_share_layers': False,
           'fcnet_hiddens': [256,256],
           'fcnet_activation': 'LeakyReLU',
           'custom_model_config': {
                'num_gaussians': 3,
                'num_layers': 2,
                # 'num_outputs': action_space.shape[0],
                # 'parquet_file_name': 'logs/critic_logging_sigma.parquet',
           }
            }
).environment(
    env = 'MAFixedwingDogfightEnv',
    env_config = env_config
).rollouts(
num_rollout_workers = 10
).resources(num_gpus = 1
).multi_agent(
    policies = {
        'policy_1': (CustomLossPolicy, obs_space, action_space, {}),
        'policy_2': (CustomLossPolicy, obs_space, action_space, {}),
    },
    policy_mapping_fn=policy_mapping_fn
)

# .callbacks(NormalizeAdvantagesCallback
# )

# analysis = tune.run(
#     'PPO',
#     config=config.to_dict(),
#     stop={'training_iteration':300},
#     checkpoint_freq=10,
#     checkpoint_at_end=True,
#     # local_dir='./ray_results'
# )


algo = config.build()

num_iterations = 1500
results = []

for i in range(num_iterations):
    result = algo.train()
    if i % 10 == 0:
        # print(f"Iteration: {i}, Mean Reward: {result['env_runners']['episode_reward_mean']} episode length: {result['env_runners']['episode_len_mean']}")
        print(f"Iteration: {i}, Policy 1 Mean Reward: {result['env_runners']['policy_reward_mean']['policy_1']} loss: {result['info']['learner']['policy_1']['learner_stats']['total_loss']}\n"
              f"Iteration: {i}, Policy 2 Mean Reward: {result['env_runners']['policy_reward_mean']['policy_2']} loss: {result['info']['learner']['policy_2']['learner_stats']['total_loss']}\n"
              f"Iteration: {i}, episode length: {result['env_runners']['episode_len_mean']}\n"
        )

    results.append([result['env_runners']['episode_reward_mean'], result['env_runners']['episode_len_mean']])

results_df = pd.DataFrame(results)

ray.shutdown()


[A                             [A


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


Iteration: 0, Policy 1 Mean Reward: -2852.137939586214 loss: 1.907180411616961
Iteration: 0, Policy 2 Mean Reward: -3263.3495558606783 loss: 2.306467440724373
Iteration: 0, episode length: 70.36363636363636

Iteration: 10, Policy 1 Mean Reward: -4029.384919649011 loss: 2.2346677710612615
Iteration: 10, Policy 2 Mean Reward: -2891.475060658207 loss: 1.7446814725796382
Iteration: 10, episode length: 109.53

Iteration: 20, Policy 1 Mean Reward: -3909.8087809667854 loss: 2.392683094739914
Iteration: 20, Policy 2 Mean Reward: -2791.551334061982 loss: 1.3002947315573692
Iteration: 20, episode length: 121.52

Iteration: 30, Policy 1 Mean Reward: -2989.4313666594203 loss: 1.6213805745045344
Iteration: 30, Policy 2 Mean Reward: -3940.6409864259967 loss: 1.3107389941811562
Iteration: 30, episode length: 147.63

Iteration: 40, Policy 1 Mean Reward: -2407.5257810985186 loss: 1.1884374196330707
Iteration: 40, Policy 2 Mean Reward: -4197.204428402961 loss: 1.8176770637432733
Iteration: 40, episode l

In [None]:
algo.get_weights()

In [None]:
data = pd.read_parquet(path + '/logs/critic_logging_sigma.parquet')

In [None]:
data.tail(25)

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(y = data['logits_array_max'], mode = 'lines', name = 'logits_array_max'))
fig.add_trace(go.Scatter(y = data['logits_array_min'], mode = 'lines', name = 'logits_array_min'))
fig.add_trace(go.Scatter(y = data['alphas_array_min'], mode = 'lines', name = 'alphas_array_min'))
fig.add_trace(go.Scatter(y = data['alphas_array_max'], mode = 'lines', name = 'alphas_array_max'))

fig.update_layout(title_text="Component Analysis", xaxis_title = 'Iterations', yaxis_title = 'y-axis',
                 font = dict(
                     family = 'Times New Roman',
                     size = 18
                 ),
                 width = 900,
                 height = 600,
                 showlegend = True,
                 )
# fig.write_image('/Slicing_clip_param.png')
fig.show()

In [None]:
data.tail()

In [None]:
data['surrogate_loss'].min()

In [None]:
results_df = pd.DataFrame(results)
experiment_type = 'enn_2dim'
results_df.to_csv(path + '/logs/test_runs/'+experiment_type+'.csv')

In [None]:
results

In [None]:
plt.plot(results)
plt.title('Training Progress - Mean Reward per Episode')
plt.xlabel('Iteration')
plt.ylabel('Mean Reward')
# plt.savefig('Basic PPO - HalfCheetah-v4')
plt.show()

In [None]:
algo.logdir

In [None]:
env = FlattenWaypointEnv(gym.make(id='PyFlyt/QuadX-Waypoints-v1', flight_mode=-1), context_length=1)

obs_list = []
obs, info = env.reset()
# env.env.env.env.env.drones[0].set_mode(-1)
targets = env.unwrapped.waypoints.targets
points = np.concatenate((obs[10:13].reshape(-1,3), targets))
obs = {'default': obs}
obs_list += [obs]

reward_list = []
action_list = []
start = time.time()
for i in range(10*40):
    compute_action = algo.compute_actions(obs)
    action = compute_action['default']
    # obs, reward, terminated, truncated, info = env.step(np.zeros((4))+.79)
    obs, reward, terminated, truncated, info = env.step(action)

    obs = {'default': obs}
    
    obs_list += [obs]
    
    reward_list += [reward]
    action_list += [action]
    
    if terminated or info['num_targets_reached'] == 4:
        break

arrays = [d['default'] for d in obs_list]
obs_array = np.vstack(arrays)
reward_array = np.array(reward_list)
action_array = np.array(action_list) 
env.close()

In [None]:
plotly_figure = px.scatter_3d(x=obs_array[:,10], y=obs_array[:,11], z=obs_array[:,12], opacity=.6, color=np.arange(len(obs_array)))
plotly_figure.add_scatter3d(x=targets[:,0], y=targets[:,1], z=targets[:,2], marker={'color':'green', 'symbol':'square-open', 'size':25, 'line':{'width':10}}, mode='markers')
plotly_figure.write_html(path+'/3D_renders/3d_drone_space4_'+experiment_type+'.html')

In [None]:
import seaborn as sns

In [None]:
dataframes = {}
for filename in os.listdir(path+'/logs/test_runs'):
    if filename.endswith('.csv'):
        file_path = os.path.join(path+'/logs/test_runs', filename)
        df = pd.read_csv(file_path)
        key = os.path.splitext(filename)[0]
        dataframes[key] = df


data_list = []
labels = []
output_desired = 'length' #else will give length

for key, df in dataframes.items():
    if output_desired == 'reward':
        data_list.append(df.iloc[:,0])
        labels.append(f"reward for {key}")
    else:
        data_list.append(df.iloc[:,1])
        labels.append(f"length for {key}")

for data in data_list:
    sns.kdeplot(data, fill = True)

plt.legend(title = 'Modes', labels = labels)
plt.title(f"{output_desired}")
plt.show()

In [None]:
dataframes = {}
for filename in os.listdir(path+'/logs/test_runs'):
    if filename.endswith('.csv'):
        file_path = os.path.join(path+'/logs/test_runs', filename)
        df = pd.read_csv(file_path)
        key = os.path.splitext(filename)[0]
        dataframes[key] = df


reward = []
labels = []
output_desired = 'reward' #else will give length

for key, df in dataframes.items():
    plt.scatter(df.iloc[:,0], df.iloc[:,1])
    labels.append(f"length for {key}")

plt.legend(title = 'Different runs', labels = labels)
plt.title(f"{output_desired} over time")
plt.show()