In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.core.display import display, HTML
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
import json, sys, math, glob, os
from scipy.stats import t as ttest

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.max_open_warning'] = 50
display(HTML("<style>.container { width:100% !important; }</style>"))

sys.path.insert(0, os.getcwd() + "/../../")
from analysis.utils import PM_HOME, GPU_NAME
from analysis.trace_utils import *

In [None]:
alpha = 0.05
eps = 1e-7
ncols = 10
workload_appearance_min = 5
workload_appearance_max = 10
iters = 100
root_dir = '{}/data/{}/e2e'.format(PM_HOME, GPU_NAME)
overhead_stats_files = glob.glob("{}/*/*overhead_stats_{}.json".format(root_dir, iters))
overhead_raw_files = glob.glob("{}/*/*overhead_raw_{}.csv".format(root_dir, iters))


def get_p_value(a, b):
    """
        a: tuple of 3 (mean, std, count)
        b: tuple of 3 (mean, std, count)
    """
    xa, sa, na = a
    xb, sb, nb = b
    t_stat = (xa - xb) / math.sqrt((sa * sa / na + sb * sb / nb + eps))
    # degrees of freedom
    df = na + nb - 2
    # calculate the p-value
    p = (1 - ttest.cdf(abs(t_stat), df)) * 2.0
    return p
# print(get_p_value((1.3, 0.5, 22), (1.6, 0.3, 24))) # 0.0188

def get_two_sample_ttest_hotmap(data):
    array = np.ones((len(data), len(data)))
    for idxa, a in enumerate(data):
        for idxb, b in enumerate(data):
            p = get_p_value(a[1], b[1])
            if p < alpha: # means they are not the same distribution (null hypothesis rejected)
                array[idxa, idxb] = 0
                array[idxb, idxa] = 0
    return array

# def kl_divergence():
    


In [None]:
o_stats = {
    't1': {},
    't2': {},
    't3': {},
    't4': {},
    't5': {}
}

shapes = "(((-1,),), ((-1,),))"

for file in overhead_stats_files:
    model_name = file.split('/')[-2]
    batch_size = file.split('/')[-1].split('_')[1]

    with open(file) as f:
        overhead = json.load(f)

    for t in o_stats.keys(): # Overhead types
        if t == 't1':
            o_stats[t][(model_name, batch_size)] = overhead[t]
        elif t == 't4':
            for runtime_f, s in overhead[t].items():
                if runtime_f not in o_stats[t].keys():
                    o_stats[t][runtime_f] = {}
                o_stats[t][runtime_f][(model_name, batch_size)] = s
        else:
            for op_name, s in overhead[t].items():
                bw_truncated_name = op_name.split("autograd::engine::evaluate_function: ")[-1]
                if bw_truncated_name not in o_stats[t].keys():
                    o_stats[t][bw_truncated_name] = {
                        shapes: {}
                    }
                if s[shapes][0] != 0:
                    o_stats[t][bw_truncated_name][shapes][(model_name, batch_size)] = s[shapes]

for t in ['t2', 't3', 't5']:
    # Remove trivial stats
    del_names = []
    for k, v in o_stats[t].items():
        if len(v[shapes].keys()) <= workload_appearance_min: # Appear in less than xx model-batch workloads
            del_names.append(k)
    for k in del_names:
        del o_stats[t][k]

df = gather_overhead_raw(overhead_raw_files)

In [None]:
t1 = df[df['type'] == 't1']
gb = t1.groupby(['model_name', 'batch_size'])
t1s = [gb.get_group(x) for x in gb.groups]
# t1_hist = t1.hist(bins=15)
# t1s_hist = [x.hist(bins=15) for x in t1s]

In [None]:
# Histograms
value_set = set()
t1_hist = t1['time'].value_counts().sort_index()
t1s_hist = [x['time'].value_counts().sort_index() for x in t1s]

# Value sets for all
value_set = set(t1_hist.index.values.flatten())
for x in t1s_hist:
    value_set = value_set.intersection(set(x.index.values.flatten()))
# Insert if not exists
def clean(x, value_set):
    diff = set(x.index.values.flatten()).difference(value_set)
    tmp = x.drop(list(diff))
    return tmp / sum(tmp.values)
t1_pdf = clean(t1_hist, value_set)
t1s_pdfs = [clean(x, value_set) for x in t1s_hist]

from scipy.special import rel_entr
KLs = [sum(rel_entr(t1_pdf.values, x.values)) for idx, x in enumerate(t1s_pdfs)]
plt.figure()
ax = plt.gca()
ax.bar(np.arange(len(KLs)), KLs)
ax.set_xticks(np.arange(len(KLs)))
labels = ["-".join([x.iloc[0]['model_name'], x.iloc[0]['batch_size']]) for x in t1s]
ax.set_xticklabels(labels, rotation=75)

In [None]:
s = [("-".join([x.iloc[0]['model_name'], x.iloc[0]['batch_size']]), x['time'].mean(), x['time'].std(), len(x)) for x in t1s]
plt.figure()
ax = plt.gca()
ax.bar([x[0] for x in s], [x[1] for x in s])
ax.set_xticks(np.arange(len(s)))
labels = [x[0] for x in s]
ax.set_xticklabels(labels, rotation=75)

### T1

In [None]:
plt.figure()
ax = plt.gca()
tmp = sorted(o_stats['t1'].items(), key=lambda x: x[0]) # x[1][0] * x[1][2]
tmp = [t for t in tmp if 'DLRM' in t[0][0]]
tmp.append()
hotmap = get_two_sample_ttest_hotmap(tmp)
ax.imshow(hotmap, cmap='hot', interpolation='nearest', vmin=0, vmax=1)
ax.title.set_text("T1")
ax.set_xticks(range(0, len(tmp)))
ax.set_yticks(range(0, len(tmp)))
ax.set_yticklabels([t[0] for t in tmp])
plt.tight_layout()

### T2, T3, T5

In [None]:
"""
o_stats = {
    op1: {
        shapes: {
            (model_name, batch_size): [mean, std, count]
        }
    },
    op2: {
        shapes: {
            (model_name, batch_size): [mean, std, count]
        }
    },
    ...
}
"""
for t in ['t2', 't3', 't5']:
    nrows = math.ceil(len(o_stats[t].keys()) / ncols)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols)

    count = 0
    for op_name, s in o_stats[t].items():
        t = sorted(s[shapes].items(), key=lambda vv: -vv[1][0] * vv[1][2])
        tmp = t[:workload_appearance_max]
        hotmap = get_two_sample_ttest_hotmap(tmp)
        ax = axs[count // ncols, count % ncols]
        ax.imshow(hotmap, cmap='hot', interpolation='nearest', vmin=0, vmax=1)
        ax.set_title(op_name[:15], fontsize=8)
        ax.set_xticks(range(0, len(tmp)))
        ax.set_yticks(range(0, len(tmp)))
        ax.set_xticklabels([""] * len(tmp))
        ax.set_yticklabels([""] * len(tmp))
        count += 1
    plt.tight_layout()
plt.rcParams['figure.figsize'] = [ncols * 2, nrows]