In [None]:
import os
import json
import pandas as pd
import functools
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import pickle
import plotnine as gg
import tempfile
import wandb

In [None]:
api = wandb.Api()

In [None]:
def filter_df(df, **settings):
    for k, v in settings.items():
        df = df[df[k] == v]
    return df

In [None]:
path = 'vladfi/slippi-ai/7u7ohb9e'

In [None]:
%%time

run = api.run(path)

In [None]:
%%time

runs = api.runs(path='vladfi/slippi-ai')
len(runs)

In [None]:
config = runs[400].load()

In [None]:
imitation_runs = api.runs(path='vladfi/slippi-ai', filters={'group': 'imitation'})
len(imitation_runs)

In [None]:
config['group']

In [None]:
path = 'vladfi/slippi-ai'
group = 'imitation'
target_name = 'top12_d21_imitation_v5'

In [None]:
# all of these are started at the 7-day mark
multichar_run_ids = {
    'default': 'vladfi/slippi-ai/wpm2b7eo',
    'mirror': 'vladfi/slippi-ai/ffl6lk6o',
    '3x768': 'vladfi/slippi-ai/yfgly4xl',
}

In [None]:
_losses_key = lambda char: 'eval_characters.losses.' + char
_counts_key = lambda char: 'eval_characters.counts.' + char

In [None]:
def get_chars(run) -> list[str]:
    run_config = run.load()
    return run_config['config']['dataset']['allowed_characters'].split(',')

In [None]:
runs = {k: api.run(v) for k,v in run_ids.items()}

In [None]:
chars = get_chars(runs['default'])

In [None]:
keys = []

for char in chars:
    keys.append(_losses_key(char))
    keys.append(_counts_key(char))

In [None]:
dfs = {k: run.history(keys=keys) for k, run in runs.items()}

In [None]:
df = dfs['3x768']

In [None]:
_mean_key = lambda char: f'{char}.mean'

for char in chars:
    df[_mean_key(char)] = df[_losses_key(char)] / df[_counts_key(char)]    

In [None]:
to_plot = pd.DataFrame({'step': df['_step']})
for c in chars:
    to_plot[c] = df[_mean_key(c)]

to_plot = to_plot.melt(id_vars='step', value_vars=chars, var_name='char', value_name='eval_loss')

In [None]:
(gg.ggplot(to_plot)
 + gg.aes(x="step", y="eval_loss", group="char", color="char")
 + gg.geom_line())

In [None]:
char_specific_run_paths = {
    'falco': 'vladfi/slippi-ai/7u7ohb9e',
}

char_specific_runs = {k: api.run(v) for k, v in char_specific_run_paths.items()}

In [None]:
config = run.load()

In [None]:
config['historyKeys']['lastStep']

In [None]:
print(json.dumps(config['historyKeys']['keys'], indent=2))

In [None]:
config['historyKeys']['keys']['eval.policy.loss']['previousValue']

In [None]:
from slippi_ai.saving import upgrade_config

def get_network_name(net: dict):
    name = net['name']
    if name != 'tx_like':
        return name

    config = net[name]
    return f'{config["num_layers"]}x{config["hidden_size"]}'

def filter_best(df, group_keys, value_key='loss'):
    min_indices = df.groupby(group_keys)[value_key].idxmin()
    return df.loc[min_indices].sort_values(value_key, ascending=False)

def get_last_step_info(run):
    config = run.load()

    train_config = upgrade_config(config['config'])
    
    allowed_chars = train_config['dataset']['allowed_characters']
    history_keys = config['historyKeys']['keys']

    if allowed_chars == 'doc,mario':
        chars = ['doc']
    elif allowed_chars == 'all':
        chars = []
        prefix = 'eval_characters.losses.'
        for key in history_keys:
            if key.startswith(prefix):
                chars.append(key.removeprefix(prefix))
    else:
        chars = allowed_chars.split(',')

    final_losses = {}

    def final_value(key):
        return history_keys[key]['previousValue']
                
    
    if len(chars) == 1:
        char = chars[0]
        final_losses[char] = final_value('eval.policy.loss')
    else:
        for char in chars:
            final_count = final_value(_counts_key(char))
            if final_count == 0:
                continue
            final_losses[char] = final_value(_losses_key(char)) / final_count

    controller_config = train_config['embed']['controller']

    if run.name == 'fox_d18_imitation_taf7':
        controller = (16, 4)
    else:
        controller=(controller_config['axis_spacing'], controller_config['shoulder_spacing'])
    
    return dict(
        name=run.name,
        final_losses=final_losses,
        last_step=config['historyKeys']['lastStep'],
        delay=config['config']['policy']['delay'],
        network=get_network_name(train_config['network']),
        controller=controller,
        allowed_names=train_config['dataset']['allowed_names'],
    )

In [None]:
run_infos = []
errors = []

for run in imitation_runs:
    if run.state == 'running':
        continue
    run_infos.append(get_last_step_info(run))
    # except Exception as e:
    #     errors.append((run, e))

print(len(run_infos), len(errors))

In [None]:
by_character_run_infos = []

for row in run_infos:
    for char, final_loss in row['final_losses'].items():
        new_row = row.copy()
        del new_row['final_losses']
        new_row['char'] = char
        new_row['loss'] = final_loss
        new_row['num_chars'] = len(row['final_losses'])
        by_character_run_infos.append(new_row)

df = pd.DataFrame(by_character_run_infos)

In [None]:
min_indices = df.groupby(['name', 'char'])['loss'].idxmin()
df = df.loc[min_indices]

In [None]:
df.sort_values(by='loss', inplace=True, ascending=False)

In [None]:
filter_df(df, char='fox', delay=21, allowed_names='all')

In [None]:
all_names = filter_df(df, allowed_names='all').drop(columns=['allowed_names', 'last_step'])
reduced_df = filter_best(all_names, ['controller', 'char', 'num_chars', 'delay'])
reduced_df = reduced_df.sort_values(by=['char', 'delay', 'controller', 'loss'], ascending=False)

In [None]:
filter_df(reduced_df, char='fox')

In [None]:
filter_df(reduced_df, char='falco')

In [None]:
filter_df(reduced_df, char='popo')