In [1]:
import sys
sys.path.append("../")
from pathlib import Path

import torch
import numpy as np
from sklearn.metrics import roc_curve, auc

from const import gnps
from train import ModelTester
from data import Tokenizer
from utils import embedding, load_transformer_model, cosine_similarity

ROC_DIR = Path("/data1/xp/code/specEmbedding/ROC")
ROC_DIR.mkdir(exist_ok=True, parents=True)

In [2]:
spectra_paths = {
    "gnps":{
        "orbitrap": {
            "train": (gnps.ORBITRAP_TRAIN_QUERY, gnps.ORBITRAP_TEST_REF),
            "test": (gnps.ORBITRAP_TEST_QUERY, gnps.ORBITRAP_TEST_REF)
        },
        "qtof": {
            "test": (gnps.QTOF_TEST_QUERY, gnps.QTOF_TEST_REF)
        },
        "other": {
            "test": (gnps.OTHER_TEST_QUERY, gnps.OTHER_TEST_REF)
        }
    }
}
gnps_train_ref = np.load(gnps.ORBITRAP_TRAIN_REF, allow_pickle=True)

In [3]:
show_progress_bar = True
is_augment = True
model_backbone = "transformer"
loss_type = "SupConWithTanimotoLoss"
replica_suffix = "-replication-{}"
k_metric = [5, 1, 10]
batch_size = None
loader_batch_size = 512
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer(100, show_progress_bar)
model = load_transformer_model(device, loss_type, is_augment)

tester = ModelTester(model, device, show_progress_bar)

In [4]:
for db, db_metadata in spectra_paths.items():
    for desc, path_metadata in db_metadata.items():
        for info, paths in path_metadata.items():
            print("-" * 40, f"{db}-{desc}-{info}", "-" * 40)
            query_path, ref_path = paths
            ref_spectra = np.load(ref_path, allow_pickle=True)
            query_spectra = np.load(query_path, allow_pickle=True)
            if db == "gnps" and desc == "orbitrap":
                ref_spectra = np.hstack((gnps_train_ref, ref_spectra))
            
            query_embedding, query_smiles = embedding(tester, tokenizer, 512, query_spectra, show_progress_bar)
            ref_embedding, ref_smiles = embedding(tester, tokenizer, 512, ref_spectra, show_progress_bar)
            cosine_score = cosine_similarity(query_embedding, ref_embedding)
            mask = np.equal(query_smiles.reshape(-1, 1), ref_smiles.reshape(-1, 1).T)
            dir_ = ROC_DIR / f"{db}-{desc}-{info}"
            dir_.mkdir(parents=True, exist_ok=True)
            indices = np.load(dir_ / "random_indices.npy", allow_pickle=True).item()
            random_rows = indices["row"]
            random_cols = indices["col"]
            fpr, tpr, _ = roc_curve(mask[random_rows, random_cols], cosine_score[random_rows, random_cols])
            print(auc(fpr, tpr))
            np.save(dir_ / "SpecEmbedding.npy", {"fpr": fpr, "tpr": tpr})

---------------------------------------- gnps-orbitrap-train ----------------------------------------


tokenization: 100%|██████████| 6851/6851 [00:00<00:00, 7751.78it/s]
get smiles: 100%|██████████| 6851/6851 [00:00<00:00, 3047531.73it/s]
embedding: 100%|██████████| 14/14 [00:01<00:00, 12.51it/s]
tokenization: 100%|██████████| 155415/155415 [00:16<00:00, 9186.04it/s]
get smiles: 100%|██████████| 155415/155415 [00:00<00:00, 1906284.38it/s]
embedding: 100%|██████████| 304/304 [00:13<00:00, 21.82it/s]


0.9976587574353124
---------------------------------------- gnps-orbitrap-test ----------------------------------------


tokenization: 100%|██████████| 1686/1686 [00:00<00:00, 8735.52it/s]
get smiles: 100%|██████████| 1686/1686 [00:00<00:00, 3952820.87it/s]
embedding: 100%|██████████| 4/4 [00:00<00:00, 30.33it/s]
tokenization: 100%|██████████| 155415/155415 [00:16<00:00, 9689.21it/s] 
get smiles: 100%|██████████| 155415/155415 [00:00<00:00, 3112280.22it/s]
embedding: 100%|██████████| 304/304 [00:12<00:00, 24.77it/s]


0.9955562276460298
---------------------------------------- gnps-qtof-test ----------------------------------------


tokenization: 100%|██████████| 7520/7520 [00:01<00:00, 6441.19it/s]
get smiles: 100%|██████████| 7520/7520 [00:00<00:00, 3395173.96it/s]
embedding: 100%|██████████| 15/15 [00:00<00:00, 25.50it/s]
tokenization: 100%|██████████| 37040/37040 [00:05<00:00, 6637.72it/s]
get smiles: 100%|██████████| 37040/37040 [00:00<00:00, 2323199.85it/s]
embedding: 100%|██████████| 73/73 [00:02<00:00, 25.06it/s]


0.9730223671969256
---------------------------------------- gnps-other-test ----------------------------------------


tokenization: 100%|██████████| 6451/6451 [00:00<00:00, 7609.37it/s]
get smiles: 100%|██████████| 6451/6451 [00:00<00:00, 2856572.54it/s]
embedding: 100%|██████████| 13/13 [00:00<00:00, 25.78it/s]
tokenization: 100%|██████████| 44241/44241 [00:06<00:00, 7240.36it/s]
get smiles: 100%|██████████| 44241/44241 [00:00<00:00, 2614113.09it/s]
embedding: 100%|██████████| 87/87 [00:03<00:00, 25.09it/s]


0.9562787950883166
