In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm

In [3]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes

In [4]:
tsv_pth = "../data/MassSpecGym_small.tsv"
frag_pth = None
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
adducts = ["[M+H]+"]
instruments = ["QTOF","QFT","Orbitrap","ITFT"]
max_ce = 100. # arbitrary
spec_transform = SpecToMzsInts()
mol_transform = MolToFingerprints()
meta_transform = StandardMeta(
    adducts=adducts,
    instruments=instruments,
    max_ce=max_ce)

In [5]:
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    frag_transform=None,
    cache_feats=False)

In [6]:
item = ds[0]
print_shapes(item)

spec_mzs - (40,) - <class 'torch.Tensor'>
spec_ints - (40,) - <class 'torch.Tensor'>
fps - (4263,) - <class 'torch.Tensor'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [7]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [8]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (192,) - <class 'torch.Tensor'>
spec_ints - (192,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (192,) - <class 'torch.Tensor'>


In [9]:
print(ds.entry_df["collision_energy"].describe())

count    3253.000000
mean       38.993048
std        28.941670
min         0.000000
25%        20.000000
50%        30.000000
75%        58.780832
max       206.025504
Name: collision_energy, dtype: float64


In [10]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/407 [00:00<?, ?it/s]

In [11]:
mol_transform = MolToPyG()
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    frag_transform=None,
    cache_feats=False)

In [12]:
item = ds[0]
print_shapes(item)

spec_mzs - (40,) - <class 'torch.Tensor'>
spec_ints - (40,) - <class 'torch.Tensor'>
mol - (16, 32) - <class 'torch_geometric.data.data.Data'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [13]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [14]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (192,) - <class 'torch.Tensor'>
spec_ints - (192,) - <class 'torch.Tensor'>
mol - (178, 378) - <class 'torch_geometric.data.batch.DataBatch'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (192,) - <class 'torch.Tensor'>


In [15]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/407 [00:00<?, ?it/s]