In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
import json
import pickle
from replay_buffer.ber import BlockReplayBuffer
from ray.rllib.utils.replay_buffers import ReplayBuffer
from replay_buffer.mpber import MultiAgentPrioritizedBlockReplayBuffer
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import MultiAgentPrioritizedReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_configs(base_path):
    configs = {}
    for env in os.listdir(base_path):
        env_path = os.path.join(base_path, env)
        if os.path.isdir(env_path):
            configs[env] = {}
            for experiment in os.listdir(env_path):
                experiment_path = os.path.join(env_path, experiment)
                if os.path.isdir(experiment_path):
                    config_file = os.path.join(experiment_path, experiment + "_config.pyl")
                    if os.path.exists(config_file):
                        with open(config_file, "rb") as f:
                            configs[env][experiment] = pickle.load(f)
                    else:
                        print("config not existing for %s in %s" % (experiment, env))
    return configs

base_path = "/home/seventheli/JADE_checkpoints/"
configs = generate_configs(base_path)

config not existing for DQN_ER_SpaceInvadersNoFrameskip-v4_511703 in SpaceInvaders


In [3]:
data_path = "./jade_apex.csv"

In [4]:
mapper = {
    MultiAgentPrioritizedReplayBuffer: "DPER",
    MultiAgentPrioritizedBlockReplayBuffer: "DPBER",
}

In [5]:
data = pd.read_csv(data_path)
data['env'] = data['exp'].str.split('_').apply(lambda x: x[2].replace("NoFrameskip-v4", ""))
data["buffer"] = data.apply(lambda x: configs[x["env"]][x["exp"]].get("replay_buffer_config")["type"], axis=1)
data["buffer"] = data["buffer"].apply(lambda x: mapper[x])
data["size"] = data.apply(lambda x: configs[x["env"]][x["exp"]].get("replay_buffer_config").get("sub_buffer_size", 1), axis=1)
data.head()
data.to_csv("./jade_apex.csv", index=False)

In [6]:
max_exp = data.groupby(['env', 'buffer', 'size'])['exp'].apply(lambda x: x.value_counts().idxmax())

filtered_data = data[data['exp'].isin(max_exp.values)].reset_index(drop=True)
# Group by 'env', 'size', 'buffer' and aggregate the required metrics
summary = filtered_data.groupby(['env', 'size', 'buffer']).agg({
    'episode_reward_mean': 'last',  # Last episode reward mean
    'num_env_steps_sampled': 'sum', # Total number of data points
    'time_total_s': 'last',         # Total time
    'exp': 'unique'                 # Experiment names
}).reset_index()

# Add a column to indicate the number of unique experiments for each group
summary['num_experiments'] = summary['exp'].apply(len)

# Verify if there's only 1 experiment for each group
summary['single_experiment'] = summary['num_experiments'] == 1

# Display the summary
summary.single_experiment.value_counts()

# remove too short
filtered_data = filtered_data[filtered_data["exp"].isin([i[0] for i in summary[summary["time_total_s"]> 300000]["exp"]])]

In [7]:
filtered_data.groupby(['env', "buffer"])["time_total_s"].max().reset_index()

Unnamed: 0,env,buffer,time_total_s
0,BeamRider,DPBER,360056.84698
1,BeamRider,DPER,310839.486917
2,Breakout,DPBER,360005.137348
3,Breakout,DPER,311496.41598
4,Centipede,DPER,310665.513887
5,ChopperCommand,DPBER,360056.667421
6,ChopperCommand,DPER,312789.381193
7,CrazyClimber,DPBER,360035.796872
8,Defender,DPBER,360005.586834
9,Qbert,DPBER,360048.02394


In [8]:
# Correctly compute the baseline difference for each environment
baseline_diff_corrected = filtered_data.groupby('env').apply(
    lambda x: x['episode_reward_mean'].max() - x['episode_reward_mean'].min()
)

# Using the corrected baseline difference to detect anomalies

# Function to detect anomalies based on the corrected baseline
def detect_anomalies_baseline(exp_data, env_baseline_diff):
    exp_diff = abs(exp_data['episode_reward_mean'].max() - exp_data['episode_reward_mean'].min())
    return exp_diff < abs(env_baseline_diff * 0.20)

anomalies_baseline_corrected = []

# Check each experiment in the data
for env, env_data in filtered_data.groupby('env'):
    env_baseline_diff = baseline_diff_corrected[env]
    for exp, exp_data in env_data.groupby('exp'):
        if detect_anomalies_baseline(exp_data, env_baseline_diff):
            anomalies_baseline_corrected.append(exp)

In [9]:
filtered_data = filtered_data[~filtered_data['exp'].isin(anomalies_baseline_corrected)].reset_index(drop=True)

In [10]:
filtered_data.to_csv("jade_apex_filtered.csv", index=False)

In [11]:
filtered_data

Unnamed: 0,episode_reward_max,episode_reward_min,episode_reward_mean,episodes_this_iter,num_env_steps_sampled,num_env_steps_trained,num_agent_steps_sampled,num_agent_steps_trained,num_weight_syncs,num_target_updates,...,learner_overall_throughput,target_net_update_time_ms,episodes_total,training_iteration,time_this_iter_s,time_total_s,size,exp,env,buffer
0,,,,0,25760,0,25760,0,161,,...,,,0,1,31.125418,31.125418,32,APEX_DDQN_BeamRiderNoFrameskip-v4_DPBER_514031,BeamRider,DPBER
1,756.0,352.0,552.000000,4,53280,512,53280,512,329,,...,7.782,,4,2,33.528656,64.654073,32,APEX_DDQN_BeamRiderNoFrameskip-v4_DPBER_514031,BeamRider,DPBER
2,756.0,352.0,602.666667,5,78400,882176,78400,882176,477,17.0,...,8859.413,11.982,9,3,98.113581,162.767654,32,APEX_DDQN_BeamRiderNoFrameskip-v4_DPBER_514031,BeamRider,DPBER
3,756.0,308.0,562.769231,4,104000,1701376,104000,1701376,637,33.0,...,10339.313,4.106,13,4,87.369431,250.137085,32,APEX_DDQN_BeamRiderNoFrameskip-v4_DPBER_514031,BeamRider,DPBER
4,756.0,308.0,534.250000,3,129440,2515456,129440,2515456,791,50.0,...,13411.639,4.650,16,5,88.745094,338.882179,32,APEX_DDQN_BeamRiderNoFrameskip-v4_DPBER_514031,BeamRider,DPBER
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
161475,2150.0,515.0,1315.100000,6,127910720,839138304,127910720,839138304,781035,16723.0,...,4803.690,10.027,29278,5071,33.193845,359898.757336,8,APEX_DDQN_SpaceInvadersNoFrameskip-v4_DPBER_51...,SpaceInvaders,DPBER
161476,2150.0,515.0,1312.300000,3,127935840,839302144,127935840,839302144,781189,16727.0,...,2319.787,11.075,29281,5072,33.767208,359932.524544,8,APEX_DDQN_SpaceInvadersNoFrameskip-v4_DPBER_51...,SpaceInvaders,DPBER
161477,2260.0,515.0,1298.600000,4,127960960,839467520,127960960,839467520,781343,16730.0,...,6187.669,5.276,29285,5073,32.214917,359964.739461,8,APEX_DDQN_SpaceInvadersNoFrameskip-v4_DPBER_51...,SpaceInvaders,DPBER
161478,2260.0,365.0,1306.400000,6,127986080,839630848,127986080,839630848,781497,16733.0,...,3916.615,8.996,29291,5074,34.269591,359999.009052,8,APEX_DDQN_SpaceInvadersNoFrameskip-v4_DPBER_51...,SpaceInvaders,DPBER
