In [1]:
import pandas as pd

import torch
from pytorch_lightning import Trainer

from massspecgym.data import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.transforms import SpecTokenizer, MolFingerprinter

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Debug data loading

In [2]:
# Load dataset
dataset = MassSpecDataset(
    mgf_pth="../data/debug/example_5_spectra.mgf",
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset, split_pth="../data/debug/example_5_spectra_split.tsv", batch_size=2
)

In [3]:
# Iterate over batches
data_module.prepare_data()
data_module.setup()
dataloader = data_module.train_dataloader()
for batch in dataloader:
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            print(k, v.shape, v.dtype)
        else:
            print(k, v)
    break

spec torch.Size([2, 60, 2]) torch.float64
mol torch.Size([2, 2048]) torch.int32
precursor_mz torch.Size([2]) torch.float64
adduct ['[M+H]+', '[M+H]+']


## Debug retrieval data loading

In [4]:
# Load dataset
dataset = RetrievalDataset(
    candidates_pth="../data/debug/example_5_spectra_candidates.json",
    mgf_pth="../data/debug/example_5_spectra.mgf",
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset, split_pth="../data/debug/example_5_spectra_split.tsv", batch_size=2
)

In [5]:
# Iterate over batches
data_module.prepare_data()
data_module.setup()
dataloader = data_module.train_dataloader()
for batch in dataloader:
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            print(k, v.shape, v.dtype, v[0])
        elif isinstance(v, list):
            print(k, len(v), type(v[0]), v[0])
        else:
            print(k, v)
        # print(v)
    break

spec torch.Size([2, 60, 2]) torch.float64 tensor([[4.2034e+01, 6.9330e-02],
        [5.8029e+01, 1.1370e-01],
        [5.9032e+01, 2.1160e-02],
        [6.0033e+01, 1.2680e-02],
        [6.3023e+01, 2.0060e-02],
        [6.4018e+01, 2.3740e-02],
        [6.5038e+01, 9.2300e-03],
        [6.6033e+01, 2.1340e-02],
        [6.9992e+01, 1.9380e-02],
        [7.8034e+01, 4.0970e-02],
        [7.9029e+01, 1.2570e-02],
        [8.6060e+01, 9.9590e-02],
        [8.7064e+01, 3.7690e-02],
        [8.8064e+01, 1.4990e-02],
        [9.1018e+01, 3.7890e-02],
        [9.3044e+01, 7.2130e-02],
        [1.0405e+02, 5.2920e-02],
        [1.0504e+02, 4.7760e-02],
        [1.0505e+02, 3.4940e-02],
        [1.0603e+02, 1.1420e-02],
        [1.2006e+02, 7.7780e-02],
        [1.2104e+02, 6.7510e-02],
        [1.2501e+02, 2.9480e-02],
        [1.3303e+02, 2.0740e-02],
        [1.3304e+02, 3.9790e-02],
        [1.3502e+02, 9.0800e-03],
        [1.4805e+02, 4.7210e-02],
        [1.5401e+02, 2.1390e-02],
      

## Debug de novo data loading

In [11]:
# Load dataset
dataset = MassSpecDataset(
    mgf_pth="../data/debug/example_5_spectra.mgf",
    spec_transform=SpecTokenizer(n_peaks=60),
    # mol_transform=MolFingerprinter(),
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset, split_pth="../data/debug/example_5_spectra_split.tsv", batch_size=2
)

In [12]:
# Iterate over batches
data_module.prepare_data()
data_module.setup()
dataloader = data_module.train_dataloader()
for batch in dataloader:
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            print(k, v.shape, v.dtype, v[0])
        elif isinstance(v, list):
            print(k, len(v), type(v[0]), v[0])
        else:
            print(k, v)
        # print(v)
    break

spec torch.Size([2, 60, 2]) torch.float64 tensor([[7.0065e+01, 2.8620e-02],
        [8.4081e+01, 1.5110e-02],
        [1.0908e+02, 1.3850e-02],
        [1.1409e+02, 2.8860e-02],
        [1.2302e+02, 1.4149e-01],
        [1.6106e+02, 2.5200e-02],
        [2.0209e+02, 1.4350e-02],
        [2.3011e+02, 3.3670e-02],
        [2.4113e+02, 2.6664e-01],
        [2.6116e+02, 3.0620e-02],
        [2.7114e+02, 4.1501e-01],
        [2.8915e+02, 1.0000e+00],
        [3.4618e+02, 2.6100e-01],
        [3.5114e+02, 1.2000e-02],
        [4.1019e+02, 8.7500e-02],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
      