In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy
import torch
from pprint import pprint
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np
import pandas as pd

from massspecgym.data.datasets import RetrievalSimulationDataset
from massspecgym.runner import load_config, get_split_ss
from massspecgym.data.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes
from massspecgym.models.simulation.fp import FPSimulationMassSpecGymModel

In [5]:
config_d = load_config(
    "/h/adamo/MassSpecGym_copy/config/template.yml",
    "/h/adamo/MassSpecGym_copy/config/debug_ret.yml"
)

In [6]:
spec_transform = SpecToMzsInts(
    mz_from=config_d["mz_from"],
    mz_to=config_d["mz_to"],
)
if config_d["model_type"] in ["fp", "prec_only"]:
    mol_transform = MolToFingerprints(
        fp_types=config_d["fp_types"]
    )
elif config_d["model_type"] == "gnn":
    mol_transform = MolToPyG()
else:
    raise ValueError(f"model_type {config_d['model_type']} not supported")
meta_transform = StandardMeta(
    adducts=config_d["adducts"],
    instrument_types=config_d["instrument_types"],
    max_collision_energy=config_d["max_collision_energy"]
)

In [7]:
ds = RetrievalSimulationDataset(
    pth=os.path.join("/h/adamo/MassSpecGym_copy",config_d["pth"]),
    candidates_pth=os.path.join("/h/adamo/MassSpecGym_copy",config_d["candidates_pth"]),
    meta_keys=config_d["meta_keys"],
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform
)

In [8]:
print(len(ds))
print(ds.spectra.shape[0])
print(ds.metadata.shape[0])
print(ds.metadata["inchikey"].nunique())
print(len(ds.candidates))

11952
11952
11952
6283
32010


In [9]:
print(ds.metadata["simulation_challenge"].value_counts())
print(ds.metadata["fold"].value_counts())

simulation_challenge
True    11952
Name: count, dtype: int64
fold
train    9927
test     1040
val       985
Name: count, dtype: int64


In [10]:
item = ds[0]
print_shapes(item)

spec_mzs - (19,) - <class 'torch.Tensor'>
spec_ints - (19,) - <class 'torch.Tensor'>
fps - (4263,) - <class 'torch.Tensor'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
smiles - None - <class 'str'>
mol_freq - () - <class 'torch.Tensor'>
identifier - None - <class 'str'>
candidates_smiles - 256 - <class 'list'>
candidates_labels - (256,) - <class 'torch.Tensor'>
candidates_mol_feats - 256 - <class 'list'>


In [12]:
smileses = list(ds.metadata["smiles"])
num_cands = [len(ds.candidates.get(smiles,[])) for smiles in smileses]
num_cands = pd.Series(num_cands)
no_cands = num_cands == 0
print(num_cands.describe())
print("missing candidates:",no_cands.sum())


count    11952.000000
mean       254.596553
std         14.776761
min         12.000000
25%        256.000000
50%        256.000000
75%        256.000000
max        256.000000
dtype: float64
missing candidates: 0


In [49]:
train_ss, val_ss, test_ss = get_split_ss(ds, config_d["split_type"])
dl = DataLoader(
    val_ss, # using val since test candidates are missing
    num_workers=0,
    batch_size=3,
    shuffle=False,
    drop_last=False,
    collate_fn=ds.collate_fn
)

>>> Number of Spectra
99341 9734 0
>>> Number of Unique Molecules
12321 2334 0


In [50]:
dl_iter = iter(dl)
batch = next(dl_iter)
print_shapes(batch)

spec_mzs - (37,) - <class 'torch.Tensor'>
spec_ints - (37,) - <class 'torch.Tensor'>
spec_batch_idxs - (37,) - <class 'torch.Tensor'>
fps - (3, 4263) - <class 'torch.Tensor'>
adduct - (3,) - <class 'torch.Tensor'>
instrument_type - (3,) - <class 'torch.Tensor'>
collision_energy - (3,) - <class 'torch.Tensor'>
precursor_mz - (3,) - <class 'torch.Tensor'>
smiles - 3 - <class 'list'>
mol_freq - (3,) - <class 'torch.Tensor'>
identifier - 3 - <class 'list'>
candidates_data - None - <class 'dict'>


In [51]:
print_shapes(batch["candidates_data"])

fps - (768, 4263) - <class 'torch.Tensor'>
smiles - 768 - <class 'list'>
batch_ptr - (3,) - <class 'torch.Tensor'>
labels - (768,) - <class 'torch.Tensor'>


In [52]:
# dl_iter = iter(dl)
# for batch in tqdm(dl_iter, total=len(dl)):
#     pass

  1%|          | 39/3245 [05:45<7:53:08,  8.85s/it]


KeyboardInterrupt: 

In [53]:
from massspecgym.models.simulation.fp import FPSimulationMassSpecGymModel
from massspecgym.models.base import MassSpecGymModel, Stage


In [54]:
pl_model = FPSimulationMassSpecGymModel(**config_d)

In [55]:
outputs = pl_model.test_step(batch=batch, batch_idx=0)

/h/adamo/miniconda3/envs/MSG2/lib/python3.11/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [56]:
print_shapes(outputs)

loss - () - <class 'torch.Tensor'>
pred_mzs - (89313,) - <class 'torch.Tensor'>
pred_logprobs 

- (89313,) - <class 'torch.Tensor'>
pred_batch_idxs - (89313,) - <class 'torch.Tensor'>
true_mzs - (34,) - <class 'torch.Tensor'>
true_logprobs - (34,) - <class 'torch.Tensor'>
true_batch_idxs - (34,) - <class 'torch.Tensor'>
retrieval_scores - (768,) - <class 'torch.Tensor'>
retrieval_labels - (768,) - <class 'torch.Tensor'>
retrieval_batch_ptr - (3,) - <class 'torch.Tensor'>


In [None]:
pl_model.evaluate_retrieval_step(
    scores=outputs["retrieval_scores"],
    labels=outputs["retrieval_labels"],
    batch_ptr=outputs["retrieval_batch_ptr"],
    stage=Stage.TEST
)

/h/adamo/miniconda3/envs/MSG2/lib/python3.11/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [None]:
getattr(pl_model,Stage.TEST.to_pref()+"hit_rate@1").compute()

tensor(0.)