In [None]:
import pandas as pd
import numpy as np
import math, sys
from matplotlib import pyplot as plt
from IPython.core.display import display, HTML
from pprint import pprint
from utils import histogram, PM_HOME
sys.path.insert(0, PM_HOME)
from analysis.inference import *
display(HTML("<style>.container { width:90% !important; }</style>"))
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

### Concat

In [None]:
_, error = infer_concat()
_ = histogram(error)

### Memcpy

In [None]:
_, error = infer_memcpy()
_ = histogram(error)

### Transpose

In [None]:
_, error = infer_from_model(op_type="transpose")
_ = histogram(error)

### Tril

In [None]:
_, error = infer_from_model(op_type="tril", backward=False)
_ = histogram(error)
_, error = infer_from_model(op_type="tril", backward=True)
_ = histogram(error)

### Embedding Lookup Forward

In [None]:
# all sizes
_, error = infer_el(backward=False, big=False, hit_rate_estimation=False)
_ = histogram(error)
_, error = infer_el(backward=False, big=False, hit_rate_estimation=True)
_ = histogram(error)

In [None]:
# big sizes
_, error = infer_el(backward=False, big=True, hit_rate_estimation=False)
_ = histogram(error)
_, error = infer_el(backward=False, big=True, hit_rate_estimation=False)
_ = histogram(error)

### Embedding Lookup Backward

In [None]:
# all sizes
_, error = infer_el(backward=True, big=False, hit_rate_estimation=False)
_ = histogram(error)
_, error = infer_el(backward=True, big=False, hit_rate_estimation=True)
_ = histogram(error)

In [None]:
# big sizes
_, error = infer_el(backward=True, big=True, hit_rate_estimation=False)
_ = histogram(error)
_, error = infer_el(backward=True, big=True, hit_rate_estimation=True)
_ = histogram(error)

### Fully Connected

In [None]:
_, error = infer_from_model(op_type="fully_connected")
_ = histogram(error)

### Legacy Code

In [None]:
# t = max(t_Mem, t_MemLatency, t_Compute)
# num_main_loop_iteration = K / blkK


# t_Mem = t_Prologue + (t_MainLoop_Mem + t_Epilogue) * (# CTA per SM)
#     t_Prologue = (t_DRAM_2_Regs_output) + (t_Regs_2_SMEM_output + t_SMEM_2_Regs_input) [t_GLS + t_SAS]
#         t_DRAM_2_Regs_output = LAT_DRAM + (blkM * blkN) / (BW_DRAM / # SM) --------------- [TBD]
#             LAT_DRAM --------------------------------- Look up the datasheet. [Good]
# 
#         t_Regs_2_SMEM_output = LAT_SMEM + (blkM * blkN) / (BW_SMEM) ---------------------- [TBD]
#             LAT_SMEM --------------------------------- Look up the datasheet. [Good]
# 
#         t_SMEM_2_Regs_input = (blkWM + blkWN) * blkK * (# warps) / (BW_SMEM)
#             blkWM,WN --------------------------------- [TBD]
#             (# warps) -------------------------------- [TBD]
# 
#         (# CTA) -------------------------------------- Calculated
# 
#     t_Epilogue = (blkM * blkN) / BW_DRAM ------------------------------------------------- [Good]
#     t_MainLoop_Mem = max(trf_L1, trf_L2, trf_DRAM) * num_main_loop_iteration --------------[Good]
#         trf_L1,L2,DRAM ------------------------------- Calculated [TODO]


# t_DRAM_LAT = t_Prologue + (tGLS + max(t_CS / blkK, t_SAS / blkK) * num_main_loop_iteration + t_Epilogue) * (# CTA per SM) / (# Active CTA)
#     t_Prologue: [same as above]
#     t_CS = (blkM * blkN * blkK) / BW_MAC
#     t_SAS = (LAT_SMEM + (blkM * blkN) + (blkWM + blkWN) * blkK * (# warps)) / (BW_SMEM)
#     t_Epilouge: [same as above]

# t_Compute_or_SMEM = t_Prologue + (t_MainLoop_Compute + t_Epilogue) * (# CTA per SM)
#     t_Prologue: [same as above]
#     t_MainLoop_Compute = max(t_CS, t_SAS) * num_main_loop_iteration
#         t_CS: [same as above]
#         t_SAS: [same as above]
#     t_Epilouge: [same as above]


# blkM,N,K --------------------------------------------- Predicted from shape [TODO]
# (# SM) ----------------------------------------------- Given
# BW_L1/L2/DRAM/MAC ------------------------------------ Measured
# (# CTA per SM) = (# CTA / # SM) ---------------------- [Good]
# (# Active CTA per SM) -------------------------------- Predicted from kernel names and shapes [TODO]

In [None]:
# # Decision tree
# from sklearn import tree
# from sklearn.preprocessing import LabelEncoder
# clf = tree.DecisionTreeClassifier()
# lb = LabelEncoder()

# name_and_size = fc_data[(fc_data['kernel_name'] != 'gemv2T_kernel') & (fc_data['kernel_name'] != 'splitKreduce_kernel')][['batch_size', 'M', 'N', 'K', 'kernel_name']]
# name_and_size["kernel_name_code"] = lb.fit_transform(name_and_size["kernel_name"])
# X = name_and_size[['batch_size', 'M', 'N', 'K']]
# y = name_and_size['kernel_name_code']
# clf = clf.fit(X, y)
# # plt.figure(figsize=(12, 8))
# # tree.plot_tree(clf)
# # plt.show()

# def predict_kernel_name(*shape):
#     n = clf.predict([*shape])
#     return list(lb.inverse_transform(n))[0]
# predict_kernel_name([1, 64, 64, 4096])

In [None]:
# def get_blkMNK_from_kernel_name(kernel_name):
#     splitted = kernel_name.split('_')
#     x = splitted[2]
#     blkN = int(x.split('x')[0])
#     blkM = int(x.split('x')[1])
#     # blkK is either 4 or 8. Assuming 4 for (<= 64) and 8 for others. No evidence. TODO: Confirm this.
#     if blkM <= 64:
#         blkK = 4
#     else:
#         blkK = 8
#     return blkM, blkN, blkK

# # batch_size, M, N, K = 1, 64, 64, 4096
# # Assuming blkM,N,K are known. Find someone to confirm.
# # A potential idea: use B, M, N, K to predict the kernel to be used.
# def get_blks(batch_size, M, N, K, clf=None):
#     if clf is None:
#         row = trunc[(trunc['batch_size'] == batch_size) & (trunc['M'] == M) & (trunc['N'] == N) & (trunc['K'] == K)]
#         if row.empty:
#             return -1, -1, -1, -1, -1, -1
#         kernel_name = str(row['kernel_name'])
#     else:
#         kernel_name = list(lb.inverse_transform(clf.predict([[batch_size, M, N, K]])))[0]
#     blkM, blkN, blkK = get_blkMNK_from_kernel_name(kernel_name)
        
#     # Confirmed
#     block_x = div_round_up(N, blkN)
#     block_y = div_round_up(M, blkM)
    
#     # # Using batch_size as block_z for now. Usually works for BMM.
#     # block_z = batch_size if batch_size > 1 else int(row['block_z'])
#     block_z = batch_size

#     return blkM, blkN, blkK, block_x, block_y, block_z

# # Verify blkMNK and block_xyz correctness. TODO: Need a test set for prediction of blkM and blkN.
# predicted_blks = bmm_data.apply(lambda x: get_blks(x['batch_size'], x['M'], x['N'], x['K'], clf), axis=1)
# predicted_blks = pd.DataFrame(predicted_blks.tolist(), index=predicted_blks.index, columns =['blkM', 'blkN', 'blkK', 'block_x', 'block_y', 'block_z'])
# actual_blkMNK = bmm_data.apply(lambda x: get_blkMNK_from_kernel_name(x['kernel_name']), axis=1)
# actual_blkMNK = pd.DataFrame(actual_blkMNK.tolist(), index=actual_blkMNK.index, columns =['blkM', 'blkN', 'blkK'])

# block_x_error = sum(predicted_blks['block_x'] != bmm_data['block_x']) / len(predicted_blks) * 100.0
# block_y_error = sum(predicted_blks['block_y'] != bmm_data['block_y']) / len(predicted_blks) * 100.0
# block_z_error = sum(predicted_blks['block_z'] != bmm_data['block_z']) / len(predicted_blks) * 100.0

# blkM_error = sum(predicted_blks['blkM'] != actual_blkMNK['blkM']) / len(predicted_blks) * 100.0
# blkN_error = sum(predicted_blks['blkN'] != actual_blkMNK['blkN']) / len(predicted_blks) * 100.0
# blkK_error = sum(predicted_blks['blkK'] != actual_blkMNK['blkK']) / len(predicted_blks) * 100.0

# print("Block xyz error rates: {}%, {}%, {}%".format(block_x_error, block_y_error, block_z_error))
# print("blk xyz error rates: {}%, {}%, {}%".format(blkM_error, blkN_error, blkK_error))

In [None]:
# # kernels and regs/smem usages are mapped one by one
# # Occupancy (CTA per SM): warps per SM, blocks per SM, regs per SM, SMEM per SM
# # bmm: only 128x64 and 128x128 kernels
# # addmm: includes all others

# info = {}
# # for kernel in fc_data['kernel_name'].unique():
# #     print(kernel)
# #     df = fc_data[fc_data['kernel_name'] == kernel]
# #     print('df length:', len(df))
# #     print(df['thread_x'].unique(), df['regs'].unique(), df['smem'].unique())
# # print("************")

# for kernel in fc_data['kernel_name'].unique():
#     df = fc_data[fc_data['kernel_name'] == kernel]
#     num_thread_x = int(df['thread_x'].unique()[0])
#     num_warps_per_CTA = num_thread_x // 32
#     num_regs_per_thread = int(df['regs'].unique()[0])
#     SMEM_per_CTA = int(df['smem'].unique()[0])
#     blkM, blkN, _ = get_blkMNK_from_kernel_name(kernel)
#     tmp = num_warps_per_CTA
#     if kernel.endswith('sliced1x4_tn'):
#         tmp = tmp // 4
#     blkWM, blkWN = blkM, blkN
#     while tmp > 1:
#         if blkWM > blkWN:
#             blkWM = blkWM // 2
#         else:
#             blkWN = blkWN // 2
#         tmp = tmp // 2
#     tmp = {}
#     tmp['num_warps_per_CTA'] = num_warps_per_CTA
#     tmp['num_active_CTA'] = min(maximum_warps_per_SM // num_warps_per_CTA, maximum_CTA_per_SM, regs_size // (num_thread_x * num_regs_per_thread), SMEM_size // SMEM_per_CTA)
#     tmp['blkWM'] = blkWM
#     tmp['blkWN'] = blkWN
#     info[kernel] = tmp

# #     print(kernel)
# #     occ = df['achieved_occupancy']
# #     histogram(occ / (tmp['num_active_CTA'] / maximum_CTA_per_SM), perc=False, bins=[0.0, 0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5, 2.0])
# #     print(len(occ))

# # occ = fc_data['achieved_occupancy']
# # num_active_CTA_df = fc_data['kernel_name'].map(lambda x: info[x]['num_active_CTA'])
# # histogram(occ / (num_active_CTA_df / maximum_CTA_per_SM), perc=False, bins=[0.0, 0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5, 2.0])
# # print(len(occ))
    
# pprint(info)

In [None]:
# # L2 traffic: almost never bound on this. Set to 0. Evidence: L2 utilization has never been higher than 2.
# L2_traffic = 0

# def get_DRAM_traffic(batch_size, M, N, K, blkM, blkN, num_CTA, num_active_CTA):
#     # Num of waves
#     num_waves = div_round_up(num_CTA // batch_size, num_SM * num_active_CTA) # Num of CTA waves per batch

#     A = M * K * 4
#     B = N * K * 4

#     num_block_per_A_col = div_round_up(M, blkM) # Num of CTA block in a block column of A
#     num_block_per_B_row = div_round_up(N, blkN) # Num of CTA block in a block row of B

#     # Possibility 1: Along rows
#     if num_SM * num_active_CTA < num_block_per_B_row:
#         DRAM_traffic_row = (B * num_block_per_A_col + A + blkM * num_waves * K * 4)
#     else:
#         DRAM_traffic_row = (B * num_waves + A + blkM * num_waves * K * 4)

#     # Possibility 2: Along columns
#     if num_SM * num_active_CTA < num_block_per_A_col: # One CTA batch doesn't fill a full column of A
#         DRAM_traffic_col = (A * num_block_per_B_row + B + blkN * num_waves * K * 4)
#     else:
#         DRAM_traffic_col = (A * num_waves + B + blkN * num_waves * K * 4)

#     # Take the min of them
#     DRAM_traffic = batch_size * min(DRAM_traffic_row, DRAM_traffic_col)

#     return DRAM_traffic
# # get_DRAM_traffic(batch_size=1, M=64, N=64, K=4096, blkM=blkM, blkN=blkN, num_CTA=4)
# # dram_trf = get_DRAM_traffic(batch_size=512, M=1024, N=512, K=4096, blkM=64, blkN=128, num_CTA=32768)
# # df = bmm_data[(bmm_data['batch_size'] == 512) & (bmm_data['M'] == 1024) & (bmm_data['N'] == 512) & (bmm_data['K'] == 4096)]
# # print(dram_trf / float(df['dram_read_transactions']))

In [None]:
# def estimate_runtime(batch_size, M, N, K, clf=None, use_th_peak=True):
#     kernel_name = predict_kernel_name([batch_size, M, N, K])
#     blkM, blkN, blkK, block_x, block_y, block_z = get_blks(batch_size, M, N, K, clf=clf) # TODO: fix block_z for batch_size = 1
#     blkWM, blkWN = info[kernel_name]['blkWM'], info[kernel_name]['blkWN']
#     num_CTA = block_x * block_y * block_z
#     throughput = peak_throughput if use_th_peak else corrected_peak_throughput
#     num_main_loop_iteration = div_round_up(num_CTA, num_SM)
#     num_warps = info[kernel_name]['num_warps_per_CTA']
#     num_active_CTA = info[kernel_name]['num_active_CTA']
#     DRAM_traffic = get_DRAM_traffic(batch_size, M, N, K, blkM, blkN, num_CTA, num_active_CTA) # Total

#     LAT_DRAM = 1029 / frequency # See https://arxiv.org/pdf/1804.06826.pdf page 20
#     LAT_SMEM = 19 / frequency # See https://arxiv.org/pdf/1804.06826.pdf page 19
# #     t_DRAM_2_Regs_output = LAT_DRAM + (DRAM_traffic / (block_x * block_y * block_z)) * 4 / (peak_DRAM_BW / num_SM) / 1000
# #     t_Regs_2_SMEM_output = LAT_SMEM + (blkM * blkN) * 4 / (peak_SMEM_BW) / 1000 # Good
# #     t_SMEM_2_Regs_input = (blkWM + blkWN) * blkK * 4 * (num_warps) / peak_SMEM_BW / 1000 # Good

# #     t_GLS = t_DRAM_2_Regs_output + t_Regs_2_SMEM_output
#     t_SAS = (blkWM + blkWN) * blkK * 4 * (num_warps) / peak_SMEM_BW / 1000
#     t_CS = blkM * blkN * blkK / throughput / 1000

#     t_Prologue = (LAT_DRAM + (blkM * blkN) * 4 / (peak_DRAM_BW / num_SM) / 1000) + \
#                     (LAT_SMEM + (blkM * blkN) * 4 / (peak_SMEM_BW / 2) / 1000) + \
#                     ((blkWM + blkWN) * blkK * 4 * (num_warps) / (peak_SMEM_BW / 2) / 1000)
#     t_Epilogue = (blkM * blkN) * 4 / peak_DRAM_BW / 1000

#     # Bounded by CS or SMEM
#     t_MainLoop_Compute = max(t_CS, t_SAS) * div_round_up(K, blkK) # Good
#     t_Compute_or_SMEM = t_Prologue + (t_MainLoop_Compute + t_Epilogue) * num_main_loop_iteration # Good

#     # Bounded by latency
#     t_MainLoop_Latency = max(t_CS, t_SAS) * num_active_CTA * div_round_up(K, blkK)
#     t_Latency = t_Prologue + (DRAM_traffic / peak_DRAM_BW / 1000 + (t_MainLoop_Latency + t_Epilogue) * num_main_loop_iteration) / num_active_CTA

#     # Bounded by memory traffic (only considering DRAM for now)
#     t_MainLoop_Mem = max(L2_traffic / peak_L2_BW / 1000, DRAM_traffic / peak_DRAM_BW / 1000)
#     t_Mem = t_Prologue + t_MainLoop_Mem + t_Epilogue * num_main_loop_iteration

#     # Final time
#     t_final = max(t_Compute_or_SMEM, t_Latency, t_Mem)
    
#     # Bound factor
#     if t_final == t_Compute_or_SMEM:
#         bound = 'compute'
#     elif t_final == t_Latency:
#         bound = 'latency'
#     else:
#         bound = 'memory'
    
#     return t_final, bound, DRAM_traffic
# estimate_runtime(batch_size=1, M=64, N=64, K=4096, clf=clf, use_th_peak=True)