In [None]:
import wandb
import os
import numpy as np
from ipro.environments.bounding_boxes import get_bounding_box
from ipro.utils.pareto import extreme_prune

In [None]:
api = wandb.Api(timeout=120)

In [None]:
def extract_eval_fronts_from_json(run, maximals):
    output = run.file('output.log')
    file = output.download(root=f'data/{env_id}/{alg}/', replace=True)
    path = file.name

    all_fronts = []
    partial_front = None
    eval_fronts = None

    with open(path) as f:  # Open as a text file
        # Read the file contents and generate a list of lines
        lines = f.readlines()

        for line in lines:
            if 'Running with config:' in line:
                if partial_front is not None:  # Add to the list if not the first.
                    all_fronts.append(eval_fronts)

                # Reset the partial front and eval_fronts.
                partial_front = np.copy(maximals)
                eval_fronts = [partial_front]
            elif 'Found' in line:
                split = line.split('[')[2].split(']')[0].split(' ')
                point = []
                for val in split:
                    if not val == '':
                        point.append(float(val))
                partial_front = np.copy(np.vstack((partial_front, point)))
                eval_fronts.append(partial_front)
    
    all_fronts.append(eval_fronts)  # Add the last one.
    
    return all_fronts

In [None]:
files = {
    "deep-sea-treasure-concave-v0": {
        "SN-MO-PPO": 'wilrop/IPRO_dst_bayes_no_pretrain_8/0pq4alhl',
        "SN-MO-DQN": 'wilrop/IPRO_dst_bayes_no_pretrain_9/daz0fqca',
    }
}

for env_id, alg_dict in files.items():
    print(f'Processing {env_id}')
    if env_id == 'mo-reacher-concave-v0':
        _, maximals, _ = get_bounding_box('mo-reacher-v4')
    else:
        _, maximals, _ = get_bounding_box(env_id)
    for alg, run_path in alg_dict.items():
        print(f'Processing {alg}')
        run = api.run(run_path)
        online_steps = run.config['oracle']['online_steps']
        all_eval_fronts = extract_eval_fronts_from_json(run, maximals)

        for seed, eval_fronts in enumerate(all_eval_fronts):
            print(f'Processing seed {seed + 1}/{len(all_eval_fronts)}')
            for i, front in enumerate(eval_fronts):
                step = i * online_steps
                front_path = os.path.join('fronts', env_id, alg, str(seed), f'front_{step}.npy')
                os.makedirs(os.path.dirname(front_path), exist_ok=True)
                np.save(front_path, front)

            final_front = extreme_prune(np.copy(eval_fronts[-1]))
            final_front_path = os.path.join('fronts', env_id, alg, str(seed), 'final_front.npy')
            os.makedirs(os.path.dirname(final_front_path), exist_ok=True)
            np.save(final_front_path, final_front)
    print('---------')
print(f'Finished!')