In [None]:
import wandb
import json
from collections import OrderedDict

In [None]:
def tabilize(results):
    names = results.keys()
    raw_data = list(results.values())
    assert len(names) == len(raw_data)

    max_len = max([len(v) for v in list(names)])
    names_padded = [v + ' ' * (max_len - len(v)) for v in names]

    data = []
    for row in raw_data:
        new_row = []
        for v in row:
            if isinstance(v, list):
                new_row.append(tuple(v))
            elif v > 999:
                new_row.append("{:.1e}".format(v))
            else:
                new_row.append(v)
        data.append(new_row)

    for i_row in range(len(names)):
        line = ''
        line += names_padded[i_row]
        for idx, v in enumerate(data[i_row]):
            line += ' & '
            line += str(v)
        if i_row < len(names) - 1:
            line += ' \\\\'
        print(line)

In [None]:
api = wandb.Api(timeout=120)

In [None]:
env_ids = ['deep-sea-treasure-concave-v0', 'minecart-v0', 'mo-reacher-v4']
algs = ['SN-MO-DQN', 'SN-MO-A2C', 'SN-MO-PPO']
rep_runs = {alg: {env_id: None for env_id in env_ids} for alg in algs}

with open('data/best_parents.json') as f:
    files = json.load(f)

for env_id, env_dict in files.items():
    for alg, run_ids in env_dict.items():
        rep_runs[alg][env_id] = run_ids[0]

In [None]:
configs = {
    'scale': [],
    '$\rho$': [],
    'pretrain\_iters': [],
    'num\_referents': [],
    'online\_steps': [],
    'pretraining\_steps': [],
    'critic\_hidden': [],
    'lr\_critic': [],
    'actor\_hidden': [],
    'lr\_actor': [],
    'n\_steps': [],
    'gae\_lambda': [],
    'normalise\_advantage': [],
    'e\_coef': [],
    'v\_coef': [],
    'max\_grad\_norm': [],
    'clip\_coef': [],
    'num\_envs': [],
    'anneal\_lr': [],
    'clip\_range\_vf': [],
    'update\_epochs': [],
    'num\_minibatches': [],
    'batch\_size': [],
    'buffer\_size': [],
    'soft\_update': [],
    'pre\_learning\_start'
    'pre\_epsilon\_start': [],
    'pre\_epsilon\_end': [],
    'pre\_exploration\_frac': [],
    'online\_learning\_start': [],
    'online\_epsilon\_start': [],
    'online\_epsilon\_end': [],
    'online\_exploration\_frac': [],
}
config_names = OrderedDict([
    ('scale', 'scale'),
    ('aug', '$\\rho$'),
    ('pretrain_iters', 'pretrain\_iters'),
    ('num_referents', 'num\_referents'),
    ('online_steps', 'online\_steps'),
    ('pretraining_steps', 'pretraining\_steps'),
    ('critic_hidden', 'critic\_hidden'),
    ('hidden_layers', 'critic\_hidden'),
    ('lr_critic', 'lr\_critic'),
    ('lr', 'lr\_critic'),
    ('actor_hidden', 'actor\_hidden'),
    ('lr_actor', 'lr\_actor'),
    ('n_steps', 'n\_steps'),
    ('gae_lambda', 'gae\_lambda'),
    ('normalize_advantage', 'normalise\_advantage'),
    ('e_coef', 'e\_coef'),
    ('v_coef', 'v\_coef'),
    ('max_grad_norm', 'max\_grad\_norm'),
    ('clip_coef', 'clip\_coef'),
    ('num_envs', 'num\_envs'),
    ('anneal_lr', 'anneal\_lr'),
    ('clip_range_vf', 'clip\_range\_vf'),
    ('update_epochs', 'update\_epochs'),
    ('num_minibatches', 'num\_minibatches'),
    ('batch_size', 'batch\_size'),
    ('buffer_size', 'buffer\_size'),
    ('soft_update', 'soft\_update'),
    ('tau', 'soft\_update'),
    ('epsilon_start', 'epsilon\_start'),
    ('epsilon_end', 'epsilon\_end'),
    ('exploration_frac', 'exploration\_frac'),
    ('pre_learning_start', 'pre\_learning\_start'),
    ('pre_epsilon_start', 'pre\_epsilon\_start'),
    ('pre_epsilon_end', 'pre\_epsilon\_end'),
    ('pre_exploration_frac', 'pre\_exploration\_frac'),
    ('online_learning_start', 'online\_learning\_start'),
    ('online_epsilon_start', 'online\_epsilon\_start'),
    ('online_epsilon_end', 'online\_epsilon\_end'),
    ('online_exploration_frac', 'online\_exploration\_frac'),
])

In [None]:
for alg in algs:
    print(f'Tabilizing {alg}')
    configs = OrderedDict()
    for env in env_ids:
        alg_env_run = rep_runs[alg][env]
        config = api.run(alg_env_run).config
        for param_name, param in config_names.items():
            if param_name in config:
                if param not in configs:
                    configs[param] = []
                configs[param].append(config[param_name])

    tabilize(configs)
    print("------------")