In [None]:
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from esme.alphabet import padding_idx
from esme.data import FastaTokenDataset, FastaDataset

In [None]:
dl_protein = FastaDataset(snakemake.input['fasta']).to_dataloader(batch_size=16, drop_last=True)

In [None]:
dl_token = FastaTokenDataset(snakemake.input['fasta'], token_per_batch=100_000, drop_last=True).to_dataloader()

In [None]:
df_protein = list()

for i, tokens in tqdm(enumerate(dl_protein)):
    df_protein.append({
        'step': i,
        'protein': tokens.shape[0],
        'tokens': (tokens != padding_idx).sum().item()
    })

df_protein = pd.DataFrame(df_protein)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4), dpi=300)

plt.suptitle('Batch Size by Proteins')
sns.lineplot(data=df_protein, x='step', y='tokens', ax=axs[0])
axs[0].set_ylabel('Number of Tokens per Batch')
axs[0].text(0.6, 0.75, f"min: {df_protein['tokens'].min()}\nmax: {df_protein['tokens'].max()}\nmean: {df_protein['tokens'].mean():.0f}\nstd: {df_protein['tokens'].std():.0f}", transform=axs[0].transAxes)
sns.lineplot(data=df_protein, x='step', y='protein', ax=axs[1])
axs[1].set_ylabel('Number of Proteins per Batch')
sns.despine()
plt.savefig(snakemake.output['fig_protein'], bbox_inches='tight', dpi=300, transparent=True)

In [None]:
df_token = list()

for i, tokens in tqdm(enumerate(dl_token)):
    df_token.append({
        'step': i,
        'protein': tokens.shape[0],
        'tokens': (tokens != padding_idx).sum().item()
    })

df_token = pd.DataFrame(df_token)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4), dpi=300)

plt.suptitle('Batch Size by Token')
sns.lineplot(data=df_token, x='step', y='tokens', ax=axs[0])
axs[0].set_ylabel('Number of Tokens per Batch')
axs[0].text(.1, .1, f"min: {df_token['tokens'].min()}\nmax: {df_token['tokens'].max()}\nmean: {df_token['tokens'].mean():.0f}\nstd: {df_token['tokens'].std():.0f}", transform=axs[0].transAxes)
axs[0].set_ylim(0, 100_500)
sns.lineplot(data=df_token, x='step', y='protein', ax=axs[1])
axs[1].set_ylabel('Number of Proteins per Batch')
axs[1].text(.1, .1, f"min: {df_token['protein'].min()}\nmax: {df_token['protein'].max()}\nmean: {df_token['protein'].mean():.2f}\nstd: {df_token['protein'].std():.2f}", transform=axs[1].transAxes)
sns.despine()
plt.savefig(snakemake.output['fig_tokens'], bbox_inches='tight', dpi=300, transparent=True)