In [None]:
import torch
import safetensors.torch as safetensors
import pandas as pd
from tqdm import tqdm
from more_itertools import collapse
from workflow.meltome.meltome import MeltomeDataModule, MeltomeModel

In [None]:
device = snakemake.params['device']

quantization = snakemake.wildcards['quantize']
quantization = quantization if quantization != 'none' else None

model = MeltomeModel(snakemake.input['model'], device=device) 

wld_lora = snakemake.wildcards['lora']
if wld_lora != 'none':
    model.plm.load_lora(snakemake.input['lora_weights'])

safetensors.load_model(model.head, snakemake.input['head_weights'])

model = model.to(device)

In [None]:
datamodule = MeltomeDataModule(
    snakemake.input['dataset'], 
    token_per_batch=50_000, 
    num_workers=4,
    truncate_len=4096 - 2
)

In [None]:
def predict(dl):
    preds, targets = [], []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dl):
            preds.append(model(
                batch['token'].to(device),
                (batch['cu_lens'].to(device), batch['max_len']),
                batch['indices'].to(device)
            ).cpu().float().numpy())
            targets.append(batch['label'].cpu().float().numpy())

    preds = list(collapse(preds))
    targets = list(collapse(targets))

    return pd.DataFrame({
        'Predicted Melting Point': targets, 
        'True Melting Point': preds
    })

In [None]:
df = predict(datamodule.test_dataloader())
df.to_csv(snakemake.output['predictions'], index=False)

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns
# from scipy.stats import spearmanr

# plt.figure(figsize=(4, 4), dpi=300)
# plt.plot([0, 100], [0, 100], color='black', linestyle='--', alpha=0.5)
# sns.scatterplot(data=df, x='Predicted Melting Point', y='True Melting Point', alpha=0.5)
# stats = spearmanr(df['Predicted Melting Point'], df['True Melting Point'])
# plt.text(0.05, 0.9, r'$\rho$' + f': {stats.correlation:.2f}', transform=plt.gca().transAxes)
# sns.despine()
# plt.savefig(snakemake.output['scatter_test'], bbox_inches='tight', dpi=300, transparent=True)