In [9]:
import sys
sys.path.append("../")
from pathlib import Path

import torch
import numpy as np
from sklearn.metrics import roc_curve, auc
from gensim.models import Word2Vec

from const import gnps
from type import TokenizerConfig
from data import Tokenizer
from utils import load_model, embedding, cosine_similarity

ROC_DIR = Path("/data1/xp/code/specEmbedding/ROC")
ROC_DIR.mkdir(exist_ok=True, parents=True)

In [10]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = load_model(device)
show_progress_bar = False

tokenizer_config = TokenizerConfig(
    max_len=100,
    n_decimals=2,
    show_progress_bar=show_progress_bar
)
tokenizer = Tokenizer(**tokenizer_config)

In [11]:
spectra_paths = {
    "gnps":{
        "orbitrap": {
            "train": (gnps.ORBITRAP_TRAIN_QUERY, gnps.ORBITRAP_TEST_REF),
            "test": (gnps.ORBITRAP_TEST_QUERY, gnps.ORBITRAP_TEST_REF)
        },
        "qtof": {
            "test": (gnps.QTOF_TEST_QUERY, gnps.QTOF_TEST_REF)
        },
        "other": {
            "test": (gnps.OTHER_TEST_QUERY, gnps.OTHER_TEST_REF)
        }
    },
}

gnps_train_ref = np.load(gnps.ORBITRAP_TRAIN_REF, allow_pickle=True)

In [12]:
for db, db_metadata in spectra_paths.items():
    for desc, path_metadata in db_metadata.items():
        for info, paths in path_metadata.items():
            print("-" * 40, f"{db}-{desc}-{info}", "-" * 40)
            query_path, ref_path = paths
            ref_spectra = np.load(ref_path, allow_pickle=True)
            query_spectra = np.load(query_path, allow_pickle=True)
            if db == "gnps" and desc == "orbitrap":
                ref_spectra = np.hstack((gnps_train_ref, ref_spectra))
        
            query_embedding, query_smiles = embedding(model, device, tokenizer, 512, query_spectra, show_progress_bar)
            ref_embedding, ref_smiles = embedding(model, device, tokenizer, 512, ref_spectra, show_progress_bar)
            cosine_score = cosine_similarity(query_embedding, ref_embedding)
            mask = np.equal(query_smiles.reshape(-1, 1), ref_smiles.reshape(-1, 1).T)
            dir_ = ROC_DIR / f"{db}-{desc}-{info}"
            dir_.mkdir(parents=True, exist_ok=True)
            indices = np.load(dir_ / "random_indices.npy", allow_pickle=True).item()
            random_rows = indices["row"]
            random_cols = indices["col"]
            fpr, tpr, _ = roc_curve(mask[random_rows, random_cols], cosine_score[random_rows, random_cols])
            print(auc(fpr, tpr))
            np.save(dir_ / "MSBERT.npy", {"fpr": fpr, "tpr": tpr})

---------------------------------------- gnps-orbitrap-train ----------------------------------------
0.8433973991534613
---------------------------------------- gnps-orbitrap-test ----------------------------------------
0.833769179697961
---------------------------------------- gnps-qtof-test ----------------------------------------
0.9411283553498268
---------------------------------------- gnps-other-test ----------------------------------------
0.924461481967707
