In [5]:
import sys
sys.path.append("../")
from pathlib import Path

import torch
import numpy as np
from sklearn.metrics import roc_curve, auc

from const import gnps
from utils import load_model, cosine_similarity

ROC_DIR = Path("/data1/xp/code/specEmbedding/ROC")
ROC_DIR.mkdir(exist_ok=True, parents=True)

In [6]:
spectra_paths = {
    "gnps":{
        "orbitrap": {
            "train": (gnps.ORBITRAP_TRAIN_QUERY, gnps.ORBITRAP_TEST_REF),
            "test": (gnps.ORBITRAP_TEST_QUERY, gnps.ORBITRAP_TEST_REF)
        },
        "qtof": {
            "test": (gnps.QTOF_TEST_QUERY, gnps.QTOF_TEST_REF)
        },
        "other": {
            "test": (gnps.OTHER_TEST_QUERY, gnps.OTHER_TEST_REF)
        }
    },
}

gnps_train_ref = np.load(gnps.ORBITRAP_TRAIN_REF, allow_pickle=True)

In [7]:
model = load_model()

for db, db_metadata in spectra_paths.items():
    for desc, path_metadata in db_metadata.items():
        for info, paths in path_metadata.items():
            print("-" * 40, f"{db}-{desc}-{info}", "-" * 40)
            query_path, ref_path = paths
            ref_spectra = np.load(ref_path, allow_pickle=True)
            query_spectra = np.load(query_path, allow_pickle=True)
            if db == "gnps" and desc == "orbitrap":
                ref_spectra = np.hstack((gnps_train_ref, ref_spectra))
            
            query_smiles = np.array([
                s.get("smiles")
                for s in query_spectra
            ])
            ref_smiles = np.array([
                s.get("smiles")
                for s in ref_spectra
            ])
            query_embedding = model.get_embedding_array(query_spectra)
            ref_embedding = model.get_embedding_array(ref_spectra)
            cosine_score = cosine_similarity(query_embedding, ref_embedding)
            mask = np.equal(query_smiles.reshape(-1, 1), ref_smiles.reshape(-1, 1).T)
            dir_ = ROC_DIR / f"{db}-{desc}-{info}"
            dir_.mkdir(parents=True, exist_ok=True)
            indices = np.load(dir_ / "random_indices.npy", allow_pickle=True).item()
            random_rows = indices["row"]
            random_cols = indices["col"]
            fpr, tpr, _ = roc_curve(mask[random_rows, random_cols], cosine_score[random_rows, random_cols])
            print(auc(fpr, tpr))
            np.save(dir_ / "MS2DeepScore.npy", {"fpr": fpr, "tpr": tpr})

---------------------------------------- gnps-orbitrap-train ----------------------------------------


6851it [00:10, 676.80it/s]
155415it [04:20, 597.47it/s] 


0.9631361697331491
---------------------------------------- gnps-orbitrap-test ----------------------------------------


1686it [00:02, 763.69it/s]
155415it [03:30, 738.43it/s] 


0.8932237741071474
---------------------------------------- gnps-qtof-test ----------------------------------------


7520it [00:08, 909.16it/s] 
37040it [00:49, 745.83it/s] 


0.9198308195266631
---------------------------------------- gnps-other-test ----------------------------------------


6451it [00:07, 807.69it/s] 
44241it [00:59, 748.07it/s] 


0.8745027429866667
