In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from esme.data import read_fai

In [None]:
df = list()

for i in snakemake.input['runtime']:
    _df = pd.read_csv(i, sep='\t')
    wld = dict(
        j.replace('.lightning', '').split('=')
        for j in i.split('runtime_')[1].replace('.txt', '').split('_')
    )
    for k, v in wld.items():
        _df[k] = v
    df.append(_df)
    
df = pd.concat(df)

In [None]:
df['Implementation'] = np.where(df['model'].str.contains('e'), 'Efficient', 'Original')
df['model'] = df['model'].str.replace('e', '')
df['q'] = df['q'].str.replace('none', 'bf16')
df['hours'] = df['s'] / 60 / 60

In [None]:
_df = df.rename(columns={
    'model': 'Model Size', 
    'c': 'Checkpointing', 
    'lora': 'Lora',
    'deepspeed': 'DeepSpeed',
    'hours': 'Runtime (hour)'
}).replace({'none': ''})
_df[['Model Size', 'Checkpointing', 'Lora', 'Implementation', 'Runtime (hour)']] \
    .to_csv(snakemake.output['stats'], index=False)

In [None]:
df_finetune = df[df['lora'] != 'none']
df = df[df['lora'] == 'none']

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
sns.set_theme(style="whitegrid")
df_finetune['_kb'] = df_finetune['hours']
sns.barplot(df_finetune, x='model', y='_kb', hue='q')
plt.ylabel('One Epoch on UniProtKB (hours)', fontsize=11)
plt.xlabel('Finetuning of Model Size')
plt.legend(title=None)
plt.savefig(snakemake.output['fig_lora'], bbox_inches='tight', dpi=300, transparent=True)

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
sns.set_theme(style="whitegrid")
df['_kb'] = df['hours']
sns.barplot(df, x='model', y='_kb', hue='Implementation')
plt.ylabel('One Epoch on UniProtKB (hours)', fontsize=11)
plt.xlabel('Traning of Model Size')
plt.savefig(snakemake.output['fig'], bbox_inches='tight', dpi=300, transparent=True)

In [None]:
df_fai_uniprotkb = read_fai(snakemake.input['fai_uniprotkb'])
df_fai_uniprot50 = read_fai(snakemake.input['fai_uniprot50'])

num_token_uniprotkb = df_fai_uniprotkb['length'].sum()
num_token_uniprot50 = df_fai_uniprot50['length'].sum()

pd.DataFrame({
    'tokens': [num_token_uniprotkb, num_token_uniprot50],
    'dataset': ['UniProtKB', 'UniProt50'],
    'num_proteins': [df_fai_uniprotkb.shape[0], df_fai_uniprot50.shape[0]]
}).to_csv(snakemake.output['token_stats'], index=False, sep='\t')

In [None]:
plt.figure(figsize=(4, 4), dpi=300)
sns.set_theme(style="whitegrid")
df['_50'] = df['hours'] * num_token_uniprot50 / num_token_uniprotkb
sns.barplot(df, x='model', y='_50', hue='Implementation')
plt.ylabel('Estimated One Epoch on UniRef50 (hours)', fontsize=11)
plt.xlabel('Training of Model Size')
plt.legend(title=None)
plt.savefig(snakemake.output['fig_estimate'], bbox_inches='tight', dpi=300, transparent=True)