In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
from IPython.display import display, HTML
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json, sys, os

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.max_open_warning'] = 50
display(HTML("<style>.container { width:100% !important; }</style>"))
PM_HOME = os.getcwd() + "/../../"

In [None]:
DATASET_ROOT = "/nvme/deep-learning/dlrm_datasets/embedding_bag"

In [None]:
# Global reuse_factor of all tables all batches
locality_file_path = "{}/2021/locality_stats.txt".format(DATASET_ROOT)
with open(locality_file_path) as f:
    found = False
    for line in f.readlines():
        if line.startswith('['):
            if not found:
                bins = [int(l) for l in line.lstrip('[').rstrip(']\n').split(', ')]
                bins
                found = True
            else:
                pdf = [float(l.strip("'")) for l in line.lstrip('[').rstrip(']\n').split(', ')]
bins.append(bins[-1] * 2)

In [None]:
# Per-table reuse_factor of all batches
table_config_path = "{}/2021/fbgemm_t856_bs65536_configs.json".format(DATASET_ROOT)
with open(table_config_path) as f:
    table_configs = json.load(f)["tables"]
fig, ax = plt.subplots()
for t in table_configs:
    ax.plot([2 ** idx for idx in range(len(bins))], [0] + np.cumsum([t["bin_{}".format(idx)] for idx in range(len(bins[:-1]))]).tolist())
    ax.set_xscale('log')

In [None]:
# Per-table per-batch reuse_factor
rf_file = '{}/data/A100/kernel/embedding_lookup_fbgemm_dlrm_datasets_rf.csv'.format(PM_HOME)
with open(rf_file) as f:
    df = pd.read_csv(f).head(6000)

stats = [defaultdict(list), defaultdict(list)]
for idx, d in df.iterrows():
    for t_idx, rfs, year in zip(
        d['num_embeddings'].split('-'),
        d['reuse_factors'].split('_'),
        [d['dataset_path'].split('/')[-2]],
    ):
        stats[int(year) - 2021][int(t_idx)].append([float(x) for x in rfs.split('-')])
stats[0] = sorted(stats[0].items(), key=(lambda x: -len(x[1])))
stats[1] = sorted(stats[1].items(), key=(lambda x: -len(x[1])))

In [None]:
fig, axs =  plt.subplots(2, 5, figsize=(30, 11))
for y, year in enumerate([2021, 2022]):
    for i in range(5):
        ax = axs[y][i]
        tid = stats[y][i][0]
        t = table_configs[tid]
        ax.plot(
            [2 ** idx for idx in range(len(bins))],
            [0] + np.cumsum([t["bin_{}".format(idx)] for idx in range(len(bins[:-1]))]).tolist(),
            color='red'
        )
        for s in stats[y][i][1]:
            ax.plot(
                [2 ** idx for idx in range(len(bins))],
                [0] + np.cumsum(s).tolist(),
                color='gray', linewidth=0.2
            )
        ax.set_title("Table {}".format(tid), fontsize=28)
        ax.set_xscale('log')
        ax.tick_params(axis='x', labelsize=20)
        ax.tick_params(axis='y', labelsize=20)
plt.tight_layout()
plt.savefig('./batch_vs_overall_rf.pdf', bbox_inches='tight')
plt.savefig('./batch_vs_overall_rf.png', bbox_inches='tight')

In [None]:
fig, axs =  plt.subplots(2, 1, figsize=(10, 5))
half_width = 0.5
bs = [0, 1, 1.00001, 2, 5, 10, 15, 20, 30, 40, 50, 80, 100, 1000]
heavy_lookup = 20
heavy_border = bs.index(heavy_lookup)
for idx, year in enumerate([2021, 2022]):
    ax = axs[idx]
    with open("{}/{}/merged_simple_configs.json".format(DATASET_ROOT, year), "r") as f:
        configs = json.load(f)["tables"]
        num_tables = len(configs)
    Ls = [float(x["pooling_factor"]) for x in configs]
    counts = [0] * (len(bs) - 1)
    for L in Ls:
        for idx in range(len(bs) - 1):
            if bs[idx] <= L and L < bs[idx+1]:
                counts[idx] += 1
    heavy_counts = sum(counts[heavy_border:])
    ax.bar(
        list(range(len(bs) - 1)), counts,
        width=1,
        color=plt.get_cmap("tab20c")(18),
        edgecolor="black"
    )
    ax.set_xticks([x - half_width for x in list(range(len(bs)))])
    ax.set_xticklabels(bs)
    ax.set_title("{} ({} tables)".format(year, num_tables))
    ax.axvline(
        heavy_border - half_width,
        linestyle="--",
        color="gray"
    )
    ax.text(
        heavy_border, max(counts) * 0.95,
        "|> {} ({:.2f}%) heavy tables".format(heavy_counts, 100 * heavy_counts / num_tables)
    )
plt.tight_layout()
plt.savefig('./dataset_histogram.pdf', bbox_inches='tight')
plt.savefig('./dataset_histogram.png', bbox_inches='tight')