In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
df = list()

for i in snakemake.input:
    model = i.split('model=')[1].split('_')[0]
    quantize = i.split('q=')[1].split('_')[0]

    _df = pd.read_csv(i)
    _df['efficient'] = 'Efficient' if 'e' in model else 'Original'
    _df['quantize'] = quantize
    _df['quantize'] = _df['quantize'].str.replace('none', 'bf16' if 'e' in model else '32bit')
    _df['model'] = model.replace('e', '').replace('1v5', '1v').replace('1v4', '1v') \
        .replace('1v3', '1v').replace('1v2', '1v').replace('1v1', '1v') \
        .replace('c300m', '300M').replace('c600m', '600M')
    _df['legend'] = _df['quantize'] + ' (' + _df['efficient'].astype(str) + ')'
    df.append(_df)

df = pd.concat(df).groupby(['model', 'quantize', 'efficient', 'study_id', 'legend'])[['correlation']].mean().reset_index()

In [None]:
_df = df.groupby(['model', 'quantize', 'efficient'])[['correlation']].agg(['mean', 'std']).reset_index()
_df.columns = ['model', 'quantize', 'efficient', 'correlation_mean', 'correlation_std']
_df.to_csv(snakemake.output['table'], index=False)

In [None]:
groups = [['1b', '1v'], ['8M', '35M', '150M', '650M', '3B', '15B'], ['300M', '600M']]
hue_order = ['32bit (Original)', 'bf16 (Efficient)', '8bit (Efficient)', '4bit (Efficient)']
group_sizes = [len(group) for group in groups]
total_size = sum(group_sizes)
width_ratios = [size / total_size for size in group_sizes]
fig, axes = plt.subplots(
    1, 3, figsize=(8, 4), dpi=300, gridspec_kw={'width_ratios': width_ratios}, sharey=True
)
labels = ['ESM1', 'ESM2', 'ESM-C']
for i, (group, ax, l) in enumerate(zip(groups, axes, labels)):
    sns.boxplot(
        data=df[df['model'].isin(group)], x='model', y='correlation', hue='legend', ax=ax,
        order=group, hue_order=hue_order
    )
    ax.set(xlabel=l, ylabel="Spearman's Rank Correlation Coef." if i == 0 else '')
    ax.legend().set_visible(False)
    ax.spines[['right', 'top']].set_visible(False)
    
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels, loc='upper center', ncol=5, labelspacing=0,
    handletextpad=0.35, columnspacing=1.25, bbox_to_anchor=(0.525, 1.05),
)
plt.tight_layout()
plt.savefig(snakemake.output['fig'], bbox_inches='tight', dpi=300, transparent=True)

In [None]:
models = ['1b', '1v', '8M', '35M', '150M', '650M', '3B', "15B", '300M', '600M']

fig, axs = plt.subplots(3, len(models), figsize=(2 * len(models), 8), dpi=300, sharey=True)

for i, q in enumerate(['bf16', '8bit', '4bit']):
    for ax, model in zip(axs[i], models):
        ax.scatter(
            df[(df['efficient'] == 'Original') & (df['model'] == model)].correlation.tolist(),
            df[(df['efficient'] == 'Efficient') & (df['model'] == model) & (df['quantize'] == q)].correlation.tolist(),
            alpha=.3
        )
        ax.set_xlim((0, 1))
        ax.set_ylim((0, 1))
        ax.plot([0,1],[0,1], color='black', linestyle='--')
        sns.despine()

        if model == '1b':
            ax.set_ylabel(q)

        if q == '4bit':
            ax.set_xlabel(model)

fig.text(0.5, 0.04, 'Model Size', ha='center')
fig.text(0.08, 0.5, 'Quantization', va='center', rotation='vertical')
plt.savefig(snakemake.output['fig_scatter'], bbox_inches='tight', dpi=300)