In [None]:
import wandb
import os
import json
import numpy as np
from environments.bounding_boxes import get_bounding_box
from utils.pareto import extreme_prune

In [None]:
api = wandb.Api(timeout=120)

In [None]:
def extract_eval_fronts_from_json(run, maximals):
    output = run.file('output.log')
    file = output.download(root=f'data/{env_id}/{alg}/{seed}', replace=True)
    path = file.name
    
    partial_front = np.copy(maximals)
    eval_fronts = [partial_front]
    
    with open(path) as f:  # Open as a text file
        # Read the file contents and generate a list of lines
        lines = f.readlines()
        for line in lines:
            if 'Found' in line:
                split = line.split('[')[2].split(']')[0].split(' ')
                point = []
                for val in split:
                    if not val == '':
                        point.append(float(val))
                partial_front = np.copy(np.vstack((partial_front, point)))
                eval_fronts.append(partial_front)
    return eval_fronts

In [None]:
def extract_eval_fronts_from_summary(run, maximals):
    pareto_points = []
    for k, v  in run.summary.items():  # Add here again the trick to read from json.
        if 'pareto_point' in k:
            idx = int(k.split('_')[-1])
            pareto_points.append((idx, np.array(v)))
    
    pareto_points = sorted(pareto_points, key=lambda x: x[0])
    partial_front = np.copy(maximals)
    eval_fronts = [partial_front]
    
    for _, point in pareto_points:
        partial_front = np.copy(np.vstack((partial_front, point)))
        eval_fronts.append(partial_front)
    return eval_fronts

In [None]:
with open('data/best_runs.json') as f:
    files = json.load(f)
new_runs = True  # Set this for true with new runs


for env_id, alg_dict in files.items():
    print(f'Processing {env_id}')
    if env_id == 'mo-reacher-concave-v0':
        _, maximals, _ = get_bounding_box('mo-reacher-v4')
    else:
        continue
    for alg, runs in alg_dict.items():
        if alg != 'SN-MO-DQN':
            continue
        print(f'Processing {alg}')
        for run_path in runs:
            run = api.run(run_path)
            print(run.id)
            seed = run.config['seed']
            online_steps = run.config['online_steps']
            if new_runs:
                eval_fronts = extract_eval_fronts_from_summary(run, maximals)
            else:
                eval_fronts = extract_eval_fronts_from_json(run, maximals)
                
            for i, front in enumerate(eval_fronts):
                step = i * online_steps
                front_path = os.path.join('fronts', env_id, alg, str(seed), f'front_{step}.npy')
                os.makedirs(os.path.dirname(front_path), exist_ok=True)
                np.save(front_path, front)
            
            final_front = extreme_prune(np.copy(eval_fronts[-1]))
            final_front_path = os.path.join('fronts', env_id, alg, str(seed), 'final_front.npy')
            os.makedirs(os.path.dirname(final_front_path), exist_ok=True)
            np.save(final_front_path, final_front)
    print('---------')
print(f'Finished!')