In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.core.display import display, HTML
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import pandas as pd
import argparse, logging, tempfile, json, sys, pandas

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.max_open_warning'] = 50
display(HTML("<style>.container { width:100% !important; }</style>"))

from utils import PM_HOME
sys.path.insert(0, PM_HOME)
from analysis.trace_utils import *

In [2]:
# %%capture
model_name = "MLPerf_1"
# model_name = "DLRM_default_1"

trace_file = '../data/{}.json'.format(model_name)
# trace_file = '../data/MLPerf_1_prof_cpu.json' # Not supported: only Kineto has the right trace file structure (schemaVersion etc) and external ID in events
# trace_file = '../data/MLPerf_1_prof_cuda.json' # Same as above
# trace_file = '../data/MLPerf_1_prof_cuda_kineto.json' # Good
# trace_file = '../data/MLPerf_1_prof_kineto.json' # No device events
# trace_file = '../data/MLPerf_1_prof_new.json' # Good
iters = 10

trimmed_trace_file = trim_trace_by_num_iter(trace_file, iters=iters)
with open(trimmed_trace_file) as f:
    trace = json.load(f)

## DLRM with data loading

In [3]:
roots, cc, corrected_start_time, corrected_end_time, sum_skipped_intervals = process_event_hierarchy(trace['traceEvents'], skip_module=False, module_marker="DLRM ")
print('Num of events: {}, num of root events: {}, num of caller/callee pairs: {}'.format(len(trace['traceEvents']), len(roots), len(cc)))
print('Sum of dataloading time: {}'.format(sum_skipped_intervals))
print("Corrected start time: ", corrected_start_time, ", corrected end time: ", corrected_end_time)
host_runtime = corrected_end_time - corrected_start_time - sum_skipped_intervals
# ---
# device_runtime, device_start_delay = get_device_runtime_and_start_delay(cc, corrected_start_time)
# print("Device start delay: ", device_start_delay)
# ---
device_runtime = host_runtime
# ---
print("Host runtime: ", host_runtime)
print("Device runtime: ", device_runtime)
ops = []
get_operators(roots, ops)
QPS = 1000000 / host_runtime * iters * 2048
print(f"QPS: {QPS:.2f}")

Num of events: 18610, num of root events: 30, num of caller/callee pairs: 6289
Sum of dataloading time: 34507
Corrected start time:  1620971976952391 , corrected end time:  1620971977079261
Host runtime:  92363
Device runtime:  92363
QPS: 221733.81


In [4]:
def get_operators(roots, ops):
    for r in roots:
        # Is an operator, and
        # Not a module or submodule, and
        # (Parent is a module, or, is simply a root operator)
        if r.category() == "Operator" and\
            (not is_module(r)) and ((\
            r.parent is not None
        ) or (\
            r.parent is None\
        )) :
            ops.append(r)
        else:
            get_operators(r.children, ops)

ops = []
get_operators(roots, ops)

In [5]:
pprint([x.name() for x in ops])

['aten::zeros',
 'aten::empty',
 'aten::ones',
 'aten::zeros',
 'aten::empty',
 'aten::to',
 'aten::to',
 'aten::to',
 'aten::zeros',
 'aten::empty',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::relu',
 'aten::zeros',
 'aten::empty',
 'aten::to',
 'aten::to',
 'LookupFunction',
 'aten::zeros',
 'aten::empty',
 'aten::view',
 'aten::cat',
 'aten::transpose',
 'aten::bmm',
 'aten::empty',
 'aten::to',
 'aten::detach_',
 'aten::empty',
 'aten::to',
 'aten::detach_',
 'aten::slice',
 'aten::to',
 'aten::to',
 'aten::index',
 'aten::cat',
 'aten::zeros',
 'aten::empty',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::relu',
 'aten::linear',
 'aten::sigmoid',
 'aten::zeros',
 'aten::empty',
 'aten::to',
 'aten::binary_cross_entropy',
 'aten::detach',
 'aten::to',
 'aten::zeros',
 'aten::empty',
 'aten::zeros',
 'Optimizer.zero_grad#SGD.zero_grad',
 'aten::ones_like',
 'BinaryC

### Host runtime breakdown

In [6]:
depth_limit = 3
truncate_count = 3
host_runtime_breakdown = get_host_runtime_breakdown(roots, cc, host_runtime)
print_host_results(host_runtime_breakdown, depth_limit, truncate_count)

Two iteration runtime:                92363 (in us, same below)
     ## BENCHMARK ##:                           (91658.0, 99.24%, 10)
          DLRM backward:                                  (56180.0, 61.29%, 10)
               AddmmBackward:                                       (14715.0, 26.19%, 80)
                    aten::mm:                                                 (10255.0, 69.69%, 150)
                    aten::t:                                                   (2518.0, 17.11%, 230)
                    aten::conj:                                                 (146.0, 0.99%, 150)
                    Unaccounted:                                               (1796.0, 12.21%)
               Optimizer.zero_grad#SGD.zero_grad:                    (5185.0, 9.23%, 10)
                    aten::zero_:                                               (3563.0, 68.72%, 170)
                    aten::empty:                                                 (20.0, 0.39%, 10)
         

In [7]:
op_device_runtime = get_device_runtime(ops, cc) # dict: op ex_id -> all its device calls and stats
pprint(op_device_runtime)

{53627: {7: {('Memcpy HtoD (Pinned -> Device)', (-1,)): {'count': 1,
                                                         'runtime': 19.0}}},
 53630: {7: {('Memcpy HtoD (Pinned -> Device)', (-1,)): {'count': 1,
                                                         'runtime': 19.0}}},
 53633: {7: {('Memcpy HtoD (Pinned -> Device)', (-1,)): {'count': 1,
                                                         'runtime': 11.0}}},
 53641: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 14.0},
             ('volta_sgemm_128x32_tn', (-1,)): {'count': 1, 'runtime': 14.0}}},
 53649: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 12.0}}},
 53654: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime'

 55013: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('volta_sgemm_128x32_nn', (-1,)): {'count': 1, 'runtime': 192.0},
             ('volta_sgemm_128x32_nt', (-1,)): {'count': 1, 'runtime': 200.0}}},
 55029: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('reduce_kernel', (-1,)): {'count': 1, 'runtime': 15.0}}},
 55032: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 55038: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 8.0}}},
 55040: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 34.0}}},
 55043: {7: {('Memset (Device)', (-1,)): {'count': 2, 'runtime': 2.0},
             ('volta_sgemm_32x128_nn', (-1,)): {'count': 1, 'runtime': 392.0},
             ('volta_sgemm_32x128_nt', (-1,)):

 56004: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 9.0},
             ('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 56013: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 56016: {7: {('void gemmk1_kernel', (-1,)): {'count': 1, 'runtime': 6.0},
             ('void gemvNSP_kernel', (-1,)): {'count': 1, 'runtime': 11.0},
             ('void splitKreduce_kernel', (-1,)): {'count': 1,
                                                   'runtime': 4.0}}},
 56032: {7: {('reduce_kernel', (-1,)): {'count': 1, 'runtime': 7.0}}},
 56035: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 56041: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
            

 57885: {7: {('Memcpy HtoD (Pinned -> Device)', (-1,)): {'count': 1,
                                                         'runtime': 11.0}}},
 57893: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 14.0},
             ('volta_sgemm_128x32_tn', (-1,)): {'count': 1, 'runtime': 14.0}}},
 57901: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 12.0}}},
 57906: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 8.0},
             ('volta_sgemm_128x32_tn', (-1,)): {'count': 1, 'runtime': 68.0}}},
 57914: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 7.0}}},
 57919: {7: {('unrolled_elementwise_kernel', (-1,)): {'count'

                                                        'runtime': 3.0}}},
 58411: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 11.0}}},
 58414: {7: {('Memset (Device)', (-1,)): {'count': 2, 'runtime': 2.0},
             ('volta_sgemm_128x32_nn', (-1,)): {'count': 1, 'runtime': 60.0},
             ('volta_sgemm_32x128_nt', (-1,)): {'count': 1, 'runtime': 106.0}}},
 58430: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('reduce_kernel', (-1,)): {'count': 1, 'runtime': 10.0}}},
 58433: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 58439: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 4.0}}},
 58441: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                      

             ('volta_sgemm_32x128_nt', (-1,)): {'count': 1, 'runtime': 120.0}}},
 60314: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('reduce_kernel', (-1,)): {'count': 1, 'runtime': 15.0}}},
 60317: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 2.0}}},
 60323: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 4.0}}},
 60325: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 18.0}}},
 60328: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('volta_sgemm_128x32_nn', (-1,)): {'count': 1, 'runtime': 192.0},
             ('volta_sgemm_128x32_nt', (-1,)): {'count': 1, 'runtime': 201.0}}},
 60344: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('reduce_kernel', (-1,)): {'cou

                                                        'runtime': 5.0}}},
 61264: {7: {('Memcpy DtoH (Device -> Pageable)', (-1,)): {'count': 1,
                                                           'runtime': 1.0}}},
 61275: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 16,
                                                        'runtime': 49.0}}},
 61315: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 2.0}}},
 61319: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 9.0},
             ('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 61328: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 61331: {7: {('void gemmk1_kernel', (-1,)): {'count': 1, 'runtime

 62693: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 18.0}}},
 62696: {7: {('void splitKreduce_kernel', (-1,)): {'count': 1, 'runtime': 5.0},
             ('volta_sgemm_32x128_nt', (-1,)): {'count': 1, 'runtime': 15.0}}},
 62706: {7: {('Memset (Device)', (-1,)): {'count': 1, 'runtime': 1.0},
             ('reduce_kernel', (-1,)): {'count': 1, 'runtime': 11.0}}},
 62709: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 62715: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 3.0}}},
 62720: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 16,
                                                        'runtime': 80.0}}},
 63194: {7: {('Memcpy HtoD (Pinned -> Device)', (-1,)): {'count': 1,
                                                     

 63655: {7: {('volta_sgemm_32x128_nt', (-1,)): {'count': 1, 'runtime': 105.0},
             ('volta_sgemm_64x64_nn', (-1,)): {'count': 1, 'runtime': 139.0}}},
 63671: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 117.0}}},
 63682: {7: {('unrolled_elementwise_kernel', (-1,)): {'count': 1,
                                                      'runtime': 10.0}}},
 63683: {7: {('void batched_embedding_backward_sgd_kernel_1', (-1,)): {'count': 1,
                                                                       'runtime': 113.0}}},
 63696: {7: {('vectorized_elementwise_kernel', (-1,)): {'count': 1,
                                                        'runtime': 7.0}}},
 63699: {7: {('void splitKreduce_kernel', (-1,)): {'count': 1, 'runtime': 5.0},
             ('volta_sgemm_128x32_nn', (-1,)): {'count': 1, 'runtime': 21.0},
             ('volta_sgemm_64x32_sliced1x4_nt', (-1,)): {'count': 1,
               

In [8]:
print_all_device_results(roots, op_device_runtime, device_runtime, depth=0)

## BENCHMARK ##
    DLRM forward
        aten::to
            Memcpy HtoD (Pinned -> Device):                       ( (-1,), 1, 19.0 )
        aten::to
            Memcpy HtoD (Pinned -> Device):                       ( (-1,), 1, 19.0 )
        aten::to
            Memcpy HtoD (Pinned -> Device):                       ( (-1,), 1, 11.0 )
        module::forward_pass::bottom_mlp
            aten::linear
                unrolled_elementwise_kernel:                              ( (-1,), 1, 14.0 )
                volta_sgemm_128x32_tn:                                    ( (-1,), 1, 14.0 )
            aten::relu
                vectorized_elementwise_kernel:                            ( (-1,), 1, 12.0 )
            aten::linear
                unrolled_elementwise_kernel:                              ( (-1,), 1, 8.0 )
                Memset (Device):                                          ( (-1,), 1, 1.0 )
                volta_sgemm_128x32_tn:                                    ( (-1,), 1

                volta_sgemm_32x128_tn:                                    ( (-1,), 1, 190.0 )
            aten::to
                Memcpy HtoD (Pageable -> Device):                         ( (-1,), 1, 1.0 )
            aten::to
                Memcpy HtoD (Pageable -> Device):                         ( (-1,), 1, 1.0 )
            aten::index
                index_elementwise_kernel:                                 ( (-1,), 1, 21.0 )
            aten::cat
                CatArrayBatchedCopy:                                      ( (-1,), 1, 16.0 )
        module::forward_pass::top_mlp
            aten::linear
                unrolled_elementwise_kernel:                              ( (-1,), 1, 18.0 )
                volta_sgemm_128x32_tn:                                    ( (-1,), 1, 203.0 )
            aten::relu
                vectorized_elementwise_kernel:                            ( (-1,), 1, 24.0 )
            aten::linear
                unrolled_elementwise_kernel:             

            vectorized_elementwise_kernel:                        ( (-1,), 1, 3.0 )
        torch::autograd::AccumulateGrad
            vectorized_elementwise_kernel:                        ( (-1,), 1, 16.0 )
        ReluBackward0
            vectorized_elementwise_kernel:                        ( (-1,), 1, 34.0 )
        AddmmBackward
            Memset (Device):                                      ( (-1,), 2, 2.0 )
            volta_sgemm_32x128_nn:                                ( (-1,), 1, 200.0 )
            volta_sgemm_32x128_nt:                                ( (-1,), 1, 170.0 )
        aten::sum
            Memset (Device):                                      ( (-1,), 1, 1.0 )
            reduce_kernel:                                        ( (-1,), 1, 22.0 )
        torch::autograd::AccumulateGrad
            vectorized_elementwise_kernel:                        ( (-1,), 1, 3.0 )
        torch::autograd::AccumulateGrad
            vectorized_elementwise_kernel:             

### Device runtime breakdown

In [9]:
dt_breakdown = device_runtime_breakdown(roots, op_device_runtime, depth=0)
# pprint(dt_breakdown)
truncate_count = 10
flatten = {}
for stream, v in dt_breakdown.items():
    print("Stream: {}".format(stream))
    flatten[stream] = {}
    get_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream])
pprint(flatten)

Stream: 7
{7: {'total': {'runtime': 49980.0,
               'subs': {('## BENCHMARK ##', (-1,)): 49980.0}},
     ('## BENCHMARK ##', (-1,)): {'runtime': 4993.0,
                                  'subs': {('DLRM backward', (-1,)): 3318.0,
                                           ('DLRM forward', (-1,)): 1657.0,
                                           ('DLRM loss compute', (-1,)): 17.0,
                                           ('aten::to', (-1,)): 1.0}},
     ('AddmmBackward', (-1,)): {'runtime': 19.0,
                                'subs': {('void gemmk1_kernel', (-1,)): 5.0,
                                         ('void gemvNSP_kernel', (-1,)): 10.0,
                                         ('void splitKreduce_kernel', (-1,)): 4.0}},
     ('BinaryCrossEntropyBackward', (-1,)): {'runtime': 12.0,
                                             'subs': {('unrolled_elementwise_kernel', (-1,)): 9.0,
                                                      ('vectorized_elementwise_kernel

In [10]:
for stream, v in dt_breakdown.items():
    print("Stream: {}".format(stream))
    print_major_device_results(device_runtime, dt_breakdown[stream], flatten[stream], truncate_count=truncate_count)

Stream: 7
    Total device time: 92363 (in us, same below)
    Device idle time: 42383.0 (45.89%)
    Device active time: 49980.0 (54.11%)
      ## BENCHMARK ##:                                     (49980.0, 100.00%, 10)                                                        (-1,)
        DLRM backward:                                         (3318.0, 66.45%, 1)                                                        (-1,)
          AddmmBackward:                                            (1955.0, 58.92%, 8)                                                        (-1,)
            void gemvNSP_kernel:                                           (10.0, 52.63%, 1)                                                        (-1,)
            void gemmk1_kernel:                                             (5.0, 26.32%, 1)                                                        (-1,)
            void splitKreduce_kernel:                                       (4.0, 21.05%, 1)                         

In [None]:
pprint(dt_breakdown)

In [None]:
from matplotlib import cm
cs=cm.Set1([1, 3, 2, 25, 4, 5, 6, 7, 8, 11])

def plot_pie_chart(flatten, key="total", truncate_count=100, depth=0):
    d = flatten[key]
    
    # Pie chart, where the slices will be ordered and plotted counter-clockwise:
    stats = sorted(d["subs"].items(), key=lambda x: x[1], reverse=True)
    labels = [x[0] for x in stats]
    runtime = [x[1] for x in stats]
    explode = np.zeros(len(runtime))
    if len(explode) > 2:
        explode[1] = 0.1

    fig1, ax1 = plt.subplots(figsize=(12, 6))
    wedges, texts = ax1.pie(runtime, explode=explode, shadow=True, startangle=90, colors=cs)
    ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    ax1.set_title(key)
    
    ax1.legend(wedges, zip(labels, ["{:.2f}%".format(r / d["runtime"] * 100) for r in runtime]),
          title="Breakdown",
          loc="center left",
          bbox_to_anchor=(1, 0, 0.5, 1),
          fontsize=14)
    
    for label in labels:
        if label in flatten:
            plot_pie_chart(flatten, key=label, truncate_count=truncate_count, depth=depth+1)

    if depth == 0:
        plt.show()

for stream, v in flatten.items():
    print("########################")
    print("STREAM: {}".format(stream))
    print("########################")
    plot_pie_chart(v, truncate_count=truncate_count)

In [None]:
def kernel_name_to_legend(name):
    if 'gemm' in name:
        return 'gemm'
    if 'gemv' in name:
        return 'gemv'
    if 'Memset' in name:
        return 'Memset'
    if 'Memcpy' in name:
        return 'Memcpy'
    if 'vectorized_elementwise' in name:
        return 'elwt'
    if 'unrolled_elementwise' in name:
        return 'permute'
    if 'embedding_forward' in name:
        return 'ELF'
    if 'embedding_backward' in name:
        return 'ELB'
    if 'splitK' in name:
        return 'splitK'
    if 'reduce_kernel' in name:
        return 'reduce'
    if 'CatArray' in name:
        return 'CatCopy'
    if 'indexing' in name:
        return 'IdxBwd'
    return name[:8]

def plot_bar_chart(flatten, key="total", truncate_count=100, depth=0):
    per_op = {}
    total = 0.0
    for k, v in flatten.items():
        if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
            continue
        k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
        if k0 not in per_op.keys():
            per_op[k0] = 0.0
        per_op[k0] += v['runtime']
        total += v['runtime']
        
    tmp = sorted(per_op.items(), key=lambda x: x[1], reverse=True)
    op = [x[0] for x in tmp]
    p = [x[1] / total for x in tmp]
    df0 = pd.DataFrame({
        'Active time': [flatten['total']['runtime'] / device_runtime],
        'Idle time': [1 - flatten['total']['runtime'] / device_runtime]
    })
    df = pd.DataFrame([p], columns=op)
    
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 3))
    axes[0] = plt.subplot2grid(shape=(2, 12), loc=(0, 1), colspan=10) # Uneven sizes of subplots
    axes[1] = plt.subplot2grid(shape=(2, 12), loc=(1, 0), colspan=12)

    ax0 = df0.plot(stacked=True, title=" ", kind='barh', width=0.05, ax=axes[0], cmap='Set2')
    vals = ax0.get_xticks()
    ax0.set_xticklabels(['{:,.0%}'.format(x) for x in vals])
    ax0.set_xlim((0.0, 1.0))
    ax0.set_yticks([])
    ax0.set_yticklabels([])
    ax0.set_ylim((-0.03, 0.03))
    ax0.legend(loc="lower center", ncol=2, bbox_to_anchor=(0.5, -0.7), frameon=False, fontsize=10.5)
    
    ax1 = df.plot(stacked=True, kind='barh', width=0.05, ax=axes[1], cmap='tab20b') # https://matplotlib.org/stable/tutorials/colors/colormaps.html
    vals = ax1.get_xticks()
    ax1.set_xticklabels(['{:,.0%}'.format(x) for x in vals])
    ax1.set_xlim((0.0, 1.0))
    ax1.set_yticks([])
    ax1.set_yticklabels([])
    ax1.set_ylim((-0.03, 0.03))
    ax1.legend(loc="lower center", ncol=4, bbox_to_anchor=(0.5, -1.7), frameon=False, fontsize=10.5)
    
    # Space between subplots
    plt.subplots_adjust(hspace=0.7)
    
    # Lines across subplots
    con1 = ConnectionPatch(xyA=(0,-0.025), xyB=(0,0.025), coordsA="data", coordsB="data", axesA=ax0, axesB=ax1, linestyle='dotted')
    con2 = ConnectionPatch(xyA=(flatten['total']['runtime'] / device_runtime,-0.025), xyB=(1,0.025), coordsA="data", coordsB="data", axesA=ax0, axesB=ax1, linestyle='dotted')
    ax1.add_artist(con1)
    ax1.add_artist(con2)

    plt.tight_layout()
    plt.rcParams['figure.figsize'] = [12, 3]
    plt.savefig('active_time_breakdown.pdf', bbox_inches='tight')
    
    t = {}
    for k, v in flatten.items():
        if k == 'total' or 'DLRM ' in k[0] or 'module' in k[0]: # Skip all labels
            continue
        k0 = k[0] if '#' not in k[0] else k[0].split('#')[0]
        if k0 not in t.keys(): # Per Op
            t[k0] = []
        a = {}
        total = 0.0
        for kk, vv in v['subs'].items():
            a[kk] = vv
            total += vv
        a = {x: y / total for x, y in a.items()}
        t[k0].append(a)
    so = sorted(t.items(), key=lambda x: per_op[x if isinstance(x, str) else x[0]], reverse=True)
    pprint([len(s[1]) for s in so])
    so = [list(x) for x in so]

    for s in so:
        keys = set()
        to_be_deleted = []
        # Find duplicate kernel combinations
        for idx, b in enumerate(s[1]):
            k = []
            for x in b.keys():
                if 'volta_sgemm' in x[0]:
                    xx = 'volta_sgemm'
                elif 'maxwell_sgemm' in x[0]:
                    xx = 'maxwell_sgemm'
                elif isinstance(x, str):
                    xx = x
                else:
                    xx = x[0]
                k.append(xx)
            k = tuple(sorted(k))
            if k in keys:
                to_be_deleted.append(idx)
            else:
                keys.add(k)
        # Delete duplicate kernel combinations
        for idx in sorted(to_be_deleted, reverse=True):
            del s[1][idx]
    pprint([len(s[1]) for s in so])
#     pprint(so)

    nrows, ncols = 3, 8
    fig, axes = plt.subplots(nrows=3, ncols=8, figsize=(20, 13))
    xx, yy = 0, 0
    rects, texts = [], []
    for s in so:
        num_variants = len(s[1])
        for idx, b in enumerate(s[1]):
            ax = plt.subplot(nrows, ncols, 1+xx+idx+yy*ncols)
            p = sorted(b.items(), key=lambda x: x[1], reverse=True)
            kernel_name = [kernel_name_to_legend(x[0] if isinstance(x[0], str) else x[0][0]) for x in p]
            perc = [x[1] for x in p]
            df = pd.DataFrame([perc], columns=kernel_name)
            _ = df.plot(stacked=True, kind='bar', ax=ax, cmap='tab20b')
            ax.get_legend().remove()
            if xx+idx == 0:
                ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
                ax.set_yticklabels(['{:,.0%}'.format(x) for x in [0, 0.2, 0.4, 0.6, 0.8, 1.0]])
            else:
                ax.set_yticks([])
                ax.set_yticklabels([])
            ax.set_xticklabels([])
            ax.set_ylim((0.0, 1.0))
            ax.legend(bbox_to_anchor=(1.02, 1.04), loc='upper left', frameon=False, fontsize=14)
            if idx == (len(s[1]) - 1) / 2:
                ax.set_title(s[0], loc='left', fontsize=(16 if len(s[0]) < 20 else 12))
            
        # Borders
        llc_x = 1.0 / ncols * (xx) + (0.014 if xx != 0 else 0)
        llc_y = 1.0 / nrows * (nrows-yy-1) + 0.01
        rx = 1.0 / ncols * (num_variants) + (0.014 if xx == 0 else 0)
        ry = 1.0 / nrows
        rects.append(plt.Rectangle(
            (llc_x, llc_y), rx, ry, fill=False, color="k", lw=2, zorder=1000, transform=fig.transFigure, figure=fig
        ))

        # Subplot position
        xx += num_variants
        if xx >= ncols:
            xx = 0
            yy += 1

    # Hardcoded for now
    axes[2,7].set_axis_off()
    
    fig.patches.extend(rects)
    plt.tight_layout()
    plt.rcParams['figure.figsize'] = [20, 13]
    plt.savefig('dominating_op_breakdown.pdf', bbox_inches='tight')

for stream, v in flatten.items():
    print("########################")
    print("STREAM: {}".format(stream))
    print("########################")
    plot_bar_chart(v, truncate_count=truncate_count)

In [None]:
pprint(flatten)

In [None]:
def histogram(df, perc=True, is_abs=False, bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.8, 1.0, 1.5, 2.0, 3.0, 4.0]):
    count = len(df)
    ret = {}
    if is_abs:
        tmp_bins = []
        for i in range(0, len(bins) - 1):
            tmp_bins.append(-bins[len(bins) - 1 - i])
        for b in bins:
            tmp_bins.append(b)
        bins = tmp_bins
    for idx, b in enumerate(bins):
        if idx == 0:
            continue
        ret[(bins[idx-1], bins[idx])] = 0
    for x in df:
        for idx, b in enumerate(bins):
            if idx == 0:
                continue
            if x >= bins[idx-1] and x < bins[idx]:
                ret[(bins[idx-1], bins[idx])] += 1
                break
    for b, c in sorted(ret.items(), key=lambda x: x[0]):
        if perc:
            print("{:.0f}% - {:.0f}%: {:.2f}%".format(b[0] * 100, b[1] * 100, c / count * 100))
        else:
            print("{:.2f} - {:.2f}: {:.2f}%".format(b[0], b[1], c / count * 100))
    return ret

In [None]:
# Type 1 overhead: between two op calls
# Type 2 overhead: before the first device call, op-specific
# Type 3 overhead: after the last device call, op-specific
# Type 4 overhead: kernel launches themselves, kernel-launch-type-specific
# Type 5 overhead: sum of gaps between kernel launches, op-specific
overheads = {'independent': {}}
overheads['independent']['t1'] = [] # Independent from names
overheads['independent']['t4'] = {} # Independent from names
launches_dict = {}

for i, op in enumerate(ops):
    name = op.name()
    if name not in overheads.keys():
        overheads[name] = {}

    if 't2' not in overheads[name].keys():
        overheads[name]['t2'] = []
    if 't3' not in overheads[name].keys():
        overheads[name]['t3'] = []
    if 't5' not in overheads[name].keys():
        overheads[name]['t5'] = []

    launches = get_event_all_kernel_launches(op)
    launches = [x for x in launches if x.name() == "cudaMemcpyAsync" or x.name() == "cudaLaunchKernel" or x.name() == "cudaStreamSynchronize"]
    
    if len(launches) > 0:
        overheads[name]['t2'].append(launches[0].start_time() - op.start_time())
        overheads[name]['t3'].append(op.end_time() - launches[-1].end_time())
        if len(launches) > 1:
            overheads[name]['t5'].extend([launches[i].start_time() - launches[i-1].end_time() for i in range(1, len(launches))])
        else:
            overheads[name]['t5'].append(0)
        
        # T4 is launch-type-dependent
        for x in launches:
            if x.name() not in overheads['independent']['t4']:
                overheads['independent']['t4'][x.name()] = []
            overheads['independent']['t4'][x.name()].append(x.duration())
        
        if op.name() not in launches_dict.keys():
            launches_dict[op.name()] = []
            for x in launches:
                launches_dict[op.name()].append(x.name())
    else:
        # If an op doesn't have kernel calls it has only T5 overheads
        if op.name() not in overheads[name].keys():
            overheads[name]['t5'] = []
        overheads[name]['t5'].append(op.duration())

    if i == 0:
        continue
    prev_op = ops[i-1]
    
    # Only consider adjacent ops under the SAME MODULE
    if prev_op.parent != op.parent:
        continue
        
    gap = op.start_time() - prev_op.end_time()
    if gap < 200: # Skip dataloading gaps
        overheads['independent']['t1'].append(gap) # Some pairs of ops are actually inserted by a runtime call which has been filtered from ops. TODO: fix it.

# T1: mean ~= 21, std ~= 20
histogram(overheads['independent']['t1'], perc=False, bins=[0, 5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100, 200, 100000])
print(np.mean(overheads['independent']['t1']), np.std(overheads['independent']['t1']))

# T2, T3, T5
t2 = {k: (np.mean(v['t2']), np.std(v['t2'])) for k, v in overheads.items() if k != 'independent' and len(v['t2']) > 0}
pprint(t2)
t3 = {k: (np.mean(v['t3']), np.std(v['t3'])) for k, v in overheads.items() if k != 'independent' and len(v['t3']) > 0}
pprint(t3)
t5 = {k: (np.mean(v['t5']), np.std(v['t5'])) for k, v in overheads.items() if k != 'independent' and len(v['t5']) > 0}
pprint(t5)

# T4
for t, l in overheads['independent']['t4'].items():
    print(t, np.mean(l), np.std(l))
    
o = {
    "t1": (np.mean(overheads['independent']['t1']), np.std(overheads['independent']['t1'])),
    "t2": t2,
    "t3": t3,
    "t4": {
        t: (np.mean(l), np.std(l)) for t, l in overheads['independent']['t4'].items()
    },
    "t5": t5,
    "launches": launches_dict
}

with open("overheads_{}.json".format(model_name), "w") as f:
    json.dump(o, f)

### Multistream analysis

In [None]:
all_kernels = []
for _, c in cc.items():
    for _, v in c["callees"].items():
        if v["executor"] is not None:
            all_kernels.append(v["executor"])
all_kernels = sorted(all_kernels, key=lambda x: x.start_time())

idle_time = 0
last_end = all_kernels[0].start_time() + all_kernels[0].duration()
overlapped = 0
for k in all_kernels:
    if k.start_time() > last_end:
        idle_time += k.start_time() - last_end
        last_end = k.start_time() + k.duration()
    else:
        last_end = max(last_end, k.start_time() + k.duration())
        overlapped += min(last_end, k.start_time() + k.duration()) - k.start_time()

print("device_runtime", device_runtime)
print("idle_time:", idle_time)
print("overlapped_time", overlapped)