In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
from pprint import pprint
from IPython.core.display import display, HTML
from scipy.stats.mstats import gmean 
import argparse, logging, tempfile, json, sys, os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', 100)
plt.rcParams['figure.max_open_warning'] = 100
pd.options.mode.chained_assignment = None

display(HTML("<style>.container { width:100% !important; }</style>"))
from trace_utils import *
from exec_graph_utils import *
from ml_predictors.mlp import get_pretrained_net, inference

In [2]:
# Utility functions
def abs_err(pred, real):
    return abs((pred - real) / real)

def err(pred, real):
    return (pred - real) / real

def gmae(x):
    return np.exp(np.log(abs(x)).mean())

def histogram(df, perc=True, bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.8, 1.0, 1.5, 2.0, 3.0, 4.0]):
    count = len(df)
    ret = {}
    for idx, b in enumerate(bins):
        if idx == 0:
            continue
        ret[(bins[idx-1], bins[idx])] = 0
    for x in df:
        for idx, b in enumerate(bins):
            if idx == 0:
                continue
            if x >= bins[idx-1] and x < bins[idx]:
                ret[(bins[idx-1], bins[idx])] += 1
                break
    for b, c in sorted(ret.items(), key=lambda x: x[0]):
        if perc:
            print("{:.0f}% - {:.0f}%: {:.2f}%".format(b[0] * 100, b[1] * 100, c / count * 100))
        else:
            print("{:.2f} - {:.2f}: {:.2f}%".format(b[0], b[1], c / count * 100))
    return ret

def strip_unit(x):
    for col in ['dram_read_throughput', 'dram_write_throughput', 'gld_requested_throughput', 'gld_throughput',\
               'gst_requested_throughput', 'gst_throughput', 'l2_read_throughput', 'l2_write_throughput', \
                'shared_load_throughput', 'shared_store_throughput']:
        if col in x.keys():
            if x[col].endswith('GB/s'):
                x[col] = float(x[col].rstrip('GB/s'))
            elif x[col].endswith('MB/s'):
                x[col] = float(x[col].rstrip('MB/s')) / 1e3
            elif x[col].endswith('B/s'):
                x[col] = float(x[col].rstrip('B/s')) / 1e9
            else:
                raise Exception("Unrecognizable unit!")
    return x
    
def p2f(x):
    for col in ['flop_dp_efficiency', 'flop_sp_efficiency', 'gld_efficiency', 'gst_efficiency', \
                'shared_efficiency', 'sm_efficiency', 'warp_execution_efficiency']:
        if col in x.keys():
            x[col] = float(str(x[col]).rstrip('%')) / 100.0
    return x

def strip_parenthesis(x):
    for col in ['dram_utilization', 'l2_utilization', 'tex_utilization']:
        if col in x.keys():
            x[col] = x[col].strip('(').strip(')')
    return x

def process_smem(x):
    # To bytes
    if 'smem' in x.keys():
        if x['smem'].endswith('MB'):
            x['smem'] = int(float(x['smem'].rstrip('MB')) * 1024 * 1024)
        elif x['smem'].endswith('KB'):
            x['smem'] = int(float(x['smem'].rstrip('KB')) * 1024)
        elif x['smem'].endswith('B'):
            x['smem'] = int(x['smem'].rstrip('B'))
        else:
            raise Exception("Unrecognizable unit!")
    return x
        
def preprocessing(df):
    df = df.apply(func=p2f, axis=1)
    df = df.apply(func=strip_unit, axis=1)
    df = df.apply(func=strip_parenthesis, axis=1)
    df = df.apply(func=process_smem, axis=1)
    df = df[(df['kernel_name'] != 'gemv2T_kernel') & (df['kernel_name'] != 'splitKreduce_kernel')]
    return df

def div_round_up(x, y):
    return int((x + y - 1) / y)

In [3]:
# model_name = "MLPerf_1"
model_name = "DLRM_default_1"
exec_graph_file = "../data/{}_graph.json".format(model_name)
with open(exec_graph_file) as f:
    graph = ExecutionGraph(json.load(f))
overheads_file = "./overheads_{}.json".format(model_name)
with open(overheads_file) as f:
    overheads = json.load(f)

In [4]:
graph.print_op_stats(detail=False, clean=True, json_format=False)

### OP STATS ###
op: AddmmBackward
  count: 6
  unique inputs:
  input: []
op: BmmBackward0
  count: 1
  unique inputs:
  input: []
op: CatBackward
  count: 2
  unique inputs:
  input: []
op: IndexBackward
  count: 1
  unique inputs:
  input: []
op: LookupFunction
  count: 1
  unique inputs:
  input: []
op: LookupFunctionBackward
  count: 1
  unique inputs:
  input: []
op: MseLossBackward
  count: 1
  unique inputs:
  input: []
op: Optimizer.step#SGD.step
  count: 1
  unique inputs:
  input: []
op: Optimizer.zero_grad#SGD.zero_grad
  count: 1
  unique inputs:
  input: []
op: ReluBackward0
  count: 5
  unique inputs:
  input: []
op: SigmoidBackward
  count: 1
  unique inputs:
  input: []
op: SliceBackward
  count: 1
  unique inputs:
  input: []
op: TBackward
  count: 6
  unique inputs:
  input: []
op: TransposeBackward0
  count: 1
  unique inputs:
  input: []
op: ViewBackward
  count: 1
  unique inputs:
  input: []
op: aten::add
  count: 2
  unique inputs:
  input: [{'dtype': 'float', '

### Performance models

In [5]:
L2_size = 6 * 1024 * 1024 * 4
num_SM = 80
peak_dram_bw = 809 # GB/s
peak_l2_bw = 2888 # GB/s
peak_throughput = 12200 # GFLOPS

In [6]:
def embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, **kwargs):
    # hit_rate = C(X, L) / C(E, L), X = avg_num_rows_per_table
    def hit_rate(X, E, L):
        ret = 1.0
        e = E
        x = X
        for idx in range(L):
            ret *= x / e
            x -= 1
            e -= 1
        return ret

    # Average number of rows per table in L2
    y = kwargs
    num_total_warps = y["batch_size"] * y["num_tables"] # Total warp number of the kernel
    num_warps_per_sm = y["rows_per_block"] # Number of warps per sm
    num_warps_simul = num_SM * num_warps_per_sm # Total number of warps simultaneously running on the device
    num_tables_simul = (num_warps_simul + y["batch_size"] - 1) // y["batch_size"] # Number of tables simultaneously being accessed on the device
    avg_table_size = min(L2_size // num_tables_simul, y["num_embeddings"] * y["embedding_dim"] * 4) # Average table size that reside on the device
    indices_size = 0
    avg_num_rows_per_table = (avg_table_size - indices_size) // 4 // y["embedding_dim"]

    # Hit rate
    hr = hit_rate(avg_num_rows_per_table, y["num_embeddings"], y["bag_size"])
    
    # num_thread_x
    num_thread_x = max(y["embedding_dim"] / 4, 1024 / y["rows_per_block"])

    # Traffics
    table_offsets_traffic = 32
    offsets_traffic = 32
    indices_dram_traffic = div_round_up(y["bag_size"] * 4, 32) * 32
    indices_l2_traffic = 0
    table_traffic = y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32)
    output_traffic = (div_round_up(y["embedding_dim"] * 4, 32) * 32)

    # avg_table_size all as dram traffic
    # 21, 26, 13, 7, 4, 4, 24, (0.2 ± 0.21)
    total_l2_traffic = ((table_offsets_traffic + offsets_traffic + indices_l2_traffic) * y["batch_size"] + \
                        hr * (table_traffic * y["batch_size"] - avg_table_size)) * y["num_tables"]
    total_dram_traffic = ((indices_dram_traffic + output_traffic) * y["batch_size"] + \
                          (1 - hr) * (table_traffic * y["batch_size"] - avg_table_size) + avg_table_size) * y["num_tables"]

    return max(total_dram_traffic / peak_dram_bw / 1000.0, total_l2_traffic / peak_l2_bw / 1000.0)
        
# e.g. 4340 vs 4789
embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, batch_size=4096, num_embeddings=500000, num_tables=197, bag_size=32, embedding_dim=32, rows_per_block=128)

4340.7676440049445

In [7]:
elf_data = pd.read_csv('../data/embedding_lookup_1_shmem.csv', delimiter=',')
elf_data = preprocessing(elf_data)
elf_data = elf_data[elf_data["kernel_name"].str.contains("batched_embedding")]
elf_data = elf_data[elf_data['batch_size'] > 1]

In [8]:
elf_time_all = elf_data.apply(lambda x: embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, **x[1:7]), axis=1)
error_all = abs_err(elf_time_all, elf_data['kernel_runtime'])
histogram(error_all)
print("==== All sizes ====")
print("GMAE: {:.2f}%, mean: {:.2f}%, std: {:.2f}%".format(gmae(error_all) * 100.0, error_all.mean() * 100.0, error_all.std() * 100.0))
elf_time_big = elf_data[elf_data['num_embeddings'] >= 100000].apply(lambda x: embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, **x[1:7]), axis=1)
error_big = abs_err(elf_time_big, elf_data[elf_data['num_embeddings'] >= 100000]['kernel_runtime'])
histogram(error_big)
print("==== Big sizes ====")
print("GMAE: {:.2f}%, mean: {:.2f}%, std: {:.2f}%".format(gmae(error_big) * 100.0, error_big.mean() * 100.0, error_big.std() * 100.0))

0% - 5%: 25.29%
5% - 10%: 14.39%
10% - 15%: 8.97%
15% - 20%: 9.42%
20% - 25%: 9.71%
25% - 30%: 6.02%
30% - 40%: 7.04%
40% - 50%: 5.28%
50% - 60%: 4.06%
60% - 80%: 9.28%
80% - 100%: 0.43%
100% - 150%: 0.12%
150% - 200%: 0.00%
200% - 300%: 0.00%
300% - 400%: 0.00%
==== All sizes ====
GMAE: 11.66%, mean: 21.98%, std: 21.00%
0% - 5%: 27.79%
5% - 10%: 21.77%
10% - 15%: 10.75%
15% - 20%: 10.89%
20% - 25%: 16.29%
25% - 30%: 5.75%
30% - 40%: 5.61%
40% - 50%: 1.15%
50% - 60%: 0.00%
60% - 80%: 0.00%
80% - 100%: 0.00%
100% - 150%: 0.00%
150% - 200%: 0.00%
200% - 300%: 0.00%
300% - 400%: 0.00%
==== Big sizes ====
GMAE: 8.68%, mean: 13.13%, std: 9.99%


In [9]:
def embedding_backward_sgd_predictor(peak_dram_bw, **kwargs):
    y = kwargs
    indices_traffic = div_round_up(y["bag_size"] * 4, 32) * 32
    grad_output_traffic = div_round_up(y["embedding_dim"] * 4, 32) * 32

    # Traffic per warp = t_offsets + t_table_offsets + t_indices + t_weights + t_grad_outputs
    total_traffic_per_warp = 32 + \
                            64 + \
                            indices_traffic + \
                            2 * y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32) + \
                            grad_output_traffic

    # Traffic = warp * traffic per warp
    total_traffic = y["batch_size"] * y["num_tables"] * total_traffic_per_warp

    # Total compute throughput
    mac_per_warp = y["bag_size"] * 4 * (y["embedding_dim"] // 4)
    total_mac = y["batch_size"] * y["num_tables"] * mac_per_warp

    return max(total_traffic / peak_dram_bw / 1000, total_mac / peak_throughput / 1000)

# e.g 42834 vs 44601
embedding_backward_sgd_predictor(peak_dram_bw, batch_size=2048, num_embeddings=200000, num_tables=128, bag_size=128, embedding_dim=128, rows_per_block=32, shmem=True)

42834.783248454885

In [10]:
# [14885288    29419    15123     7291    19899        3     6463     1310
#        61 10155909   618195   218994       10     2208     9779       71
#         4      963       14 16967044  4154705 13180313   289595    10828
#        95       34]

In [11]:
def get_kernel_time(op, addmm_list, bmm_list, tril_list):
    kernel_times = []
    if op.name == "aten::linear":
        for child in op.children:
            if "aten::t" in child.name:
                M, N = child.input_shapes[0][0], child.input_shapes[0][1]
                t = inference("transpose", "1-{}-{}".format(M, N))
                kernel_times.append(t)
#                 print("transpose", 1, M, N, "{:.2f}".format(t))
            else: # addmm
                addmm_list.append(child)
                M, K, N = child.input_shapes[1][0], child.input_shapes[1][1], child.input_shapes[2][1]
                t = inference("fully_connected", "1-{}-{}-{}".format(M, N, K))
                kernel_times.append(t)
#                 print("addmm", M, N, K, "{:.2f}".format(t))
    if op.name == "AddmmBackward":
        addmm_op = addmm_list.pop()
        M, K, N = addmm_op.input_shapes[1][0], addmm_op.input_shapes[1][1], addmm_op.input_shapes[2][1]
        m1, k1, n1 = M, N, K
        m2, k2, n2 = N, M, K
        t1 = inference("fully_connected", "1-{}-{}-{}".format(m1, n1, k1))
        kernel_times.append(t1)
        t2 = 0
        if M != N:
            t2 = inference("fully_connected", "1-{}-{}-{}".format(m2, n2, k2))
            kernel_times.append(t2)
        t = t1 + t2
#         print("addmm backward", M, N, K, "{:.2f}".format(t1), "{:.2f}".format(t2), "{:.2f}".format(t))
    if op.name == "aten::bmm":
        bmm_list.append(op)
        batch_size, M, K, N = op.input_shapes[0][0], op.input_shapes[0][1], op.input_shapes[0][2], op.input_shapes[1][2]
        t = inference("fully_connected", "{}-{}-{}-{}".format(batch_size, M, N, K))
        kernel_times.append(t)
#         print("bmm", batch_size, M, N, K, "{:.2f}".format(t))
    if op.name == "BmmBackward0":
        bmm_op = bmm_list.pop()
        batch_size, M, K, N = bmm_op.input_shapes[0][0], bmm_op.input_shapes[0][1], bmm_op.input_shapes[0][2], bmm_op.input_shapes[1][2]
        m1, k1, n1 = N, M, K
        m2, k2, n2 = M, N, K
        t1 = inference("fully_connected", "{}-{}-{}-{}".format(batch_size, m1, n1, k1))
        t2 = inference("fully_connected", "{}-{}-{}-{}".format(batch_size, m2, n2, k2))
        kernel_times.append(t1)
        kernel_times.append(t2)
        t = t1 + t2
#         print("bmm backward", batch_size, M, N, K, "{:.2f}".format(t1), "{:.2f}".format(t2), "{:.2f}".format(t))
    if op.name == "LookupFunction":
        if model_name == "MLPerf_1":
            Es = [14885288, 29419, 15123, 7291, 19899, 3, 6463, 1310, 61, 10155909, 618195, 218994, 10, 2208, 9779, 71, 4, 963, 14, 16967044, 4154705, 13180313, 289595, 10828, 95, 34] 
            B, E, T, L, D, rows_per_block = 2048, int(np.mean(Es)), len(Es), 1, 128, 2
        else:
            B, E, T, L, D, rows_per_block = 2048, 1000000, 8, 100, 64, 4
        t = embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, batch_size=B, num_embeddings=E, num_tables=T, bag_size=L, embedding_dim=D, rows_per_block=rows_per_block)
        kernel_times.append(t)
#         print("Embedding forward", t)
    if op.name == "LookupFunctionBackward":
        if model_name == "MLPerf_1":
            Es = [14885288, 29419, 15123, 7291, 19899, 3, 6463, 1310, 61, 10155909, 618195, 218994, 10, 2208, 9779, 71, 4, 963, 14, 16967044, 4154705, 13180313, 289595, 10828, 95, 34] 
            B, E, T, L, D, rows_per_block = 2048, int(np.mean(Es)), len(Es), 1, 128, 2
        else:
            B, E, T, L, D, rows_per_block = 2048, 1000000, 8, 100, 64, 4
        t = embedding_backward_sgd_predictor(peak_dram_bw, batch_size=B, num_embeddings=E, num_tables=T, bag_size=L, embedding_dim=D, rows_per_block=rows_per_block)
        kernel_times.append(t)
#         print("Embedding backward", t)
    if op.name == "aten::t":
        kernel_times.append(0) # T is handled under addmm
    if op.name == "aten::relu":
#             print(op.input_shapes)
        pass
    if op.name == "aten::sigmoid":
#             print(op.input_shapes)
        pass
    if op.name == "aten::add":
#             print(op.input_shapes)
        pass
    if op.name == "aten::index":
        tril_list.append(op)
        batch_size, M, N = op.input_shapes[0][0], op.input_shapes[0][1], op.input_shapes[0][2]
        total_output_element = op.input_shapes[1][1][0]
        if total_output_element == int(M * (1+N) / 2):
            diag = 1
        else:
            diag = 0
        t = inference("tril", "{}-{}-{}-{}".format(batch_size, M, N, diag), backward=False)
        kernel_times.append(t)
#         print("tril", batch_size, M, N, diag, "{:.2f}".format(t))
    if op.name == "IndexBackward": # See all kernels as a whole
        tril_op = tril_list.pop()
        batch_size, M, N = tril_op.input_shapes[0][0], tril_op.input_shapes[0][1], tril_op.input_shapes[0][2]
        total_output_element = tril_op.input_shapes[1][1][0]
        if total_output_element == int(M * (1+N) / 2):
            diag = 1
        else:
            diag = 0
        t = inference("tril", "{}-{}-{}-{}".format(batch_size, M, N, diag), backward=True)
        kernel_times.append(t)
#         print("tril backward", batch_size, M, N, diag, "{:.2f}".format(t))
    return kernel_times

In [12]:
pprint(overheads)

{'launches': {'AddmmBackward': ['cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel'],
              'BmmBackward0': ['cudaLaunchKernel', 'cudaLaunchKernel'],
              'IndexBackward': ['cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel',
                                'cudaLaunchKernel'],
              'LookupFunction': ['cudaLaunchKernel'],
              'LookupFunctionBackward': ['cudaLaunchKernel'],
              'MseLossBackward': ['cudaLaunchKernel', 'cudaLaunchKernel'],
       

In [13]:
nodes = graph.get_nodes(clean=True) # dict
sorted_nodes = sorted(nodes.items(), key=lambda x: x[0])
addmm_list = []
bmm_list = []
tril_list = []
cpu_time = 0
gpu_time = 0
gpu_active_time = 0

consider = ["aten::linear", "AddmmBackward", "aten::bmm", "BmmBackward0", "LookupFunction", "LookupFunctionBackward", "IndexBackward", "aten::index"]
whole = ["Optimizer.zero_grad#SGD.zero_grad", "Optimizer.step#SGD.step"]
skip = ["aten::random_", "aten::item"]

for id, op in sorted_nodes:
    if op.name in skip:
        continue
    is_op = (op.type == NodeType.OPERATOR and op.parent.type != NodeType.OPERATOR)
    if is_op:
        cpu_time += overheads["t1"][0] # T1: between two ops
        if op.name in overheads["launches"].keys(): # Has kernel calls
            cpu_time += overheads["t2"][op.name][0] # T2: before the first kernel call
            launches = overheads["launches"][op.name]
            if op.name in consider:
                t = get_kernel_time(op, addmm_list, bmm_list, tril_list) # Get kernel time

                for idx, l in enumerate(launches):
                    t4 = overheads["t4"][l][0] # Kernel launches
                    t5 = overheads["t5"][op.name][0] # Avg overhead between
#                     if "AddmmBackward" == op.name:
#                         print(overheads["t2"][op.name][0])
#                         print("============")
#                         print(cpu_time)
#                         print(t4)
#                         print(t5)

                    # Contribution of CPU overheads on GPU idle time
                    gpu_time = max(gpu_time + 1, cpu_time + t4/2) # Where the kernel starts: either launch right after last kernel, or at the middle of the kernel launch
                    
                    if idx < len(t):
                        gpu_time += t[idx]
                    cpu_time += t4
                    if idx < len(launches) - 1:
                        cpu_time += t5
                    
#                     if "AddmmBackward" == op.name:
#                         print(cpu_time)

                gpu_active_time += np.sum(t)
            else:
                if op.name in whole:
                    # Take the op as a whole without considering all its kernel calls
                    cpu_time += overheads["t2"][op.name][0] + overheads["t3"][op.name][0] + overheads["t5"][op.name][0] * (len(launches) - 1)
                else:
                    # Only consider CPU time then: op_cpu_time = T2 + (T4 sum) + (T5 sum) + T3
                    cpu_time += overheads["t5"][op.name][0] * (len(launches) - 1) # T5
                    cpu_time += np.sum([overheads["t4"][x][0] for x in launches]) # T4
            cpu_time += overheads["t3"][op.name][0] # T3: after the first kernel call
        else:
            cpu_time += overheads["t5"][op.name][0] # Ops that have no kernel calls only have T5 overheads (total CPU overheads)
        print(op.name, cpu_time)
            
total_time = max(gpu_time, cpu_time)
        
print("Total time is: {}".format(total_time))
print("GPU time is: {}".format(gpu_active_time))

aten::empty 17.59292929292929
aten::ones 54.18585858585858
aten::zeros 85.77878787878788
aten::empty 103.37171717171718
aten::to 360.13071789321793
aten::to 616.8897186147186
aten::to 873.6487193362193
aten::zeros 905.2416486291486
aten::empty 922.8345779220779
aten::linear 1083.3642237725373
aten::relu 1163.970511344232
aten::linear 1324.5001571946914
aten::relu 1405.106444766386
aten::zeros 1436.6993740593152
aten::empty 1454.2923033522445
aten::to 1711.0513040737453
aten::to 1967.810304795246
LookupFunction 2094.8965923669407
aten::zeros 2126.48952165987
aten::empty 2144.0824509527997
aten::view 2163.6753802457292
aten::cat 2240.611667817424
aten::transpose 2268.2045971103535
aten::bmm 2352.890884682048
aten::empty 2370.4838139749777
aten::to 2627.2428146964785
aten::detach_ 2646.835743989408
aten::empty 2664.4286732823375
aten::to 2921.1876740038383
aten::detach_ 2940.780603296768
aten::slice 2970.3735325896973
aten::to 3227.132533311198
aten::to 3483.891534032699
aten::index 3582.

# 