In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="whitegrid")


In [None]:
df = list()

for i in snakemake.input:
    _df = pd.read_csv(i)
    wld = dict(
        j.split('=')
        for j in i.split('memory_usage_')[1].replace('.csv', '').split('_')
    )
    for k, v in wld.items():
        _df[k] = v

    _df = _df[_df['length'] <= 2000]
    df.append(_df)
    
df = pd.concat(df)

In [None]:
df['efficient'] = np.where(df['model'].str.contains('e'), 'Efficient', 'Baseline')
df['model'] = df['model'].str.replace('e', '')
df['q'] = df['q'].replace('none', 'bf16').replace('8bit', 'int8')

In [None]:
_df = df.rename(columns={
    'model': 'Model Size', 
    'c': 'Checkpointing', 
    'lora': 'Lora',
    'd': 'DeepSpeed',
    'mem_gb': 'Memory (GB)',
    'efficient': 'Implementation',
    'length': 'Sequence Length',
}).replace({'none': ''})
_df[_df['Memory (GB)'] != -1][[
    'Sequence Length', 'Model Size', 'Implementation', 
    'Checkpointing', 'Lora', 'DeepSpeed', 'Memory (GB)'
]].to_csv(snakemake.output['table'], index=False)

In [None]:
df = df.set_index(['model', 'b', 'q', 'c', 'lora', 'd', 'efficient'])

for index in df[df['mem_gb'] == -1].index:
    df.loc[index, 'mem_gb']= np.nan

df = df.reset_index()

In [None]:
df_lora = df[df['lora'] != 'none']
df = df[df['lora'] == 'none'] 

In [None]:
palette = sns.color_palette()

palette = {
    i: p
    for i, p in zip(['8M', '35M', '150M', '650M', '3B', '15B'], palette)
}

In [None]:
df['method'] = 'Baseline'
df['method'] = np.where(df['efficient'] == 'Efficient', '+FlashAttention', df['method'])
df['method'] = np.where(df['c'] == 'True', '+Checkpointing', df['method'])
df['method'] = np.where(df['d'] == 'True', '+Zero-Offload \n(DeepSpeed Stage 2)', df['method'])

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), dpi=300, sharey=True, gridspec_kw={'width_ratios': [2, 1]})

def errorbar(x):
    return (x.min(), x.max())

sns.pointplot(data=df, x='method', y='mem_gb', hue='model', palette=palette, 
              ax=ax1, errorbar=errorbar, dodge=True) 

oom_limit = 81.9
order = ['Baseline', '+FlashAttention', '+Checkpointing', 
         '+Zero-Offload \n(DeepSpeed Stage 2)']

df_oom = df[df['mem_gb'].isna()]

for q, _df in df_oom.groupby('method'):
    _df = _df.drop_duplicates('model')
    n = _df.shape[0]
    for i, row in enumerate(_df.itertuples()):
        ax1.text(.2 + order.index(q) + (i - n / 2) * .15, oom_limit + 1, 'OOM', 
                color=palette[row.model], ha='center', rotation=60)

ax1.set_xticklabels(order, rotation=10)
ax1.axhline(y=oom_limit, linestyle='--', color='black')
ax1.set_xlabel(None)
ax1.set_ylabel('Memory Usage (GB)')
ax1.legend(title=None)

order = ['bf16', 'int8']
sns.pointplot(data=df_lora[~df_lora['mem_gb'].isna()], x='q', y='mem_gb', 
              hue='model', dodge=True, order=order, palette=palette, 
              legend=False, ax=ax2, errorbar=errorbar)

df_oom = df_lora[df_lora['mem_gb'].isna()]

ax2.axhline(y=oom_limit, linestyle='--', color='black')
ax2.set_xlabel('LoRA Quantization')
ax2.set_ylabel('Memory Usage (GB)')

for q, _df in df_oom.groupby('q'):
    _df = _df.drop_duplicates('model')
    n = _df.shape[0]
    for i, row in enumerate(_df.itertuples()):
        print(n, i)
        ax2.text(order.index(q) + (i - n / 2) * .01 , oom_limit + 1, 'OOM', 
                color=palette[row.model], ha='center', rotation=45)

plt.ylim(-1, oom_limit + 5)
plt.savefig(snakemake.output['fig_memory'], dpi=300, bbox_inches='tight', transparent=True)