In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import yaml
import json

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_colwidth', 999)

In [None]:
def load_yaml(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)
def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

In [None]:
model_path = '../models/'

In [None]:
config_list = []
metric_list = []

def find_metrics_in_dir(base_path):
    for f in os.listdir(base_path):
        print(f)
        if f=='.gitignore': continue
        subdir_path = base_path + f + '/'
        
        if os.path.exists(subdir_path + 'checkpoints/'):
            config_path = subdir_path + 'config.yaml'
            metrics_path = subdir_path + 'metrics_test.json'

            config_list.append(load_yaml(config_path))
            metric_list.append(load_json(metrics_path))
            
        elif os.path.isdir(subdir_path) and not f=='checkpoints':
            find_metrics_in_dir(subdir_path)

In [None]:
find_metrics_in_dir(model_path)

In [None]:
len(config_list), len(metric_list)

In [None]:
config_list[0]

In [None]:
def expand_config_item(config_item):
    expanded = {}
    expanded['base'] = config_item['data']['base']
    expanded['train_batch_size'] = config_item['loader']['train']['batch_size']
    for k, v in config_item['model_args'].items():
        expanded[k] = v
    expanded['optimizer'] = config_item['optimizer']['type']
    for k, v in config_item['optimizer']['opt_args'].items():
        expanded[k] = v
#     handle all model args
#     handle all opt args
    expanded['n_warmup_steps'] = config_item['scheduler']['n_warmup_steps']
    expanded['nb_epochs'] = config_item['scheduler']['nb_epochs']
    expanded['max_grad_norm'] = config_item['optimizer']['max_grad_norm']
    expanded['learn_positional_encoding'] = config_item['model_args']['learn_positional_encoding']
    return expanded

In [None]:
expanded = [expand_config_item(c) for c in config_list]
config_df = pd.DataFrame.from_dict(expanded)

In [None]:
config_df

In [None]:
def get_nice_metrics(metric_list):
    correct = pd.DataFrame.from_dict([l['correct'] for l in metric_list])
    n_beams = pd.DataFrame.from_dict([l['meta']['n_beams'] for l in metric_list])
    n_beams.columns = ['n_beams']
    
    return [correct, n_beams]

In [None]:
merged = pd.concat([config_df] + get_nice_metrics(metric_list), axis=1)

In [None]:
# Drop all confiuraiton columns that have no variantion b/c that's not super helpful
drop_cols = []
for c in list(config_df) + ['n_beams']:
    if merged[c].nunique()==1:
        drop_cols.append(c)
metric_df = merged.drop(drop_cols, axis=1)

In [None]:
metric_df

In [None]:
def plot_single_value(df, groupby_col, metric_cols, metrics=None, num_columns = 4):
    if not metrics:
        metrics = {c: 'mean' for c in metric_cols}
    elif isinstance(metrics, list):
        metrics = {c: m for c, m in enumerate(metric_cols, metrics)}
    
    num_metrics = len(metric_cols)
    num_rows = num_metrics // num_columns + num_metrics % num_columns
    if num_rows==1:
        num_columns = num_metrics
        
    fig = plt.gcf()
    fig.set_size_inches(8*num_rows, 2*num_columns)
    grouped_by_target = df.groupby(groupby_col)
    for i, (metric, function) in enumerate(metrics.items()):
        ax = fig.add_subplot(num_rows, num_columns, 1+i)
        ax.set_title('%s'%metric)
        ax.set_ylabel('%s'%metric)
        grouped_by_target.agg({metric : function}).plot(ax=ax, legend=False, marker='x')
    fig.tight_layout()
    plt.show()
    
    

In [None]:
def plot_crossed_values(df, groupby_col, cross_col, metric_cols, metrics=None, num_columns = 4):
    if not metrics:
        metrics = {c: 'mean' for c in metric_cols}
    elif isinstance(metrics, list):
        metrics = {c: m for c, m in enumerate(metric_cols, metrics)}
    
    num_metrics = len(metric_cols)
    num_rows = num_metrics // num_columns + num_metrics % num_columns
    if num_rows==1:
        num_columns = num_metrics
        
    fig = plt.gcf()
    fig.set_size_inches(8*num_rows, 2*num_columns)
    
    
    grouped_by_cross = df.groupby(cross_col)

    for i, (metric, function) in enumerate(metrics.items()):
        ax = fig.add_subplot(num_rows, num_columns, 1+i)
        ax.set_title('%s'%metric)
        ax.set_ylabel('%s'%metric)
        for cross_name, cross_df in grouped_by_cross:
            cross_df = cross_df.groupby(groupby_col).agg('mean')
#             cross_df.set_index(groupby_col, inplace=True)
            cross_df[metric].plot(ax=ax, legend=False, label='%s: '%cross_col + str(cross_name), marker='x')
        ax.legend()
        
    fig.tight_layout()
    plt.show()
    
    

In [None]:
metric_cols = ['correct_product', 'correct_factorization']

In [None]:
plot_crossed_values(metric_df, 'base', 'n_beams', metric_cols)

In [None]:
plot_single_value(metric_df, 'base', metric_cols)