In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
df = list()

for i in snakemake.input:
    model = i.split('model=')[1].split('_')[0]

    _df = pd.read_csv(i)
    _df = pd.read_csv(i, header=None, names=['protein_id', 'Pseudo-Perplexity'], sep='\t')
    _df['efficient'] = 'Efficient' if 'e' in model else 'Original'
    _df['model'] = model.replace('.tsv', '').replace('e', '').replace('1v5', '1v').replace('1v4', '1v') \
        .replace('1v3', '1v').replace('1v2', '1v').replace('1v1', '1v')
        # .replace('c300m', '300M').replace('c600m', '600M')
    
    df.append(_df)

df = pd.concat(df)

In [None]:
from sklearn.metrics import mean_squared_error
from more_itertools import flatten

models = ['1b', '1v', '8M', '35M', '150M', '650M', '3B', "15B", 'c300m', 'c600m']

fig, axs = plt.subplots(2, len(models) // 2, figsize=(len(models), 3), 
                        dpi=300, sharex=True, sharey=True)

axs = flatten(axs)

model_label = {
    '1b': 'ESM-1b', 
    '1v': 'ESM-1v', 
    '8M': 'ESM2-8M', 
    '35M': 'ESM2-35M', 
    '150M': 'ESM2-150M',
    '650M': 'ESM2-650M',
    '3B': 'ESM2-3B', 
    '15B': 'ESM2-15B', 
    'c300m': 'ESMC-300M', 
    'c600m': 'ESMC-600M'
}

lim = (df['Pseudo-Perplexity'].min(), df['Pseudo-Perplexity'].max())

stats = list()

for ax, model in zip(axs, models):
    ax.plot([0, 1], [0, 1], color='black', linestyle='--', transform=ax.transAxes, zorder=1)

    _df_original = df[(df['efficient'] == 'Original') & (df['model'] == model)] \
        .set_index('protein_id')[['Pseudo-Perplexity']].rename(columns={'Pseudo-Perplexity': 'original'})

    _df_efficient = df[(df['efficient'] == 'Efficient') & (df['model'] == model)] \
        .set_index('protein_id')[['Pseudo-Perplexity']].rename(columns={'Pseudo-Perplexity': 'efficient'})

    _df = _df_original.join(_df_efficient, how='inner')

    sns.scatterplot(data=_df, x='original', y='efficient', ax=ax, alpha=.25, zorder=2)
    ax.text(0.05, 0.95, model_label[model], va='center', transform=ax.transAxes, fontsize=10)
    ax.set_xlabel(None)
    ax.set_ylabel(None)

    ax.set_xlim(lim)
    ax.set_ylim(lim)
    sns.despine()

    stats.append({
        'model': model,
        'original': _df['original'].mean(),
        'efficient': _df['efficient'].mean(),
        'original_std': _df['original'].std(),
        'efficient_std': _df['efficient'].std(),
        'mse': mean_squared_error(_df['efficient'], _df['original']),
    })

fig.text(0.5, -0.02, 'Original Pseudo-Perplexity', ha='center')
fig.text(0.08, 0.5, 'Efficient Pseudo-Perplexity', va='center', rotation='vertical')

plt.savefig(snakemake.output['fig'], bbox_inches='tight', dpi=300)
pd.DataFrame(stats).to_csv(snakemake.output['stats'], index=False)