In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm

In [10]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes

In [11]:
tsv_pth = "../data/MassSpecGym_4.tsv"
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
fp_types = ["morgan","maccs","rdkit"]
adducts = ["[M+H]+"]
instrument_types = ["QTOF","QFT","Orbitrap","ITFT"]
max_collision_energy = 200. # arbitrary
mz_from = 10.
mz_to = 1500.
spec_transform = SpecToMzsInts(
    mz_from=mz_from,
    mz_to=mz_to
)
mol_transform = MolToFingerprints(
    fp_types=fp_types
)
meta_transform = StandardMeta(
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy
)

ds = SimulationDataset(
    tsv_pth=tsv_pth,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    cache_feats=False)

In [15]:
df = pd.read_csv(tsv_pth, sep="\t")
small_df = df.sample(n=10000, replace=False, random_state=42)
small_df.to_csv("../data/MassSpecGym_4_small.tsv", sep="\t", index=False)

In [5]:
item = ds[0]
print_shapes(item)

spec_mzs - (8,) - <class 'torch.Tensor'>
spec_ints - (8,) - <class 'torch.Tensor'>
fps - (4263,) - <class 'torch.Tensor'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
spec_id - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [6]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [7]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (92,) - <class 'torch.Tensor'>
spec_ints - (92,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
spec_id - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (92,) - <class 'torch.Tensor'>


In [12]:
print(ds.entry_df["collision_energy"].describe())

count    119029.000000
mean         38.425283
std          27.709673
min           0.000000
25%          20.000000
50%          30.000000
75%          55.707348
max         358.400160
Name: collision_energy, dtype: float64


In [13]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/14945 [00:00<?, ?it/s]

In [11]:
mol_transform = MolToPyG()
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    cache_feats=False)

In [12]:
item = ds[0]
print_shapes(item)

spec_mzs - (40,) - <class 'torch.Tensor'>
spec_ints - (40,) - <class 'torch.Tensor'>
mol - (16, 32) - <class 'torch_geometric.data.data.Data'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [13]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [14]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (192,) - <class 'torch.Tensor'>
spec_ints - (192,) - <class 'torch.Tensor'>
mol - (178, 378) - <class 'torch_geometric.data.batch.DataBatch'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (192,) - <class 'torch.Tensor'>


In [15]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/407 [00:00<?, ?it/s]