In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.display import display, HTML
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle
from itertools import chain
import pandas as pd
import json, sys, re, os, glob
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.max_open_warning'] = 50
display(HTML("<style>.container { width:100% !important; }</style>"))

sys.path.insert(0, os.getcwd() + "/../../")
from analysis.utils import PM_HOME, GPU_NAME, op_name_contains
from analysis.trace_utils import *

### Multi-GPU

In [None]:
# %%capture
num_gpus = 4
batch_size = 4096
iters = 30
model_names = ["DLRM_Heavy_EL", "DLRM_Light_EL", "GPT2"]

workloads = {
    "DLRM_Heavy_EL": '{}/data/{}/e2e/DLRM_open_source/2022/44-247-742-494-148-1-595-558-16-109-550-511-704-583-206-564-316-71-572-125-106-592-584-142-546-582-591-586-574-736-716-253-267-41-587-102/f/size_lookup_greedy/barrier_bucketed_allreduce/25/{}_{}_distributed_[0-{}].json'.format(PM_HOME, GPU_NAME, num_gpus, batch_size, num_gpus),
    "DLRM_Light_EL": '{}/data/{}/e2e/DLRM_open_source/2021/417-307-201-412-722-106-536-234-20-523-373-231-791-140-719-309-107-216-209-336-638-69-711-206-740-655-58-337-475-136-550-56-53-425-530-808-280-595/f/size_lookup_greedy/barrier_bucketed_allreduce/25/{}_{}_distributed_[0-{}].json'.format(PM_HOME, GPU_NAME, num_gpus, batch_size, num_gpus),
    "GPT2": '{}/data/{}/e2e/gpt2/barrier_bucketed_allreduce/25/4_64_distributed_[0-{}].json'.format(PM_HOME, GPU_NAME, num_gpus),
}

traces = {
    "DLRM_Heavy_EL": [],
    "DLRM_Light_EL": [],
    "GPT2": [],
}

actual_time = {
    "DLRM_Heavy_EL": [],
    "DLRM_Light_EL": [],
    "GPT2": [],
}

for model_name in workloads.keys():
    trace_files = glob.glob(workloads[model_name])

    for tf in trace_files:
        log_file = os.path.splitext(tf)[0][:-2] + '.log'
        runtime_no_pf = -1
        if os.path.exists(log_file):
            for line in open(log_file, 'r'):
                if re.search("Overall per-batch", line):
                    runtime_no_pf = float(line.split(' ')[4]) * 1000 * iters # us
                    actual_time[model_name].append(runtime_no_pf)
                    break

        trimmed_trace_file = os.path.splitext(tf)[0] + '_trimmed_{}.json'.format(iters)
        with open(trimmed_trace_file) as f:
            trace = json.load(f)
            traces[model_name].append(trace)

In [None]:
# It's impossible that a DL model doesn't have a linear, no?
def is_comp_stream(tmp):
    return any([("aten::linear" in x[0]) for x in tmp.keys()])

ranks = {}
for model_name in model_names:
    ranks[model_name] = []
    for idx, trace in enumerate(traces[model_name]):
        module_marker = "DLRM " if "DLRM" in model_name else "## Forward ##"
        roots, cc, streams, corrected_start_time, corrected_end_time, sum_skipped_intervals = process_event_hierarchy(trace['traceEvents'], skip_module=False, module_marker=module_marker)
        host_runtime = corrected_end_time - corrected_start_time - sum_skipped_intervals
        device_runtime = host_runtime
        ops = []
        get_operators(roots, ops)
        op_device_runtime = get_device_runtime(ops, cc) # dict: op ex_id -> all its device calls and stats
        dt_breakdown = device_runtime_breakdown(roots, op_device_runtime, depth=0)
        flatten = {}
        comp_count = 0
        comm_count = 0
        for stream, v in dt_breakdown.items():
            tmp = {}
            get_major_device_results(device_runtime, dt_breakdown[stream], tmp)
            if is_comp_stream(tmp):
                s = "computation_{}".format(comp_count)
                comp_count += 1
            else:
                s = "communication_{}".format(comm_count)
                comm_count += 1
            flatten[s] = tmp
        ranks[model_name].append((model_name, flatten, actual_time[model_name][idx]))

In [None]:
TYPES = ["Embedding Lookup", "GEMM", "Memory", "Communication", "Others", "Idle"]
f1 = lambda x: op_name_contains(x, [
    "fbgemm", "torch::autograd::CppNode<SplitLookupFunction_sgd_Op>",
])
f2 = lambda x: op_name_contains(x, [
    "aten::linear", "aten::addmm", "AddmmBackward", \
    "aten::bmm", "BmmBackward", \
    "aten::matmul", "MmBackward",
])
f3 = lambda x: op_name_contains(x, [
    "aten::to", "aten::cat", "Transpose", "transpose",
])
def get_type(x):
    if f1(x): return "Embedding Lookup"
    if f2(x): return "GEMM"
    if f3(x): return "Memory"
    return "Others"

def legend_helper(tmp, *args):
    if isinstance(tmp, plt.Figure):
        handles, labels = [list(chain.from_iterable(seq)) for seq in zip(*(
            ax.get_legend_handles_labels() for ax in tmp.axes
        ))]
    elif isinstance(tmp, list):
        handles, labels = [list(chain.from_iterable(seq)) for seq in zip(*(
            ax.get_legend_handles_labels() for ax in tmp
        ))]
    else:
        handles, labels = [list(chain.from_iterable(seq)) for seq in zip(*(
            ax.get_legend_handles_labels() for ax in chain([tmp], args)
        ))]
        
    return {
        'handles': handles,
        'labels': labels,
    }

In [None]:

def plot_bar_chart(ranks):
    ncols = len(model_names)
    nrows = num_gpus * 2
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12 * len(model_names) - 1, 9)) # 2 Streams per model
    cmap = plt.get_cmap('tab20b', 6) # 6 colors: 5 types + idle
    colors = cmap(np.linspace(0, 1, 6))
    colors[-1, :] = np.array([0.5, 0.5, 0.5, 1]) # Grey
    handles, labels = None, None
    for idz, (model_name, tts) in enumerate(ranks.items()):
        for idy, t in enumerate(tts):
            _, d, runtime_no_pf = t
            for idx, (stream, flatten) in enumerate(d.items()):
                ax = axes[idy * 2 + idx][idz]
                is_communication = "communication" in stream
                per_op = {}
                total = 0.0
                # ranks["GPT2"][0][1]["computation_0"]["total"]["subs"]
                for k, v in flatten["total"]["subs"].items():
                    k0 = k[0]
                    if k0 not in per_op.keys():
                        per_op[k0] = 0.0
                    per_op[k0] += v
                    total += v

                tmp = sorted(per_op.items(), key=lambda x: x[1], reverse=True)
                op = [x[0] for x in tmp]
                p = [x[1] / total for x in tmp]
                df0 = pd.DataFrame({
                    'Active time': [flatten['total']['runtime'] / runtime_no_pf],
                    'Idle time': [1 - flatten['total']['runtime'] / runtime_no_pf],
                })
                tmp_df = pd.DataFrame({
                    "type": ["Communication"] * len(op) if is_communication else [get_type(x) for x in op],
                    "perc": p,
                }).groupby(['type'], as_index=False).sum()
                df = pd.DataFrame([tmp_df["perc"].tolist()], columns=tmp_df["type"].tolist())
                active_time = df0['Active time'].item()
                idle_time = df0['Idle time'].item()
                df = df * active_time
                df['Idle'] = idle_time

                new_colors = [colors[TYPES.index(x)] for x in df.columns.values.tolist()]
                new_cmp = ListedColormap(new_colors)
                df.plot(stacked=True, legend=False, kind='barh', width=0.05, cmap=new_cmp, ax=ax)
                ax.set_xlim((0.0, 1.0))
                ax.set_yticks([])
                ax.set_yticklabels([])
                ax.set_ylim((-0.03, 0.03))
                ax.grid(axis='x')
                ax.set_title("Communication" if is_communication else "Computation", fontsize=16)
                if idy * 2 + idx == nrows - 1:
                    ax.set_xticks(ax.get_xticks())
                    ax.set_xticklabels(['{:,.0%}'.format(x) for x in ax.get_xticks()], fontsize=16)
                else:
                    ax.tick_params(labelbottom = False, bottom = False)

                if idz == 0 and idy == 0:
                    h, l = ax.get_legend_handles_labels()
                    if handles is None:
                        handles, labels = h[:-1], l[:-1]
                    else:
                        handles += h
                        labels += l
            ax.text(-0.02, 0.04, "Rank {}".format(idy), horizontalalignment='center', verticalalignment='center', size=16, rotation=90)
            rec = Rectangle((-0.04, -0.035), 1.05, 0.19, fill=False, lw=1.5, linestyle="dotted")
            rec.set_clip_on(False)
            ax.add_patch(rec)

        ax.text(0.5, 0.75, "{} ({:.2f} ms/iter)".format(model_name, max(actual_time[model_name]) / 1000.0 / iters), horizontalalignment='center', verticalalignment='center', size=18)
    fig.legend(handles, labels, loc="lower center", ncol=6, bbox_to_anchor=(0.5, 0.01), frameon=False, fontsize=18)
    plt.subplots_adjust(hspace=0.6, wspace=0.2)
    plt.savefig('{}/data/{}/e2e/multi_gpu_time_breakdown.pdf'.format(PM_HOME, GPU_NAME), bbox_inches='tight')
    plt.savefig('{}/data/{}/e2e/multi_gpu_time_breakdown.png'.format(PM_HOME, GPU_NAME), bbox_inches='tight')

plot_bar_chart(ranks)

### Single-GPU

In [None]:
# %%capture
num_gpus = 1
batch_size = 2048
iters = 10

flattens = []
for model_name in ['DLRM_default', 'DLRM_MLPerf', 'DLRM_DDP']:
    module_marker = "DLRM " if "DLRM" in model_name else "## Forward ##"
    trace_file = '{}/data/{}/e2e/{}/{}_{}.json'.format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)

    runtime_no_pf = -1
    log_file = "{}/data/{}/e2e/{}/{}_{}.log".format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)
    if os.path.exists(log_file):
        for line in open(log_file, 'r'):
            if re.search("Overall per-batch", line):
                runtime_no_pf = float(line.split(' ')[4]) * 1000 * iters # us

    trimmed_trace_file = trim_trace_by_num_iter(trace_file, iters=iters, trimmed_file='/tmp/{}_{}_{}'.format(model_name, num_gpus, batch_size))
    with open(trimmed_trace_file) as f:
        trace = json.load(f)

    roots, cc, corrected_start_time, corrected_end_time, sum_skipped_intervals = process_event_hierarchy(trace['traceEvents'], skip_module=False, module_marker=module_marker)
    host_runtime = corrected_end_time - corrected_start_time - sum_skipped_intervals
    device_runtime = host_runtime
    ops = []
    get_operators(roots, ops)
    op_device_runtime = get_device_runtime(ops, cc) # dict: op ex_id -> all its device calls and stats
    dt_breakdown = device_runtime_breakdown(roots, op_device_runtime, depth=0)
    flatten = {}
    for stream, v in dt_breakdown.items():
        flatten[stream] = {}
        get_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream])
    flattens.append((model_name, flatten, runtime_no_pf))

In [None]:
def plot_bar_chart(flattens, truncate=20):
    fig, axes = plt.subplots(nrows=len(flattens), ncols=1, figsize=(12, 5))

    cmap = plt.get_cmap('tab20b', truncate+1)
    gray = np.array([0.5, 0.5, 0.5, 1])
    colors = cmap(np.linspace(0, 1, truncate+1))
    colors[-1, :] = gray
    legend_dict = None

    for idx, tuple in enumerate(flattens):
        model_name, flatten, runtime_no_pf = tuple
        flatten = flatten[7]
        per_op = {}
        total = 0.0
        for k, v in flatten.items():
            if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
                continue
            k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
            if k0 not in per_op.keys():
                per_op[k0] = 0.0
            per_op[k0] += v['runtime']
            total += v['runtime']
            
        tmp = sorted(per_op.items(), key=lambda x: x[1], reverse=True)
        op = [x[0] for x in tmp]
        p = [x[1] / total for x in tmp]
        df0 = pd.DataFrame({
            'Active time': [flatten['total']['runtime'] / runtime_no_pf],
            'Idle time': [1 - flatten['total']['runtime'] / runtime_no_pf]
        })
        if len(p) > truncate:
            p[truncate-1] = sum(p[(truncate-1):])
            p = p[:truncate]
            op[truncate-1] = "others"
            op = op[:truncate]
        df = pd.DataFrame([p], columns=op)
        active_time = df0['Active time'].item()
        idle_time = df0['Idle time'].item()
        df = df * active_time
        df['Idle'] = idle_time
        
        if idx == 0:
            new_colors = colors
            legend_dict = {k: v for k, v in zip(df.columns.values.tolist(), colors)}
        else:
            assert legend_dict is not None
            columns = df.columns.values.tolist()
            new_colors = [legend_dict[c] if c in legend_dict.keys() else plt.get_cmap('tab20b')(truncate + idx) for c in columns]

        new_cmp = ListedColormap(new_colors)
        df.plot(stacked=True, legend=False, kind='barh', width=0.05, cmap=new_cmp, ax=axes[idx])
        vals = axes[idx].get_xticks()
        axes[idx].set_xticklabels(['{:,.0%}'.format(x) for x in vals])
        axes[idx].set_xlim((0.0, 1.0))
        axes[idx].set_yticks([])
        axes[idx].set_yticklabels([])
        axes[idx].set_ylim((-0.03, 0.03))
        axes[idx].set_title(model_name, fontsize=14)

        ax2 = axes[idx].twiny()
        ax2.set_xticks([x * active_time for x in np.arange(0, 1.2, 0.2)])
        ax2.set_xbound(axes[idx].get_xbound())
        ax2.set_xticklabels(['{:,.0%}'.format(x) for x in np.arange(0, 1.2, 0.2)])

    handels, labels = axes[0].get_legend_handles_labels()
    fig.legend(handels, labels, loc="lower center", ncol=5, bbox_to_anchor=(0.5, -0.3), frameon=False, fontsize=10.5)
    plt.tight_layout()
    plt.rcParams['figure.figsize'] = [12, 5]
    plt.savefig('{}/data/{}/e2e/active_time_breakdown.pdf'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')
    plt.savefig('{}/data/{}/e2e/active_time_breakdown.png'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')

plot_bar_chart(flattens)