In [78]:
from __future__ import absolute_import, division, print_function, unicode_literals
from aiplatform.monitoring.atc import antipattern_detection, trace_utils
from pprint import pprint
from sklearn.linear_model import LinearRegression as LR
from IPython.core.display import display, HTML
from scipy.stats.mstats import gmean 
import argparse, logging, tempfile, json, sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', 100)
plt.rcParams['figure.max_open_warning'] = 100
pd.options.mode.chained_assignment = None

display(HTML("<style>.container { width:100% !important; }</style>"))

In [62]:
# Util functions
def div_round_up(x, y):
    return (x + y - 1) // y

def error_mean_std(predicted, actual):
    error = (predicted - actual).abs() / actual
    return error.mean(), error.std()

def p2f(x):
    return float(x.strip('%')) / 100

def remove_bw_suffix(x):
    if x.endswith('GB/s'):
        return float(x.strip('GB/s'))
    elif x.endswith('MB/s'):
        return float(x.strip('MB/s')) / 1000
    # B/s
    return float(x.strip('B/s')) / 1000000000

def strip_util(x):
    return float(x.strip('(').strip(')'))

def preprocessing(stats):
    stats = stats.dropna()
    return stats

def histogram(df, buckets, percentage=True):
    for idx, bk in enumerate(buckets):
        if idx > 0:
            if percentage:
                print("{}-{}%, {:.2f}%".format(buckets[idx-1] * 100, bk * 100, 100 * len(df[(df.abs() < bk) & (df.abs() > buckets[idx-1])]) / len(df)))
            else:
                print("{}-{}, {:.2f}%".format(buckets[idx-1], bk, 100 * len(df[(df < bk) & (df > buckets[idx-1])]) / len(df)))
                
# From Louis. Trim a long trace so that it eases the ATC processing
def trim_trace(file_name, start, end):
    assert (0 <= start and start <= 1 and 0 <= end and end <= 1 and start <= end)
    with open(file_name) as trace_file:
        trace = json.load(trace_file)
        min_time = sys.maxsize
        max_time = 0

        for event in trace:
            # print(event['ts'])
            min_time = min(min_time, event['ts'])
            max_time = max(max_time, event['ts'])

        print("time range: {} {}".format(min_time, max_time))
        time_range = max_time - min_time
        offset_start = start * time_range
        offset_end = end * time_range
        # offset from the start to the trimmed end
        max_time = min_time + offset_end
        # move the min time to the offset start
        min_time += offset_start
        print("trimmed time range: {} {}".format(min_time, max_time))
        trimmed_trace = [x for x in trace if x['ts'] > min_time and x['ts'] < max_time]
        with open("trace_trimmed.json", 'w') as out_file:
            json.dump(trimmed_trace, out_file)

# Code copied from //aiplatform/monitoring/atc
def run_ATC():
    # Initiate the logger
    FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
    logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
    logger: logging.Logger = logging.getLogger("atc")
    logger.setLevel(logging.INFO)

    # Disable logging if necessary
    logging.disable(sys.maxsize)

    base_trace = "./trace_trimmed.json"

    trace = trace_utils.load_trace_json_file(base_trace)
    base_trace_dir: str = tempfile.mkdtemp(prefix="base-trace_")
    (
        iteration_start_base,
        iteration_end_base,
        all_events_base,
        per_process_events_base,
        per_thread_events_base,
    ) = trace_utils.parse_trace_json(trace, base_trace_dir)
    trace_utils.extract_insights_from_trace(
        base_trace_dir,
        all_events_base,
        per_process_events_base,
        per_thread_events_base,
        iteration_start_base,
        iteration_end_base,
    )

    antipattern_detection.save_all_antipatterns(
        all_events_base,
        per_process_events_base,
        per_thread_events_base,
        output_dir=base_trace_dir,
    )
    logger.info("output directory for base trace: {}".format(base_trace_dir))

    return base_trace_dir

def list_to_tuple(lst):
    return tuple(list_to_tuple(l) if isinstance(l, list) else l for l in lst) if lst is not None else None

class Event:
    def __init__(self, e, dummy=False):
        if dummy:
            self.event = {
                "name": "dummy",
                "ts": -1,
                "dur": -1,
                "cat": "Runtime",
                "args": {}
            }
        else:
            assert (type(e) == dict)
            self.event = e
        self.parent = None
        self.children = []
        self.has_device_calls = False
    def __str__(self):
        return json.dumps(self.event, sort_keys=True, indent=4, separators=(',', ': '))
#     def __repr__(self):
#         return json.dumps(self.event, sort_keys=True, indent=4, separators=(',', ': '))
    def start_time(self):
        if "ts" not in self.event.keys():
            return None
        return self.event["ts"]
    def duration(self):
        if "dur" not in self.event.keys():
            return None
        return self.event["dur"]
    def category(self):
        if "cat" not in self.event.keys():
            raise TypeError("Unknown event type!")
        return self.event["cat"]
    def name(self):
        if "name" not in self.event.keys():
            raise TypeError("Name lost!")
        return self.event["name"]
    def is_sub_of(self, other):
        assert (self.start_time() is not None and \
                self.duration() is not None and \
                other.start_time() is not None and \
                other.duration() is not None)
        ls = other.start_time()
        le = other.start_time() + other.duration()
        es = self.start_time()
        ee = self.start_time() + self.duration()
        return ls <= es and le >= ee
    def input_shape(self):
        if "args" not in self.event.keys() or "Input dims" not in self.event["args"].keys():
            return (-1,)
        return list_to_tuple(self.event["args"]["Input dims"])
    def output_shape(self):
        if "args" not in self.event.keys() or "Output dims" not in self.event["args"].keys():
            return (-1,)
        return list_to_tuple(self.event["args"]["Output dims"])
    def external_id(self):
        if "args" not in self.event.keys():
            return None

        if ("External id" not in self.event["args"].keys() and \
             "external id" not in self.event["args"].keys()):
            raise TypeError("External id lost!")
        
        if self.category() == "Operator":
            return self.event["args"]["External id"]
        else:
            return self.event["args"]["external id"]
    def correlation_id(self):
        if "args" not in self.event.keys() or self.category() == "Operator":
            return None

        if ("correlation" not in self.event["args"].keys()):
            raise TypeError("Correlation id lost!")
        return self.event["args"]["correlation"]
    def device(self):
        if "args" not in self.event.keys() or \
            ("Device" not in self.event["args"].keys() and \
            "device" not in self.event["args"].keys()):
            return None
        if "Device" in self.event["args"].keys():
            return self.event["args"]["Device"]
        else:
            return self.event["args"]["device"]
    def stream(self):
        if "args" not in self.event.keys() or "stream" not in self.event["args"].keys():
            return None
        return self.event["args"]["stream"]

In [45]:
# Construct a forest to represent the event hierarchy as well as a data structure to hold the relation between ops and device calls
########## cc #########
# {
#     ex_id1 : {
#         caller: - (an op that has one or multiple device calls)
#         callees: {
#             cr_id1: {
#                 launcher: - (cudaKernelLaunch)
#                 executor: - (device kernel)
#             }
#             ...
#         }
#     }
#     ...
# }
def process_event_hierarchy(two, skip_module=False, module_marker="## "):
    
    # Get the "grandest child" event of a given leaf
    # e.g. |------------ A --------------| The leaf event in the frontier currently being accessed
    #         |------------B-----------|
    #            |-----C------| The current "grandest child" of A, since D hasn't been added as A's child yet
    #               |---D---| The event currently being processed
    def get_grandest_child_event(leaf, event, depth=1):
        if not event.is_sub_of(leaf):
            return None
        ret = leaf
        for c in leaf.children:
            grandest = get_grandest_child_event(c, event, depth+1)
            if grandest is not None:
                ret = grandest
                break
        return ret

    roots = [] # All the root events that have no parents
    leaves = [] # The event frontier of the processing
    unaccounted = [] # Unaccounted events (not being used now)
    cc = {} # caller / callee: key = external id, value = { caller event, callee events }
    
    # Sort the event lists and remove all events without a duration
    duration_none = [e for e in two if "dur" not in e.keys()]
    sorted_events = [Event(e) for e in two if e not in duration_none]
    sorted_events = sorted(sorted_events, key=lambda x: (x.start_time(), -x.duration()))
    
    # Remove all leftovers from the last iteration and next iteration
    start_idx = 0
    end_idx = len(sorted_events) - 1
    corrected_start_time = sorted_events[0].start_time()
    corrected_end_time = sorted_events[-1].start_time()
    # Start the analysis from the first module detected, if module is not to be skipped
    for idx, x in enumerate(sorted_events):
        ######## IMPORTANT ########
        # Find the start of an iteration started with "##" without ":". The first module should be "## zero_grad ##" though, 
        # but the current ATC code couldn't start the extraction exactly at there. 
        # Change TORCH_AUTOGRAD_GRAPHROOT in ATC's trace_utils.py does the trick
        if not skip_module and x.name().startswith(module_marker) and ":" not in x.name():
            # The actual start time is the start time of the profiler enter call right before "zero_grad"
            for idy, y in enumerate(reversed(sorted_events[:idx])):
                if y.name() == "profiler::_record_function_enter":
                    start_idx = idx - idy
                    corrected_start_time = y.start_time()
                    break
            break

    # End the analysis at the last event that has a duration. Set the corrected end time later.
    for idx, x in enumerate(reversed(sorted_events)):
        if x.duration() is not None:
            end_idx = idx
            break
    sorted_events = sorted_events[start_idx:(len(sorted_events) - 1 - end_idx)]

    for x in sorted_events:
        # Get start, duration and end time of the current event
        event_start = x.start_time()
        event_duration = x.duration()
        external_id = x.external_id()
        correlation_id = x.correlation_id()

        # Runtime events e.g. cudaLaunchKernel counted as host events
        if x.category() == "Operator" or x.category() == "Runtime":
            if event_start is None or event_duration is None:
                print("Unaccounted event: {}".format(x.event))
                unaccounted.append(x)
                continue
            # Put all OPERATOR events with no device info into unaccounted (0 means None in the trace file)
            # This usually work for events like aten::pin_memory, etc
            if x.device() == 0:
                unaccounted.append(x)
                continue
                
            event_end = event_start + event_duration
            corrected_end_time = max(event_end, corrected_end_time)

            # Find parent of the current event from the frontier
            parent_found = False
            to_add_root = None
            to_add_leaf = None
            for l in leaves:
                leaf_start = l.start_time()
                leaf_end = leaf_start + l.duration()

                # The current event is sub to leaf
                if event_end <= leaf_end:
                    # Add this event to the GRANDEST CHILD of the leaf that can sub it
                    grandest = get_grandest_child_event(l, x)
                    x.parent = grandest
                    grandest.children.append(x)
                    to_add_leaf = x
                    parent_found = True
                    break
                # The current event has no overlap with leaf
                elif event_start >= leaf_end:
                    continue
                # Crossover shouldn't happen
                else:
                    pprint(str(x))
                    raise ValueError("\tCrossover happens!")

            # New root and leaf
            if not parent_found:
                to_add_root = x
                to_add_leaf = x
            if to_add_root:
                roots.append(to_add_root)
            if to_add_leaf:
                leaves.append(to_add_leaf)
            
            # Add op to caller or unaccounted
            if x.category() == "Operator":
                if external_id != 0:
                    if external_id not in cc.keys():
                        cc[external_id] = {}  
                    cc[external_id]["caller"] = x
                    cc[external_id]["callees"] = {}
            else: # Runtime
                if external_id != 0 and correlation_id != 0: # Not consider some events without ex_id and cr_id, e.g. cudaEventCreateWithFlags
                    if external_id not in cc.keys():
                        cc[external_id] = {}
                    if "caller" not in cc[external_id].keys():
                        cc[external_id]["caller"] = None
                    if "callees" not in cc[external_id].keys():
                        cc[external_id]["callees"] = {}
                    if correlation_id not in cc[external_id]["callees"].keys():
                        cc[external_id]["callees"][correlation_id] = {}
                        cc[external_id]["callees"][correlation_id]["launcher"] = None
                        cc[external_id]["callees"][correlation_id]["executor"] = None
                    cc[external_id]["callees"][correlation_id]["launcher"] = x
        else:
            # Skip modules if needed
            if (skip_module and x.name().startswith(module_marker)):
                continue
            else: # "cat" = "Memcpy" or "Kernel", i.e. callee
                if external_id != 0 and correlation_id != 0: # Doesn't consider some events without ex_id and cr_id, e.g. cudaEventCreateWithFlags
                    if external_id not in cc.keys():
                        cc[external_id] = {}
                    if "caller" not in cc[external_id].keys():
                        cc[external_id]["caller"] = None
                    if "callees" not in cc[external_id].keys():
                        cc[external_id]["callees"] = {}
                    if correlation_id not in cc[external_id]["callees"].keys():
                        cc[external_id]["callees"][correlation_id] = {}
                        cc[external_id]["callees"][correlation_id]["launcher"] = None
                        cc[external_id]["callees"][correlation_id]["executor"] = None
                    cc[external_id]["callees"][correlation_id]["executor"] = x
            
    # Set the corrected_end_time to be the last event's end time
    for x in reversed(roots):
        if x.duration() is not None:
            corrected_end_time = x.start_time() + x.duration()
            break
            
    # Update 'has_device_calls' for all events in the tree
    def update_has_device_calls(roots):
        for r in roots:
            ex_id = r.external_id()
            if len(r.children) == 0:
                if ex_id in cc.keys() and len(cc[ex_id]["callees"].keys()) != 0:
                    for k, v in cc[ex_id]["callees"].items():
                        if v["executor"] is not None:
                            r.has_device_calls = True
            else:
                update_has_device_calls(r.children)
                for c in r.children:
                    if c.has_device_calls:
                        r.has_device_calls = True
    update_has_device_calls(roots)

    return roots, cc, corrected_start_time, corrected_end_time

# Get root operators, not including modules
def get_operators(roots, ops):
    for r in roots:
        # Is an operator, and
        # Not a module or submodule, and
        # (Parent is a module, or, is simply a root operator)
        if r.category() == "Operator" and\
            (not r.name().startswith("## ")) and ((\
            r.parent is not None and\
            r.parent.name().startswith("## ")\
        ) or (\
            r.parent is None\
        )) :
            ops.append(r)
        else:
            get_operators(r.children, ops)

In [46]:
%%capture
trace_file = "./libgpumon_activities_425511.json"
trim_trace(trace_file, 0.90, 1.0)
base_trace_dir = run_ATC()
print(base_trace_dir)
with open(base_trace_dir + "/two_iteration_trace.json") as two:
    two_iteration_stats = json.load(two)

ops = []
roots, cc, corrected_start_time, corrected_end_time = process_event_hierarchy(two_iteration_stats, skip_module=False)
get_operators(roots, ops)

In [54]:
# Sizes of C*B and C'*A
def get_addmm_backward_size(op):
    sizes = []
    for x in op.children:
        if x.name() == "mm":
            size = x.input_shape()
            sizes.append((size[0][0], size[0][1], size[1][1]))
    return sizes

# Never seen BmmBackward0 having only one bmm. Possible though.
def get_bmm_backward_size(op):
    sizes = []
    for x in op.children:
        if x.name() == "bmm":
            size = x.input_shape()
            sizes.append((size[0][0], size[0][1], size[0][2], size[1][2]))
    return sizes

# Not working in new traces as the output_nr op calls are removed
def get_embedding_lookup_forward_size(op):
    sizes = []
    rows_per_block = None
    for x in op.children:
        if x.name() == "output_nr":
            sizes.append(x.input_shape())
        if x.name() == "cudaLaunchKernel":
            ex = cc[op.external_id()]["callees"][x.correlation_id()]["executor"]
            rows_per_block = ex.event["args"]["block"][1]
    D = int(sizes[0][0][1])
    T = int(sizes[1][0][0])
    E = int(sizes[0][0][0] / T)
    B = int((sizes[3][0][0] - 1) / T)
    L = int(sizes[2][0][0] / B / T)
    return B, E, T, L, D, rows_per_block

# Not working in new traces as the size op calls are removed
def get_embedding_lookup_backward_size(op):
    sizes = []
    rows_per_block = -1
    for x in op.children:
        if x.name() == "size":
            sizes.append(x.input_shape())
        if x.name() == "cudaLaunchKernel":
            ex = cc[op.external_id()]["callees"][x.correlation_id()]["executor"]
            rows_per_block = ex.event["args"]["block"][1]
    T = int(sizes[0][0][0])
    E = int(sizes[1][0][0] / T)
    D = int(sizes[1][0][1])
    B = int((sizes[2][0][0] - 1) / T)
    return B, E, T, _, D, rows_per_block

In [15]:
fc_stats = pd.read_csv("./fully_connected_forward.csv", delimiter=',')
fc_stats = preprocessing(fc_stats)
fc_stats = fc_stats[fc_stats["kernel_name"].str.startswith("volta")].reset_index(drop=True)

### Performance models

In [16]:
L2_size = 6 * 1024 * 1024 * 4
num_SM = 80
peak_dram_bw = 809 # GB/s
peak_l2_bw = 2888 # GB/s
peak_throughput = 12200 # GFLOPS

In [87]:
def embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, **kwargs):
    # hit_rate = C(X, L) / C(E, L), X = avg_num_rows_per_table
    def hit_rate(X, E, L):
        ret = 1.0
        e = E
        x = X
        for idx in range(L):
            ret *= x / e
            x -= 1
            e -= 1
        return ret

    # Average number of rows per table in L2
    y = kwargs
    num_total_warps = y["batch_size"] * y["num_tables"] # Total warp number of the kernel
    num_warps_per_sm = y["rows_per_block"] # Number of warps per sm
    num_warps_simul = num_SM * num_warps_per_sm # Total number of warps simultaneously running on the device
    num_tables_simul = (num_warps_simul + y["batch_size"] - 1) // y["batch_size"] # Number of tables simultaneously being accessed on the device
    avg_table_size = min(L2_size // num_tables_simul, y["num_embeddings"] * y["embedding_dim"] * 4) # Average table size that reside on the device
    indices_size = 0 if y["shmem"] else div_round_up(y["bag_size"] * 4, 32) * 32
    avg_num_rows_per_table = (avg_table_size - indices_size) // 4 // y["embedding_dim"]

    # Hit rate
    hr = hit_rate(avg_num_rows_per_table, y["num_embeddings"], y["batch_size"])
    
    # num_thread_x
    num_thread_x = max(y["embedding_dim"] / 4, 1024 / y["rows_per_block"])

    # Traffics
    table_offsets_traffic = 32
    offsets_traffic = 32
    if y["shmem"]:
        indices_dram_traffic = div_round_up(y["bag_size"] * 4, 32) * 32
        indices_l2_traffic = 0
    else: # no_shmem
        indices_dram_traffic = div_round_up(y["bag_size"] * 4, 32) * 32
        indices_l2_traffic = y["embedding_dim"] // (4 * num_thread_x) * div_round_up(y["bag_size"] * 4, 32) * 32
    table_traffic = y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32)
    output_traffic = (div_round_up(y["embedding_dim"] * 4, 32) * 32)

    # avg_table_size all as dram traffic
    # 21, 26, 13, 7, 4, 4, 24, (0.2 ± 0.21)
    total_l2_traffic = ((table_offsets_traffic + offsets_traffic + indices_l2_traffic) * y["batch_size"] + \
                        hr * (table_traffic * y["batch_size"] - avg_table_size)) * y["num_tables"]
    total_dram_traffic = ((indices_dram_traffic + output_traffic) * y["batch_size"] + \
                          (1 - hr) * (table_traffic * y["batch_size"] - avg_table_size) + avg_table_size) * y["num_tables"]

    return max(total_dram_traffic / peak_dram_bw / 1000.0, total_l2_traffic / peak_l2_bw / 1000.0)
        
# e.g. 4340 vs 4789
embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, batch_size=4096, num_embeddings=500000, num_tables=197, bag_size=32, embedding_dim=32, rows_per_block=128, shmem=True)

4340.7676440049445

In [75]:
def embedding_backward_sgd_predictor(peak_dram_bw, **kwargs):
    y = kwargs
    if y["shmem"]: # 40% GMAE...
        indices_traffic = div_round_up(y["bag_size"] * 4, 32) * 32
        grad_output_traffic = div_round_up(y["embedding_dim"] * 4, 32) * 32
    else: # backward_sgd_no_shmem
        indices_traffic = y["bag_size"] * 32
        grad_output_traffic = (y["bag_size"] * div_round_up(y["embedding_dim"] * 4, 32) * 32) * 2

    # Traffic per warp = t_offsets + t_table_offsets + t_indices + t_weights + t_grad_outputs
    total_traffic_per_warp = 32 + \
                            32 + \
                            indices_traffic + \
                            2 * y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32) + \
                            grad_output_traffic

    # Traffic = warp * traffic per warp
    total_traffic = y["batch_size"] * y["num_tables"] * total_traffic_per_warp

    # Total compute throughput
    mac_per_warp = y["bag_size"] * 4 * (y["embedding_dim"] // 4)
    total_mac = y["batch_size"] * y["num_tables"] * mac_per_warp

    return max(total_traffic / peak_dram_bw / 1000, total_mac / peak_throughput / 1000)

# e.g 86291 vs 77508
embedding_backward_sgd_predictor(peak_dram_bw, batch_size=2048, num_embeddings=200000, num_tables=128, bag_size=128, embedding_dim=128, rows_per_block=32, shmem=False)

86291.71294932015

In [77]:
def embedding_backward_rowwise_adagrad_approx_predictor(peak_dram_bw, **kwargs):
    y = kwargs
    
    # Traffic = warp * traffic per warp
    total_traffic_per_warp = 32 + \
                            32 + \
                            2 * (div_round_up(y["bag_size"] * 4, 32) * 32) + \
                            y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32) + \
                            (y["bag_size"] + 1) * (div_round_up(y["embedding_dim"] * 4, 32) * 32) + \
                            y["bag_size"] * (div_round_up(y["embedding_dim"] * 4, 32) * 32)
    total_traffic = y["batch_size"] * y["num_tables"] * total_traffic_per_warp

    return total_traffic / peak_dram_bw / 1000.0

# e.g 64226 vs 63304
embedding_backward_rowwise_adagrad_approx_predictor(peak_dram_bw, batch_size=2048, num_embeddings=200000, num_tables=128, bag_size=128, embedding_dim=128, rows_per_block=32)

64226.252103831896

In [11]:
def fc_forward_predictor(peak_dram_bw, peak_throughput, df, **kwargs):
    def get_record(df, **kwargs):
        row_count = df.shape[0]
        condition = pd.Series([True] * row_count)
        for k, v in kwargs.items():
            condition = condition & (df[k] == v)
        return df[condition]

    def get_closest(df, **kwargs):
        no_match = {}
        row_count = df.shape[0]
        condition = pd.Series([True] * row_count)
        for k, v in kwargs.items():
            if v in df[k].unique():
                condition = condition & (df[k] == v)
            else:
                no_match[k] = v

        # With matched dimensions
        data_points = [(df[condition], {})]

        # For each of the non-matched dimension
        for k, v in no_match.items():
            tmp = []
            for dp, limits in data_points:
                uni_val = sorted(dp[k].unique())

                low, high = -1, -1
                if v < uni_val[0]:
                    high = uni_val[0]
                elif v > uni_val[-1]:
                    low = uni_val[-1]
                else:
                    for idx in range(len(uni_val[:-1])):
                        if uni_val[idx] < v and uni_val[idx+1] > v:
                            high = uni_val[idx+1]
                            low = uni_val[idx]
                            break
                assert not (low == -1 and high == -1)

                less_tmp = dp[dp[k] == (low if low != -1 else uni_val[0])]
                more_tmp = dp[dp[k] == (high if high != -1 else uni_val[-1])]
                if low == -1:
                    less_tmp[k] = 0
                if high == -1:
                    more_tmp[k] = sys.maxsize # Big enough for BW in GB/s or throughput in GFLOPS
                tmp_limits = limits.copy()
                tmp_limits[k] = (low, high)
                tmp.append((less_tmp, tmp_limits))
                tmp.append((more_tmp, tmp_limits))
            data_points = tmp

        return data_points

    #####################
    #       |     X |
    #    |==O=======|
    #    |  |       |
    #    |  |-------O-
    #    |     X     |
    #    O           |
    #    |===========O
    #    |  X        |
    #####################

    record = get_record(df, **kwargs)
    if not record.empty:
        return record["kernel_runtime"].iloc[0]
    data_points = get_closest(df, **kwargs)

    effective_flops = 0.0
    effective_bw = 0.0
    batch_size = kwargs["batch_size"]
    M = kwargs["M"]
    N = kwargs["N"]
    K = kwargs["K"]
    for dp, limits in data_points:
        dp_flops_contrib = 0.0
        dp_bw_contrib = 0.0

        # An idea, if zero occurs, it's always the bottleneck. A zero dominates all peaks.
        zero_exists = False
        peak_exists = False
        for k, v in limits.items():
            metric = dp[k].iloc[0]
            if metric == 0:
                zero_exists = True
            elif metric == sys.maxsize:
                peak_exists = True

        for k, v in limits.items():
            low, high = v
            if high == -1: # Reaching the peak, taking average
                ratio_l, ratio_h = 0.5, 0.5
            elif low == -1: # Reaching the bottom, set low as 0
                ratio_l, ratio_h = kwargs[k] / high, (high - kwargs[k]) / high
            else: # Normal, weighted
                ratio_l, ratio_h = (kwargs[k] - low) / (high - low), (high - kwargs[k]) / (high - low)
            # Edge cases: when more than one metric is MAX/0

            metric = dp[k].iloc[0]
            if zero_exists:
                throughput = 0
                dram_bw = 0
                ratio = 0
            elif peak_exists:
                throughput = peak_throughput
                dram_bw = peak_dram_bw
                ratio = ratio_h
            elif metric == low:
                throughput = (dp["batch_size"] * dp["M"] * dp["N"] * dp["K"]).iloc[0] / dp["kernel_runtime"].iloc[0] / 1000 # GFLOPS
                dram_bw = (dp["batch_size"] * (dp["M"] * dp["K"] + dp["K"] * dp["N"] + dp["M"] * dp["N"])).iloc[0] / dp["kernel_runtime"].iloc[0] / 1000 * 4 # GB/s
                ratio = ratio_l 
            elif metric == high:
                throughput = (dp["batch_size"] * dp["M"] * dp["N"] * dp["K"]).iloc[0] / dp["kernel_runtime"].iloc[0] / 1000 # GFLOPS
                dram_bw = (dp["batch_size"] * (dp["M"] * dp["K"] + dp["K"] * dp["N"] + dp["M"] * dp["N"])).iloc[0] / dp["kernel_runtime"].iloc[0] / 1000 * 4 # GB/s
                ratio = ratio_h

            dp_flops_contrib += throughput * ratio
            dp_bw_contrib += dram_bw * ratio

        effective_flops += dp_flops_contrib / len(limits.items())
        effective_bw += dp_bw_contrib / len(limits.items())

    effective_flops /= len(data_points) / 2
    effective_bw /= len(data_points) / 2

    FLOP = kwargs["batch_size"] * kwargs["M"] * kwargs["N"] * kwargs["K"]
    DRAM_bytes = kwargs["batch_size"] * (kwargs["M"] * kwargs["K"] + kwargs["K"] * kwargs["N"] + kwargs["M"] * kwargs["N"]) * 4
    predicted_runtime = max(FLOP / effective_flops, DRAM_bytes / effective_bw) / 1000

    return predicted_runtime

# 5829 vs 7548
fc_forward_predictor(peak_dram_bw, peak_throughput, fc_stats, batch_size=256, M=512, N=1000, K=400)

5829.673816702252

In [83]:
total_time = 0.0
for op in ops:
    t = 0.0
    if op.name() == "addmm":
        size = op.input_shape()
        M, K, N = size[1][0], size[1][1], size[2][1]
        t = fc_forward_predictor(peak_dram_bw, peak_throughput, fc_stats, batch_size=1, M=M, N=N, K=K)
#         print("addmm", M, N, K, t)
    if op.name() == "bmm":
        size = op.input_shape()
        batch_size, M, K, N = size[0][0], size[0][1], size[0][2], size[1][2]
        t = fc_forward_predictor(peak_dram_bw, peak_throughput, fc_stats, batch_size=batch_size, M=M, N=N, K=K)
#         print("bmm", batch_size, M, N, K, t)
    if op.name() == "LookupFunction":
        B, E, T, L, D, rows_per_block = get_embedding_lookup_forward_size(op)
        lks = []
        for c in op.children:
            if c.name() == "cudaLaunchKernel":
                lks.append(c)
        callees = cc[lks[0].external_id()]["callees"]
        shmem = list(callees.values())[0]["executor"].name().split(',')[1].strip()
        t = embedding_forward_predictor(peak_dram_bw, peak_l2_bw, L2_size, num_SM, batch_size=B, num_embeddings=E, num_tables=T, bag_size=L, embedding_dim=D, rows_per_block=rows_per_block, shmem=shmem)
#         print("Embedding forward", t)
    if op.name() == "LookupFunctionBackward":
        B, E, T, _, D, rows_per_block = get_embedding_lookup_backward_size(op)
        L = 38 # TODO: Cannot get it from trace. Hard code it.
        sgd = False
        lks = []
        for c in op.children:
            if c.name() == "cudaLaunchKernel":
                lks.append(c)
        callees = cc[lks[0].external_id()]["callees"]
        kernel_name = list(callees.values())[0]["executor"].name()
        if "sgd" in kernel_name:
            sgd = True
        shmem = kernel_name.split(',')[1].strip()
        if sgd:
            t = embedding_backward_sgd_predictor(peak_dram_bw, batch_size=B, num_embeddings=E, num_tables=T, bag_size=L, embedding_dim=D, rows_per_block=rows_per_block, shmem=shmem)
        else:
            t = embedding_backward_rowwise_adagrad_approx_predictor(peak_dram_bw, batch_size=B, num_embeddings=E, num_tables=T, bag_size=L, embedding_dim=D, rows_per_block=rows_per_block)
#         print("Embedding backward", t)
    if op.name() == "AddmmBackward":
        sizes = get_addmm_backward_size(op)
        for size in sizes:
            M, K, N = size
            t += fc_forward_predictor(peak_dram_bw, peak_throughput, fc_stats, batch_size=1, M=M, N=N, K=K)
#         print("AddmmBackward", t)
    if op.name() == "BmmBackward0":
        sizes = get_bmm_backward_size(op)
        for size in sizes:
            batch_size, M, K, N = size
            t += fc_forward_predictor(peak_dram_bw, peak_throughput, fc_stats, batch_size=batch_size, M=M, N=N, K=K)
#         print("BmmBackward0", t)
    total_time += t
print("total_time:", total_time)

total_time: 12117.71546046202
