In [None]:
import wandb
import pandas as pd
import numpy as np
from collections import defaultdict

In [None]:
# Setup experiments to extract data from.
api = wandb.Api(timeout=120)
runs = api.runs("wilrop/IPRO_runs")
algs = ["SN-MO-PPO", "SN-MO-A2C", "SN-MO-DQN"]
env_ids = ["deep-sea-treasure-concave-v0", "minecart-v0", "mo-reacher-v4"]

In [None]:
# Group runs based on Parent runs, env_id and algorithm.
run_hists = {env_id: {alg: defaultdict(list) for alg in algs} for env_id in env_ids}

for run in runs:
    env_id = run.config['env_id']
    alg = run.config['alg_name']
    if env_id in env_ids and alg in algs:
        parent_run = run.config['parent_run_id']
        run_hists[env_id][alg][parent_run].append(run)
        print(f'Added run to {alg} - {parent_run}')

In [None]:
# Keep only the best runs.
best_data = {env_id: {alg: None for alg in algs} for env_id in env_ids}
max_iterations = {env_id: 0 for env_id in env_ids}

for env_id in run_hists:
    for alg in run_hists[env_id]:
        best_hv = -1
        for parent, runs in run_hists[env_id][alg].items():
            hvs = [run.summary['outer/hypervolume'] for run in runs]
            mean_hv = np.mean(hvs)
            if mean_hv > best_hv:
                if len(runs) == 5:
                    best_hv = mean_hv
                    best_data[env_id][alg] = (parent, runs)
                else:
                    print(f"Skipping {alg} - {env_id} - {parent} with mean {mean_hv} and {len(runs)} runs")
                
# Print results
for env_id in best_data:
    for alg, (parent, runs) in best_data[env_id].items():
        hvs = np.array([run.summary['outer/hypervolume'] for run in runs])
        print(f"Best run for {alg} - {env_id} - {parent}")
        print(f"HVs: {hvs} - Mean: {np.mean(hvs)}")

In [None]:
# Extract the maximum number of iterations.
max_iterations = {env_id: 0 for env_id in env_ids}
for env_id in best_data:
    for alg in best_data[env_id]:
        parent, runs = best_data[env_id][alg]
        for run in runs:
            iters = run.history(keys=['_step']).iloc[-1]['_step'] + 1
            max_iterations[env_id] = max(max_iterations[env_id], iters)
print(max_iterations)

In [None]:
def fill_iterations(hypervolumes, coverages, max_iter):
    """Fill the hypervolume and coverage lists with the last value to have the same length as max_iter.
    
    Args:
        hypervolumes (list): List of hypervolumes.
        coverages (list): List of coverages.
        max_iter (int): Maximum number of iterations.

    Returns:
        None
    """
    while len(hypervolumes) < max_iter:
        hypervolumes.append(hypervolumes[-1])
        coverages.append(coverages[-1])

In [None]:
# Make dictionaries with the data and save to csv.
for env_id in best_data:
    max_iter = max_iterations[env_id]
    for alg in best_data[env_id]:
        parent, runs = best_data[env_id][alg]
        hv_dict = {alg: [], 'Step': [], 'Seed': []}
        cov_dict = {alg: [], 'Step': [], 'Seed': []}

        for seed, run in enumerate(runs):
            hist = run.history(keys=['outer/hypervolume', 'outer/coverage'])
            hypervolumes = hist['outer/hypervolume'].values.tolist()
            coverages = list(np.clip(hist['outer/coverage'].values.tolist(), 0, 1))
            step_size = run.config['online_steps']
            fill_iterations(hypervolumes, coverages, max_iter)
            global_steps = np.arange(len(hypervolumes)) * step_size
            global_steps = global_steps.tolist()
            hv_dict[alg].extend(hypervolumes)
            cov_dict[alg].extend(coverages)
            hv_dict['Step'].extend(global_steps)
            cov_dict['Step'].extend(global_steps)
            hv_dict['Seed'].extend([seed] * max_iter)
            cov_dict['Seed'].extend([seed] * max_iter)

        hv_df = pd.DataFrame.from_dict(hv_dict)
        cov_df = pd.DataFrame.from_dict(cov_dict)
        print(f"Saving data for {env_id} - {alg}")
        hv_df.to_csv(f'../utils/results/{alg}_{env_id}_hv.csv', index=False)
        cov_df.to_csv(f'../utils/results/{alg}_{env_id}_cov.csv', index=False)