In [None]:
import pandas as pd
import torch
from torch import nn
import safetensors.torch as sf 
from esme import ESM
from esme.pooling import BinaryLearnedAggregation
from esme.trainer import RegressionTrainer
from workflow.gb1_aav.gb1 import Gb1DataModule

device = snakemake.params['device']
torch.set_float32_matmul_precision('medium')

In [None]:
truncate_len=None
if snakemake.wildcards['model'].startswith('1ve') or snakemake.wildcards['model'].startswith('1be'):
    truncate_len = 4096

datamodule = Gb1DataModule(
    snakemake.input['fasta'],
    token_per_batch=10_000,
    num_workers=snakemake.threads,
    truncate_len=truncate_len,
)

In [None]:
_model = ESM.from_pretrained(snakemake.input['model'], device=device)

In [None]:
wld_lora = snakemake.wildcards['lora']
if wld_lora != 'none':
    _model.load_lora(snakemake.input['lora_weights']) 

In [None]:
head = BinaryLearnedAggregation(_model.attention_heads, _model.embed_dim).to(device)
head.load_state_dict(sf.load_file(snakemake.input['head_weights'], device=device))

In [None]:
model = RegressionTrainer(_model, head, reduction=None).to(device)
model.eval()

In [None]:
preds = list()
targets = list()

with torch.no_grad():
    for batch in datamodule.test_dataloader():
        pad_args = (batch['cu_lens'].to(device), batch['max_len'])
        token = batch['token'].to(device)
        
        preds.append(model(token, pad_args).cpu())
        targets.append(batch['label'])

df = pd.DataFrame({
    'pred': torch.cat(preds).float().numpy(),
    'target': torch.cat(targets).float().numpy(),
})
df.to_csv(snakemake.output['pred'], index=False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr

plt.figure(figsize=(4, 4), dpi=300)
sns.scatterplot(data=df, x='pred', y='target', alpha=0.5)
plt.xlabel('Predicted')
plt.ylabel('GB1 Measurement')
plt.xlim(-.05, 1.05)
plt.ylim(-.05, 1.05)
sns.despine()

spearman = spearmanr(df['pred'], df['target'])
plt.text(.05, .95, r'$\rho$' f'{spearman.correlation:.2f}')

plt.savefig(snakemake.output['fig'], bbox_inches='tight')