In [7]:
# imports
import json
import glob
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import os
import pandas as pd

from functools import partial
from matplotlib import cm
from pandas.io.json import json_normalize

LOGS_PATH = '../logs'

In [6]:
# meta utils

def wrapped_json_loads(line, p):
    try:
        return json.loads(line.encode('utf-32-be').replace(b'\0', b'').decode())
    except:
        print('this is the culprit', p, line.encode('utf-32-be').replace(b'\0', b'').decode())
        print('fix this in the input and rerun')
        raise
        
        
def get_exp_path(exp_id):
    assert exp_id
    return os.path.join(LOGS_PATH, exp_id)
        

def load_metadata(exp_path):
    path = os.path.join(exp_path, 'metadata.json')
    with open(path) as f:
        d = json.load(f)
    return d


def load_worker_logs(exp_path, filter_funcs):
    ranks_paths = glob.iglob(os.path.join(exp_path, 'rank_*.log'))
    for p in ranks_paths:
        basename = os.path.basename(p).strip('.log')
        with open(p) as f:
            logs = (wrapped_json_loads(line, p) for line in f)
            for f in filter_funcs:
                logs = filter(f, logs)
            logs = list(logs)
        yield basename, logs
        

def load_data(exp_id=None, exp_path=None, filter_funcs=None):
    assert exp_id or exp_path
    filter_funcs = filter_funcs or []
    if not exp_path:
        exp_path = os.path.join(LOGS_PATH, exp_id)
    data = {}
    data['meta'] = load_metadata(exp_path)
    for basename, logs in load_worker_logs(exp_path, filter_funcs):
        data[basename] = logs
    return data

In [10]:
# Utils related to validation set performance - useful for evaluating training-time evolution

def make_df(exp_id, log_key, meta_keys):
    data = load_data(exp_id)
    gen = (
        {
            'rank': rank_key,
            **{mk: data['meta'][mk]
               for mk in meta_keys},
            **log
        }
        for rank_key in data if 'rank_' in rank_key
        for log in data[rank_key] if log_key in log
    )
    df = pd \
        .DataFrame(gen) \
        .sort_values(by='time', ascending=True) \
        .drop_duplicates(subset=['Rank', 'Update'], keep='last')
    
    if 'ValAcc' in df.columns:
        df['ValAcc'] = df['ValAcc'].astype(float)        
    return df, data['meta']


make_valacc_df = partial(make_df, log_key='ValAcc', meta_keys=['seed'])
make_valloss_df = partial(make_df, log_key='ValLoss', meta_keys=['seed'])
make_trainloss_df = partial(make_df, log_key='TrainLoss', meta_keys=['seed'])

def plot_valacc_evolution(exp_id, ax=None, color='blue', legend_label=None, color_idx=None, acc_range=True,
                          smooth=0, color_palette=cm.tab20):
    df, _ = make_valacc_df(exp_id)
    val_accs = pd.pivot_table(df, index=['Epoch', 'Update'], columns=['name'], values='ValAcc')
    if smooth:
        val_accs = val_accs.rolling(smooth, min_periods=1).mean()
    val_stats = val_accs \
        .apply(pd.DataFrame.describe, axis=1) \
        [['mean', 'min', 'max']] \
        .reset_index()
    if not ax:
        _, ax = plt.subplots(1)
    if color_idx:
        color = color_palette(color_idx)
    ax.plot(val_stats['Update'], val_stats['mean'], label=legend_label or 'Validation Accuracy', color=color)
    if acc_range:
        ax.fill_between(val_stats['Update'], val_stats['min'], val_stats['max'], facecolor=color, alpha=0.3)
    ax.set_ylabel('Validation Accuracy', fontweight='bold')
    ax.set_xlabel('Update', fontweight='bold')
    return ax

plot_evolution = plot_valacc_evolution

In [9]:
# Utils related to test set performance - useful for evaluating post-training performance

def load_test_acc_data(exp_path):
    filter_funcs = [
        lambda log: 'TestAcc' in log or 'AggTestAcc' in log
        # lambda log: 'TestAcc' in log
    ]
    data = load_data(exp_path=exp_path, filter_funcs=filter_funcs)
    agg_test_acc_row = None
    for key in data:
        if 'rank_' in key:
            # # because there's only one TestAcc log
            # data[key] = data[key][0]
            for row in data[key]:
                if 'AggTestAcc' in row:
                    agg_test_acc_row = row
                    break
            for i in range(len(data[key])):
                if 'TestAcc' in data[key][i]:
                    data[key] = data[key][i]
                    break
    if agg_test_acc_row:
        data['agg'] = agg_test_acc_row
    return data


def glob_and_load_all_test_accs(exp_root='mnist'):
    logs_path = '../logs/*'
    exps = glob.iglob(logs_path)
    exps = filter(lambda x: exp_root in x, exps)
    exps = filter(lambda x: 'bak' not in x, exps)
    test_accs = map(load_test_acc_data, exps)
    test_accs = list(test_accs)
    return test_accs


def create_clean_df(exp_root='mnist'):
    data = glob_and_load_all_test_accs(exp_root)
    df = json_normalize(data)
    df = df.drop(columns=[
        'meta.anneal_factor',
        'meta.anneal_milestones',
        'meta.async_cuda',
        'meta.checkpoint_interval',
        'meta.comm_backend',
        'meta.eval_on_gpu',
        'meta.from_checkpoint',
        'meta.gpu',
        'meta.log_level', 
        'meta.log_interval',
        'meta.log_path',
        'meta.num_threads',
        'meta.rank',
        'meta.momentum_correction',
        'meta.experiment',
        'meta.learning_rate',
        'meta.momentum',
        'meta.nesterov',
        'meta.timestamp',
        'meta.batch_size',
        'meta.epochs',
    ])
    cols = df.columns
    cols = filter(lambda x: 'data' not in x, cols)
    cols = filter(lambda x: 'master' not in x, cols)
    cols = filter(lambda x: 'log' not in x or 'TestAcc' in x or 'AggTestAcc' in x, cols)
    cols = list(cols)
    df = df[cols]
    renamed = {k: k.replace('meta.', '').replace('.log.', '.')
               for k in cols}
    df = df.rename(columns=renamed)
    df = df.dropna(subset=['rank_0.TestAcc'])
    return df