In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm

In [3]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes

In [4]:
tsv_pth = "../data/MassSpecGym_4.tsv"
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
fp_types = ["morgan","maccs","rdkit"]
adducts = ["[M+H]+"]
instrument_types = ["QTOF","QFT","Orbitrap","ITFT"]
max_collision_energy = 200. # arbitrary
mz_from = 10.
mz_to = 1500.
spec_transform = SpecToMzsInts(
    mz_from=mz_from,
    mz_to=mz_to
)
mol_transform = MolToFingerprints(
    fp_types=fp_types
)
meta_transform = StandardMeta(
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy
)

ds = SimulationDataset(
    tsv_pth=tsv_pth,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    cache_feats=False)

In [5]:
df = pd.read_csv(tsv_pth, sep="\t")
sim_df = df[df["simulation_challenge"]]
print(sim_df.shape)

(123042, 15)


In [12]:
train_sim_df = sim_df[sim_df["fold"]=="train"]
val_sim_df = sim_df[sim_df["fold"]=="val"]
test_sim_df = sim_df[sim_df["fold"]=="test"]
print("> sim: everything")
print(sim_df.shape[0], sim_df["inchikey"].nunique())
print(sim_df[["instrument_type","collision_energy"]].isna().sum())
print("> sim: train split")
print(train_sim_df.shape[0], train_sim_df["inchikey"].nunique())
print(train_sim_df[["instrument_type","collision_energy"]].isna().sum())
print("> sim: val split")
print(val_sim_df.shape[0], val_sim_df["inchikey"].nunique())
print(val_sim_df[["instrument_type","collision_energy"]].isna().sum())
print("> sim: test split")
print(test_sim_df.shape[0], test_sim_df["inchikey"].nunique())
print(test_sim_df[["instrument_type","collision_energy"]].isna().sum())

> sim: everything
123042 18567
instrument_type     4013
collision_energy    4013
dtype: int64
> sim: train split
102808 13636
instrument_type     3467
collision_energy    3467
dtype: int64
> sim: val split
10033 2484
instrument_type     299
collision_energy    299
dtype: int64
> sim: test split
10201 2447
instrument_type     247
collision_energy    247
dtype: int64


In [13]:
print(sim_df["precursor_mz"].max())

995.556


In [16]:
sim_df["mzs"].iloc[0]

'91.0542,125.0233,154.0499,155.0577,185.0961,200.107,229.0859,246.1125'

In [17]:
mz_max = sim_df["mzs"].apply(lambda l: max(float(x) for x in l.split(",")))
mz_min = sim_df["mzs"].apply(lambda l: min(float(x) for x in l.split(",")))
print((mz_max > 1000.).sum())
print((mz_min > 1000.).sum())
print((mz_max > 1500.).sum())
print((mz_min > 1500.).sum())

1243
0
1130
0


In [13]:
sim_df = sim_df[~sim_df["collision_energy"].isna()]
print(sim_df.shape)
mz_max = sim_df["mzs"].apply(lambda l: float(max(l.split(","))))
precursor_mz = sim_df["precursor_mz"]
print(mz_max.describe())
print(precursor_mz.describe())
print((mz_max < 1000.).sum())
print((precursor_mz < 1000.).sum())

(119029, 15)
count    119029.000000
mean        154.944212
std         131.616448
min          31.604450
25%          91.054000
50%          96.080292
75%         162.092400
max         998.565700
Name: mzs, dtype: float64
count    119029.000000
mean        338.704167
std         136.954683
min          60.081046
25%         247.107718
50%         322.166240
75%         406.081620
max         995.556000
Name: precursor_mz, dtype: float64
119029
119029


In [15]:
df = pd.read_csv(tsv_pth, sep="\t")
small_df = df.sample(n=10000, replace=False, random_state=42)
small_df.to_csv("../data/MassSpecGym_4_small.tsv", sep="\t", index=False)

In [5]:
item = ds[0]
print_shapes(item)

spec_mzs - (8,) - <class 'torch.Tensor'>
spec_ints - (8,) - <class 'torch.Tensor'>
fps - (4263,) - <class 'torch.Tensor'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
spec_id - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [6]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [7]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (92,) - <class 'torch.Tensor'>
spec_ints - (92,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
spec_id - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (92,) - <class 'torch.Tensor'>


In [12]:
print(ds.entry_df["collision_energy"].describe())

count    119029.000000
mean         38.425283
std          27.709673
min           0.000000
25%          20.000000
50%          30.000000
75%          55.707348
max         358.400160
Name: collision_energy, dtype: float64


In [13]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/14945 [00:00<?, ?it/s]

In [11]:
mol_transform = MolToPyG()
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    cache_feats=False)

In [12]:
item = ds[0]
print_shapes(item)

spec_mzs - (40,) - <class 'torch.Tensor'>
spec_ints - (40,) - <class 'torch.Tensor'>
mol - (16, 32) - <class 'torch_geometric.data.data.Data'>
precursor_mz - () - <class 'torch.Tensor'>
adduct - () - <class 'torch.Tensor'>
instrument_type - () - <class 'torch.Tensor'>
collision_energy - () - <class 'torch.Tensor'>
weight - () - <class 'torch.Tensor'>


In [13]:
dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [14]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (192,) - <class 'torch.Tensor'>
spec_ints - (192,) - <class 'torch.Tensor'>
mol - (178, 378) - <class 'torch_geometric.data.batch.DataBatch'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (192,) - <class 'torch.Tensor'>


In [15]:
for batch in tqdm(iter(dl),total=len(dl)):
    pass

  0%|          | 0/407 [00:00<?, ?it/s]