In [None]:
import wandb
import os
import json
import numpy as np
from utils.pareto import extreme_prune

In [None]:
api = wandb.Api(timeout=120)
baseline_runs = api.runs("wilrop/MORL-Baselines")

In [None]:
num_evals = 100
for run in baseline_runs:
    algo = run.config['algo']
    env_id = run.config['env_id']
    seed = run.config['seed']
    eval_fronts = []
    print(f'Processing {env_id} - {algo} - {seed}')
    history_dicts = [h for h in run.scan_history(keys=['eval/front', 'global_step'])]
    
    if len(history_dicts) < num_evals:
        interval = 1
        offset = 0
    else:
        interval = len(history_dicts) // num_evals
        offset = len(history_dicts) % num_evals
        
    for i, h in enumerate(history_dicts):
        if (i + 1 - offset) % interval != 0:
            continue
        data_path = os.path.join('data', env_id, algo, str(seed))
        pf_path = h['eval/front']['path']
        combined_path = os.path.join(data_path, pf_path)
        try:
            run.file(pf_path).download(root=data_path)
        except wandb.CommError:
            pass
            
        with open(combined_path, 'r') as f:
            front_data = json.load(f)
            front = np.array(front_data['data'])
            eval_fronts.append((h['global_step'], front))
    
    eval_fronts = sorted(eval_fronts, key=lambda x: x[0])
    
    partial_front = eval_fronts[0][1]
    for i, eval_front in eval_fronts:
        partial_front = extreme_prune(np.concatenate([partial_front, eval_front], axis=0))
        front_path = os.path.join('fronts', env_id, algo, str(seed), f'front_{i}.npy')
        os.makedirs(os.path.dirname(front_path), exist_ok=True)
        np.save(front_path, partial_front)
    
    merged_fronts = np.concatenate([front for _, front in eval_fronts], axis=0)
    final_front = extreme_prune(merged_fronts)
    final_front_path = os.path.join('fronts', env_id, algo, str(seed), 'final_front.npy')
    os.makedirs(os.path.dirname(final_front_path), exist_ok=True)
    np.save(final_front_path, final_front)
    print('---------')