In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
df = list()
df_quantize = list()

for i in snakemake.input['memory']:
    _df = pd.read_csv(i)
    i = i.split('_')
    _df['model'] = i[2].replace('model=', '')
    _df['Quantization'] =  i[3].replace('q=', '')
    df.append(_df)
    
df = pd.concat(df)

In [None]:
df['Efficient'] = df['model'].str.endswith('e')
df['Model Size'] = df['model'].str.replace('e', '')

In [None]:
_df = df.replace('none', 'bf16') \
  .rename(columns={'length': 'Sequence Length', 'mem_gb': 'Memory (GB)'})

_df[df['mem_gb'] != -1] \
  [['Model Size', 'Sequence Length', 'Quantization', 'Memory (GB)', 'Efficient']] \
  .to_csv(snakemake.output['table'], index=False)  

In [None]:
palette = sns.color_palette()

palette = {
    i: p
    for i, p in zip(['8M', '35M', '150M', '650M', '3B', '15B'], palette)
}

In [None]:
plt.figure(figsize=(6, 2), dpi=300)
ax = sns.lineplot(data=df[(df['Quantization'] == 'none') & (df['mem_gb'] != -1)], 
                  x='length', y='mem_gb', hue='Model Size', 
                  style='Efficient', style_order=[True, False])
sns.despine()
plt.xlabel('Protein Length (Number of Residues)')
plt.ylabel('Memory Usage (GB)')
plt.xlim(0, 3510)

oom_limit = 80
plt.axhline(y=oom_limit, linestyle='--', color='black')
plt.ylim(-1, oom_limit + 5)

_df = df[df['mem_gb'] == -1] 
for i, row in enumerate(_df.itertuples()):
    print(row)
    ax.text(row.length, oom_limit + 1, 'OOM', 
            color=palette[row.model], ha='center', rotation=60)

plt.legend([],[], frameon=False)
plt.savefig(snakemake.output['fig'], bbox_inches='tight', dpi=300)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2), dpi=300, width_ratios=[4, 2])
plt.subplots_adjust(wspace=0.15)

_df = df[df['Efficient']].replace({'none': 'bf16'})
sns.boxplot(data=_df[~_df['Model Size'].isin({'3B', '15B'})], x='Model Size', y='mem_gb', hue='Quantization', ax=axs[0])
sns.boxplot(data=_df[_df['Model Size'].isin({'3B', '15B'})], x='Model Size', y='mem_gb', hue='Quantization', ax=axs[1])
axs[0].legend(fontsize=9, title='Quantization', title_fontsize=9)
axs[1].legend().set_visible(False)
axs[0].set_xlabel(None)
axs[1].set_xlabel(None)
axs[0].set_ylabel('Memory Usage (GB)')
axs[1].set_ylabel(None)
axs[0].text(1.5, -1.1, 'Model Parameter Size')
sns.despine()
oom_limit = 49.1
plt.savefig(snakemake.output['fig_quantize'], bbox_inches='tight', dpi=300)