In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import os
import numpy
import torch
from pprint import pprint
from torch.utils.data import DataLoader

import massspecgym
from massspecgym.data.datasets import SimulationDataset
from massspecgym.runner import load_config
from massspecgym.data.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes

In [27]:
config_d = load_config(
    "/h/adamo/MassSpecGym_copy/config/template.yml",
    "/h/adamo/MassSpecGym_copy/config/new_bench_fp.yml"
)

In [28]:
spec_transform = SpecToMzsInts(
    mz_from=config_d["mz_from"],
    mz_to=config_d["mz_to"],
)
if config_d["model_type"] in ["fp", "prec_only"]:
    mol_transform = MolToFingerprints(
        fp_types=config_d["fp_types"]
    )
elif config_d["model_type"] == "gnn":
    mol_transform = MolToPyG()
else:
    raise ValueError(f"model_type {config_d['model_type']} not supported")
meta_transform = StandardMeta(
    adducts=config_d["adducts"],
    instrument_types=config_d["instrument_types"],
    max_collision_energy=config_d["max_collision_energy"]
)

In [29]:
ds = SimulationDataset(
    pth=os.path.join("/h/adamo/MassSpecGym_copy",config_d["tsv_pth"]),
    meta_keys=config_d["meta_keys"],
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform
)

In [30]:
print(len(ds))
print(ds.spectra.shape[0])
print(ds.metadata.shape[0])
print(ds.metadata["inchikey"].nunique())

119029
119029
119029
16974


In [31]:
print(ds.metadata["simulation_challenge"].value_counts())

simulation_challenge
True    119029
Name: count, dtype: int64


In [32]:
item = ds[0]
print_shapes(item)

spec_mzs - (8,) - <class 'torch.Tensor'>
spec_ints - (8,) - <class 'torch.Tensor'>
fps - (4263,) - <class 'torch.Tensor'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
mol_freq - () - <class 'torch.Tensor'>
identifier - None - <class 'str'>




In [33]:
dl = DataLoader(
    ds,
    num_workers=0,
    batch_size=3,
    shuffle=False,
    drop_last=False,
    collate_fn=ds.collate_fn
)

In [34]:
dl_iter = iter(dl)
batch = next(dl_iter)
print_shapes(batch)

spec_mzs - (26,) - <class 'torch.Tensor'>
spec_ints - (26,) - <class 'torch.Tensor'>
spec_batch_idxs - (26,) - <class 'torch.Tensor'>
fps - (3, 4263) - <class 'torch.Tensor'>
adduct - (3,) - <class 'torch.Tensor'>
instrument_type - (3,) - <class 'torch.Tensor'>
collision_energy - (3,) - <class 'torch.Tensor'>
precursor_mz - (3,) - <class 'torch.Tensor'>
mol_freq - (3,) - <class 'torch.Tensor'>
identifier - 3 - <class 'list'>
