In [21]:
import os
import torch
from chemprop import models
from chemprop.data import build_dataloader
import numpy as np
import pandas as pd
from lightning import pytorch as pl
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from src.utils import load_known_rxns
from src.featurizer import RCVNReactionMolGraphFeaturizer, MultiHotAtomFeaturizer, MultiHotBondFeaturizer
from src.data import RxnRCDatapoint, RxnRCDataset

In [None]:
ds_name = 'sprhea'
toc = 'sp_folded_pt'
gs_name = 'rc_gnn_two_channel_mean_agg_binaryffn_pred_neg_1'
seed = 1234
neg_multiple = 1
known_rxns = load_known_rxns(f"../data/{ds_name}/known_rxns_{toc}.json")

In [None]:
model_pref = f"/projects/p30041/spn1560/hiec/artifacts/model_evals/gnn/{gs_name}_{ds_name}_{toc}_epochs_seed_1234_split_"
n_splits = 5
featurizer = RCVNReactionMolGraphFeaturizer(
    atom_featurizer=MultiHotAtomFeaturizer.no_stereo(),
    bond_featurizer=MultiHotBondFeaturizer()
)

for i in range(n_splits):
    model_dir = model_pref + f"{i+1}_of_{n_splits}/version_0/checkpoints/"
    fn = os.listdir(model_path)[0]
    model_path = model_dir + fn
    mpnn = models.MPNN.load_from_file(model_path, map_location=torch.device('cpu'))
    test_data_path = f"/scratch/spn1560/{ds_name}_{toc}_{n_splits}_splits_{seed}_seed_{neg_multiple}_neg_multiple_{i}_split_idx_test.npy"
    test_data = np.load(test_data_path)

    datapoints_test = []
    for row in test_data:
        rxn = known_rxns[row['feature']]
        y = np.array([row['y']])
        datapoints_test.append(RxnRCDatapoint.from_smi(rxn, y=y, x_d=row['sample_embed']))

    dataset_test = RxnRCDataset(datapoints_test, featurizer=featurizer)

    data_loader_test = build_dataloader(dataset_test, shuffle=False)


    # Test
    with torch.inference_mode():
        trainer = pl.Trainer(
            logger=None,
            enable_progress_bar=True,
            accelerator="cpu",
            devices=1
        )
        test_preds = trainer.predict(mpnn, data_loader_test)

In [None]:
# test_preds = np.concatenate(test_preds, axis=0).reshape(-1, n_classes)
y_pred = np.argmax(test_preds, axis=1).reshape(-1,)
y_true = dataset_test.Y.reshape(-1,)

scorers = {
    'f1': lambda y_true, y_pred: f1_score(y_true, y_pred, average='macro'),
    'precision': lambda y_true, y_pred: precision_score(y_true, y_pred, average='macro'),
    'recall': lambda y_true, y_pred: recall_score(y_true, y_pred, average='macro'),
    'accuracy': accuracy_score
}

scores = {}

for k, scorer in scorers.items():
    scores[k] = scorer(y_true, y_pred)

print(scores)

In [4]:
'''
Check for data leak between train and test negatives
'''
train = np.load("/scratch/spn1560/sprhea_sp_folded_pt_5_splits_1234_seed_1_neg_multiple_1_split_idx_train.npy")
test = np.load("/scratch/spn1560/sprhea_sp_folded_pt_5_splits_1234_seed_1_neg_multiple_1_split_idx_test.npy")

In [22]:
guide = pd.read_csv("/scratch/spn1560/sprhea_sp_folded_pt_5_splits_1234_seed_1_neg_multiple.csv", sep='\t')
guide.head()

Unnamed: 0.1,Unnamed: 0,train/test,split_idx,X1,X2,y
0,0,train,0,0,0,1.0
1,1,train,0,1,0,1.0
2,2,train,0,2,0,1.0
3,3,train,0,3,0,1.0
4,4,train,0,3,1,1.0


In [23]:
guide.loc[(guide['train/test'] == 'train') & (guide['y'] == 0)]

Unnamed: 0.1,Unnamed: 0,train/test,split_idx,X1,X2,y


In [11]:
np.any(test['y'] == 0)

False

In [20]:
def f(seed):
    rng = np.random.default_rng(seed=seed)
    tr = sample_more(rng)
    te = sample_more(rng)
    return tr, te

def sample_more(rng):
    return rng.integers(0, 1000, size=3)

f(1234)

(array([979, 976, 987]), array([380, 171, 923]))