In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
df = list()

for i in snakemake.input['runtime']:
    _df = pd.read_csv(i)
    i = i.split('_')
    _df['model'] = i[2].replace('model=', '')
    _df['Quantization'] =  i[3].replace('q=', '')
    df.append(_df)
    
df = pd.concat(df)

In [None]:
df['Implementation'] = np.where(df['model'].str.endswith('e'), 'Efficient', 'Original')
df['Quantization'] = df['Quantization'].str.replace('none', 'bf16')
df['Model Size'] = df['model'].str.replace('e', '')

In [None]:
_df = df.replace('none', 'bf16') \
  .rename(columns={'length': 'Sequence Length', 'runtime': 'Runtime (sec)'})

_df[_df['Runtime (sec)'] != -1] \
  [['Model Size', 'Sequence Length', 'Quantization', 'Runtime (sec)', 'Implementation']] \
  .to_csv(snakemake.output['table'], index=False)

In [None]:
palette = sns.color_palette()

palette = {
    i: p
    for i, p in zip(['8M', '35M', '150M', '650M', '3B', '15B'], palette)
}

In [None]:
plt.figure(figsize=(6, 2), dpi=300)
ax = sns.lineplot(data=df[(df['Quantization'] == 'bf16') & (df['runtime'] != -1)], 
             x='length', y='runtime', hue='Model Size', style='Implementation')
sns.despine()
plt.xlabel('Model Parameter Size')
plt.ylabel('Runtime (Seconds)')
plt.xlim((0, 3550))
plt.legend(loc=(-0.11, 1.03), ncol=5, fontsize=8)

oom_limit = df.groupby('model')['runtime'].max() 
_df = df[df['runtime'] == -1].drop_duplicates('model')
for i, row in enumerate(_df.itertuples()):
    ax.text(row.length + 50, oom_limit[row.model] - .5, 'OOM',
            color=palette[row.model], ha='center')

plt.savefig(snakemake.output['fig'], bbox_inches='tight', dpi=300)

In [None]:
plt.figure(figsize=(6, 2), dpi=300)
sns.boxplot(data=df[df['Implementation'] == 'Efficient'], x='Model Size', y='runtime', hue='Quantization')
sns.despine()
plt.ylabel('Runtime (Seconds)')
plt.savefig(snakemake.output['fig_quantize'], bbox_inches='tight', dpi=300)