In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from workflow import model_color_map

sns.set_theme(style="whitegrid")


In [None]:
df = list()

for i in snakemake.input:
    _df = pd.read_csv(i)
    wld = dict(
        j.split('=')
        for j in i.split('memory_usage_')[1].replace('.csv', '').split('_')
    )
    for k, v in wld.items():
        _df[k] = v
    df.append(_df)
    
df = pd.concat(df).rename(columns={'q': 'quantization', 'mem_gb': 'Memory Usage (GB)'})

In [None]:
df['Model'] = df['model'].str.replace('e', '')
df = df[df['Memory Usage (GB)'] != -1]
df['Precision'] = df['quantization'].str.replace('none', '16bit')
df['Number of Tokens'] = df['length'] * df['b'].astype(int)

In [None]:
colors = ["#04a3bd", "#f0be3d", "#931e18", "#da7901", "#247d3f", "#20235b"]
model_color_map = {k:v for k, v in zip(model_color_map.keys(), reversed(colors))}

In [None]:
plt.figure(figsize=(6, 3), dpi=300)

ax = sns.FacetGrid(df, col='Model', hue='Precision').map(
    sns.lineplot, 'Number of Tokens', 'Memory Usage (GB)').add_legend()
oom_limit = 81.9
for ax in ax.axes.flat:
    ax.axhline(y=oom_limit, linestyle='--', color='black')
plt.ylim(-0.01, oom_limit + 1)
plt.savefig(snakemake.output['fig_memory_len_batch'], dpi=300, 
            bbox_inches='tight', transparent=True)