In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.core.display import display, HTML
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import pandas as pd
import json, sys, re, os

# plt.rcParams['figure.figsize'] = [10, 5]
# plt.rcParams['figure.max_open_warning'] = 50
# display(HTML("<style>.container { width:100% !important; }</style>"))

sys.path.insert(0, os.getcwd() + "/../../")
from analysis.utils import PM_HOME, GPU_NAME, CPU_EVENT_OVERHEAD, GPU_EVENT_OVERHEAD, KERNEL_LAUNCH_LENGTH
from analysis.trace_utils import *

In [None]:
# %%capture
model_name = "DLRM_default" # "DLRM_MLPerf"
num_gpus = 1
batch_size = 2048
iters = 10

trace_file = '{}/data/{}/e2e/{}/{}_{}.json'.format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)
print(trace_file)

trimmed_trace_file = trim_trace_by_num_iter(trace_file, iters=iters, trimmed_file='/tmp/{}_{}_{}.json'.format(model_name, num_gpus, batch_size))
with open(trimmed_trace_file) as f:
    trace = json.load(f)

print(trimmed_trace_file)

## DLRM with data loading

In [None]:
roots, cc, _, corrected_start_time, corrected_end_time, sum_skipped_intervals = process_event_hierarchy(trace['traceEvents'], skip_module=False)
print('Num of events: {}, num of root events: {}, num of caller/callee pairs: {}'.format(len(trace['traceEvents']), len(roots), len(cc)))
print('Sum of dataloading time: {}'.format(sum_skipped_intervals))
print("Corrected start time: ", corrected_start_time, ", corrected end time: ", corrected_end_time)
host_runtime = corrected_end_time - corrected_start_time - sum_skipped_intervals
# ---
# device_runtime, device_start_delay = get_device_runtime_and_start_delay(cc, corrected_start_time)
# print("Device start delay: ", device_start_delay)
# ---
device_runtime = host_runtime
# ---
print("Host runtime: ", host_runtime)
print("Device runtime: ", device_runtime)
ops = []
get_operators(roots, ops)
QPS = 1000000 / host_runtime * iters * 2048
print(f"QPS: {QPS:.2f}")

### Device runtime breakdown

In [None]:
op_device_runtime = get_device_runtime(ops, cc) # dict: op ex_id -> all its device calls and stats
dt_breakdown = device_runtime_breakdown(roots, op_device_runtime, depth=0)
truncate_count = 10
flatten = {}
for stream, v in dt_breakdown.items():
    # print("Stream: {}".format(stream))
    flatten[stream] = {}
    get_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream])
print(op_device_runtime)

In [None]:
for stream, v in dt_breakdown.items():
    print("Stream: {}".format(stream))
    print_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream], truncate_count=truncate_count)

In [None]:
# from matplotlib import cm
# cs=cm.Set1([1, 3, 2, 25, 4, 5, 6, 7, 8, 11])

# def plot_pie_chart(flatten, key="total", truncate_count=100, depth=0):
#     d = flatten[key]
    
#     # Pie chart, where the slices will be ordered and plotted counter-clockwise:
#     stats = sorted(d["subs"].items(), key=lambda x: x[1], reverse=True)
#     labels = [x[0] for x in stats]
#     runtime = [x[1] for x in stats]
#     explode = np.zeros(len(runtime))
#     if len(explode) > 2:
#         explode[1] = 0.1

#     fig1, ax1 = plt.subplots(figsize=(12, 6))
#     wedges, texts = ax1.pie(runtime, explode=explode, shadow=True, startangle=90, colors=cs)
#     ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
#     ax1.set_title(key)
    
#     ax1.legend(wedges, zip(labels, ["{:.2f}%".format(r / d["runtime"] * 100) for r in runtime]),
#           title="Breakdown",
#           loc="center left",
#           bbox_to_anchor=(1, 0, 0.5, 1),
#           fontsize=14)
    
#     for label in labels:
#         if label in flatten:
#             plot_pie_chart(flatten, key=label, truncate_count=truncate_count, depth=depth+1)

#     if depth == 0:
#         plt.show()

# for stream, v in flatten.items():
#     print("########################")
#     print("STREAM: {}".format(stream))
#     print("########################")
#     plot_pie_chart(v, truncate_count=truncate_count)

In [None]:
def kernel_name_to_legend(name):
    if 'gemm' in name:
        return 'gemm'
    if 'gemv' in name:
        return 'gemv'
    if 'Memset' in name:
        return 'Memset'
    if 'Memcpy' in name:
        return 'Memcpy'
    if 'vectorized_elementwise' in name:
        return 'elwt'
    if 'unrolled_elementwise' in name:
        return 'permute'
    if 'embedding_forward' in name:
        return 'ELF'
    if 'embedding_backward' in name:
        return 'ELB'
    if 'splitK' in name:
        return 'splitK'
    if 'reduce_kernel' in name:
        return 'reduce'
    if 'CatArray' in name:
        return 'CatCopy'
    if 'indexing' in name:
        return 'IdxBwd'
    return name[:8]

runtime_no_pf = -1
log_file = "{}/data/{}/e2e/{}/{}_{}.log".format(PM_HOME, GPU_NAME, model_name, num_gpus, batch_size)
if os.path.exists(log_file):
    for line in open(log_file, 'r'):
        if re.search("Overall per-batch", line):
            runtime_no_pf = float(line.split(' ')[4]) * 1000 * iters # us

def plot_bar_chart(flatten, key="total", truncate_count=100, depth=0):
    per_op = {}
    total = 0.0
    for k, v in flatten.items():
        if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
            continue
        k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
        if k0 not in per_op.keys():
            per_op[k0] = 0.0
        per_op[k0] += v['runtime']
        total += v['runtime']
    
    t = {}
    for k, v in flatten.items():
        if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
            continue
        k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
        if k0 not in t.keys(): # Per Op
            t[k0] = []
        a = {}
        total = 0.0
        for kk, vv in v['subs'].items():
            a[kk] = vv
            total += vv
        a = {x: y / total for x, y in a.items()}
        t[k0].append(a)
    so = sorted(t.items(), key=lambda x: per_op[x if isinstance(x, str) else x[0]], reverse=True)
    pprint([len(s[1]) for s in so])
    so = [list(x) for x in so]

    for s in so:
        keys = set()
        to_be_deleted = []
        # Find duplicate kernel combinations
        for idx, b in enumerate(s[1]):
            k = []
            for x in b.keys():
                if 'volta_sgemm' in x[0]:
                    xx = 'volta_sgemm'
                elif 'maxwell_sgemm' in x[0]:
                    xx = 'maxwell_sgemm'
                elif isinstance(x, str):
                    xx = x
                else:
                    xx = x[0]
                k.append(xx)
            k = tuple(sorted(k))
            if k in keys:
                to_be_deleted.append(idx)
            else:
                keys.add(k)
        # Delete duplicate kernel combinations
        for idx in sorted(to_be_deleted, reverse=True):
            del s[1][idx]
    pprint([len(s[1]) for s in so])

    nrows, ncols = 3, 8
    fig, axes = plt.subplots(nrows=3, ncols=8, figsize=(20, 13))
    xx, yy = 0, 0
    rects, texts = [], []
    for s in so:
        num_variants = len(s[1])
        for idx, b in enumerate(s[1]):
            ax = plt.subplot(nrows, ncols, 1+xx+idx+yy*ncols)
            p = sorted(b.items(), key=lambda x: x[1], reverse=True)
            kernel_name = [kernel_name_to_legend(x[0] if isinstance(x[0], str) else x[0][0]) for x in p]
            perc = [x[1] for x in p]
            df = pd.DataFrame([perc], columns=kernel_name)
            _ = df.plot(stacked=True, kind='bar', ax=ax, cmap='tab20b')
            ax.get_legend().remove()
            if xx+idx == 0:
                ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
                ax.set_yticklabels(['{:,.0%}'.format(x) for x in [0, 0.2, 0.4, 0.6, 0.8, 1.0]])
            else:
                ax.set_yticks([])
                ax.set_yticklabels([])
            ax.set_xticklabels([])
            ax.set_ylim((0.0, 1.0))
            ax.legend(bbox_to_anchor=(1.02, 1.04), loc='upper left', frameon=False, fontsize=14)
            if idx == (len(s[1]) - 1) / 2:
                ax.set_title(s[0], loc='left', fontsize=(16 if len(s[0]) < 20 else 12))
            
        # Borders
        llc_x = 1.0 / ncols * (xx) + (0.014 if xx != 0 else 0)
        llc_y = 1.0 / nrows * (nrows-yy-1) + 0.01
        rx = 1.0 / ncols * (num_variants) + (0.014 if xx == 0 else 0)
        ry = 1.0 / nrows
        rects.append(plt.Rectangle(
            (llc_x, llc_y), rx, ry, fill=False, color="k", lw=2, zorder=1000, transform=fig.transFigure, figure=fig
        ))

        # Subplot position
        xx += num_variants
        if xx >= ncols:
            xx = 0
            yy += 1

    # Hardcoded for now
    axes[2,7].set_axis_off()
    
    fig.patches.extend(rects)
    plt.tight_layout()
    plt.rcParams['figure.figsize'] = [20, 13]
    plt.savefig('{}/data/{}/e2e/{}/dominating_op_breakdown.pdf'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')
    plt.savefig('{}/data/{}/e2e/{}/dominating_op_breakdown.png'.format(PM_HOME, GPU_NAME, model_name), bbox_inches='tight')

for stream, v in flatten.items():
    print("########################")
    print("STREAM: {}".format(stream))
    print("########################")
    plot_bar_chart(v, truncate_count=truncate_count)

In [None]:
# Type 1 overhead: between two op calls
# Type 2 overhead: before the first device call, op-specific
# Type 3 overhead: after the last device call, op-specific
# Type 4 overhead: kernel launches themselves, kernel-launch-type-specific
# Type 5 overhead: sum of gaps between kernel launches, op-specific
overheads = {'independent': {}}
overheads['independent']['t1'] = [] # Independent from names
overheads['independent']['t4'] = {} # Independent from names
launches_dict = {}

for i, op in enumerate(ops):
    name = op.name()
    if name not in overheads.keys():
        overheads[name] = {}

    if 't2' not in overheads[name].keys():
        overheads[name]['t2'] = []
    if 't3' not in overheads[name].keys():
        overheads[name]['t3'] = []
    if 't5' not in overheads[name].keys():
        overheads[name]['t5'] = []

    sub_event_count = get_sub_event_count(op)
    # Get the number of events before each kernel launch (to subtract corresponding amount of CPU overheads from estimated time)
    tmp_launches = get_event_all_kernel_launches(op)
    launches = []
    count = 0
    for x, y in tmp_launches:
        count += y
        if x.name() in ["cudaMemcpyAsync", "cudaLaunchKernel"]:
            launches.append((x, count))
            count = 0

    if len(launches) > 0:
        overheads[name]['t2'].append(launches[0][0].start_time() - op.start_time() - launches[0][1] * CPU_EVENT_OVERHEAD) # T2 has all overheads before the first launch
        trailing_sub_event_count = sub_event_count - sum([y+1 for _, y in launches]) # And kernel launches themselves
        overheads[name]['t3'].append(op.end_time() - launches[-1][0].end_time() - trailing_sub_event_count * CPU_EVENT_OVERHEAD) # T3 has all overheads after the last launch
        if len(launches) > 1:
            overheads[name]['t5'].extend([launches[i][0].start_time() - launches[i-1][0].end_time() - launches[i][1] * CPU_EVENT_OVERHEAD for i in range(1, len(launches))]) # T5 has all overheads between each pair of launches
        else:
            overheads[name]['t5'].append(0)

        # T4 is launch-type-dependent
        for x, _ in launches:
            if x.name() not in overheads['independent']['t4']:
                overheads['independent']['t4'][x.name()] = []
            overheads['independent']['t4'][x.name()].append(KERNEL_LAUNCH_LENGTH - CPU_EVENT_OVERHEAD - GPU_EVENT_OVERHEAD) # T4 has 1 overhead

        if op.name() not in launches_dict.keys():
            launches_dict[op.name()] = []
            for x, _ in launches:
                launches_dict[op.name()].append(x.name())
    else:
        if name not in overheads.keys():
            overheads[name] = {}
        # If an op doesn't have kernel calls it has only one T5 overhead representing its CPU duration
        if 't5' not in overheads[name].keys():
            overheads[name]['t5'] = []
        if name == "aten::to":
            continue # Some aten::to doesn't have children
        else:
            overheads[name]['t5'].append(op.duration() - sub_event_count * CPU_EVENT_OVERHEAD) # Remove cpu overhead for all sub events

    if i > 0:
        prev_op = ops[i-1]

        # Only consider adjacent ops under the SAME MODULE
        if prev_op.parent != op.parent:
            continue

        gap = op.start_time() - prev_op.end_time()
        if gap < 200: # Skip dataloading gaps
            overheads['independent']['t1'].append(gap - CPU_EVENT_OVERHEAD) # Some pairs of ops are actually inserted by a runtime call which has been filtered from ops. TODO: fix it.

# # T1: mean ~= 21, std ~= 20
# from analysis.utils import histogram
# histogram(overheads['independent']['t1'], perc=False, bins=[0, 5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100, 200, 100000])
# print(np.mean(overheads['independent']['t1']), np.std(overheads['independent']['t1']))

# T2, T3, T5
t2 = {k: (np.mean(v['t2']), np.std(v['t2'])) for k, v in overheads.items() if k != 'independent' and len(v['t2']) > 0}
# pprint(t2)
t3 = {k: (np.mean(v['t3']), np.std(v['t3'])) for k, v in overheads.items() if k != 'independent' and len(v['t3']) > 0}
# pprint(t3)
t5 = {k: (np.mean(v['t5']), np.std(v['t5'])) for k, v in overheads.items() if k != 'independent' and len(v['t5']) > 0}
# pprint(t5)

# # T4
# for t, l in overheads['independent']['t4'].items():
#     print(t, np.mean(l), np.std(l))

o = {
    "t1": (np.mean(overheads['independent']['t1']), np.std(overheads['independent']['t1'])),
    "t2": t2,
    "t3": t3,
    "t4": {
        t: (np.mean(l), np.std(l)) for t, l in overheads['independent']['t4'].items()
    },
    "t5": t5,
    "launches": launches_dict
}

# with open("overheads_{}.json".format(model_name), "w") as f:
#     json.dump(o, f)

### Multistream analysis

In [None]:
# # Not finished yet
# all_kernels = []
# for _, c in cc.items():
#     for _, v in c["callees"].items():
#         if v["executor"] is not None:
#             all_kernels.append(v["executor"])
# all_kernels = sorted(all_kernels, key=lambda x: x.start_time())

# idle_time = 0
# last_end = all_kernels[0].start_time() + all_kernels[0].duration()
# overlapped = 0
# for k in all_kernels:
#     if k.start_time() > last_end:
#         idle_time += k.start_time() - last_end
#         last_end = k.start_time() + k.duration()
#     else:
#         last_end = max(last_end, k.start_time() + k.duration())
#         overlapped += min(last_end, k.start_time() + k.duration()) - k.start_time()

# print("device_runtime", device_runtime)
# print("idle_time:", idle_time)
# print("overlapped_time", overlapped)