In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.core.display import display, HTML
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pandas as pd
import json, sys, re, os

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.max_open_warning'] = 50
display(HTML("<style>.container { width:100% !important; }</style>"))

sys.path.insert(0, os.getcwd() + "/../../")
from analysis.utils import PM_HOME, GPU_NAME
from analysis.trace_utils import *

In [None]:
# %%capture
num_gpus = 1
batch_size = 2048
iters = 10

flattens = []
for model_name in ['DLRM_default', 'DLRM_MLPerf', 'DLRM_DDP']:
    module_marker = "DLRM " if "DLRM" in model_name else "## Forward ##"
    trace_file = '{}/data/{}/e2e/{}/{}_{}.json'.format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)

    runtime_no_pf = -1
    log_file = "{}/data/{}/e2e/{}/{}_{}.log".format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)
    if os.path.exists(log_file):
        for line in open(log_file, 'r'):
            if re.search("Overall per-batch", line):
                runtime_no_pf = float(line.split(' ')[4]) * 1000 * iters # us

    trimmed_trace_file = trim_trace_by_num_iter(trace_file, iters=iters, trimmed_file='/tmp/{}_{}_{}'.format(model_name, num_gpus, batch_size))
    with open(trimmed_trace_file) as f:
        trace = json.load(f)

    roots, cc, corrected_start_time, corrected_end_time, sum_skipped_intervals = process_event_hierarchy(trace['traceEvents'], skip_module=False, module_marker=module_marker)
    host_runtime = corrected_end_time - corrected_start_time - sum_skipped_intervals
    device_runtime = host_runtime
    ops = []
    get_operators(roots, ops)
    op_device_runtime = get_device_runtime(ops, cc) # dict: op ex_id -> all its device calls and stats
    dt_breakdown = device_runtime_breakdown(roots, op_device_runtime, depth=0)
    flatten = {}
    for stream, v in dt_breakdown.items():
        flatten[stream] = {}
        get_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream])
    flattens.append((model_name, flatten, runtime_no_pf))

In [None]:
def plot_bar_chart(flattens, truncate=20):
    fig, axes = plt.subplots(nrows=len(flattens), ncols=1, figsize=(12, 5))

    cmap = plt.get_cmap('tab20b', truncate+1)
    gray = np.array([0.5, 0.5, 0.5, 1])
    colors = cmap(np.linspace(0, 1, truncate+1))
    colors[-1, :] = gray
    legend_dict = None

    for idx, tuple in enumerate(flattens):
        model_name, flatten, runtime_no_pf = tuple
        flatten = flatten[7]
        per_op = {}
        total = 0.0
        for k, v in flatten.items():
            if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
                continue
            k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
            if k0 not in per_op.keys():
                per_op[k0] = 0.0
            per_op[k0] += v['runtime']
            total += v['runtime']
            
        tmp = sorted(per_op.items(), key=lambda x: x[1], reverse=True)
        op = [x[0] for x in tmp]
        p = [x[1] / total for x in tmp]
        df0 = pd.DataFrame({
            'Active time': [flatten['total']['runtime'] / runtime_no_pf],
            'Idle time': [1 - flatten['total']['runtime'] / runtime_no_pf]
        })
        if len(p) > truncate:
            p[truncate-1] = sum(p[(truncate-1):])
            p = p[:truncate]
            op[truncate-1] = "others"
            op = op[:truncate]
        df = pd.DataFrame([p], columns=op)
        active_time = df0['Active time'].item()
        idle_time = df0['Idle time'].item()
        df = df * active_time
        df['Idle'] = idle_time
        
        if idx == 0:
            new_colors = colors
            legend_dict = {k: v for k, v in zip(df.columns.values.tolist(), colors)}
        else:
            assert legend_dict is not None
            columns = df.columns.values.tolist()
            new_colors = [legend_dict[c] if c in legend_dict.keys() else plt.get_cmap('tab20b')(truncate + idx) for c in columns]

        new_cmp = ListedColormap(new_colors)
        df.plot(stacked=True, legend=False, kind='barh', width=0.05, cmap=new_cmp, ax=axes[idx])
        vals = axes[idx].get_xticks()
        axes[idx].set_xticklabels(['{:,.0%}'.format(x) for x in vals])
        axes[idx].set_xlim((0.0, 1.0))
        axes[idx].set_yticks([])
        axes[idx].set_yticklabels([])
        axes[idx].set_ylim((-0.03, 0.03))
        axes[idx].set_title(model_name, fontsize=14)

        ax2 = axes[idx].twiny()
        ax2.set_xticks([x * active_time for x in np.arange(0, 1.2, 0.2)])
        ax2.set_xbound(axes[idx].get_xbound())
        ax2.set_xticklabels(['{:,.0%}'.format(x) for x in np.arange(0, 1.2, 0.2)])

    handels, labels = axes[0].get_legend_handles_labels()
    fig.legend(handels, labels, loc="lower center", ncol=5, bbox_to_anchor=(0.5, -0.3), frameon=False, fontsize=10.5)
    plt.tight_layout()
    plt.rcParams['figure.figsize'] = [12, 5]
    plt.savefig('{}/data/{}/e2e/active_time_breakdown.pdf'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')
    plt.savefig('{}/data/{}/e2e/active_time_breakdown.png'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')

plot_bar_chart(flattens)