In [1]:
import pandas as pd

import torch
from pytorch_lightning import Trainer

from massspecgym.datasets import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.transforms import SpecTokenizer, MolFingerprinter
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Debug data loading

In [2]:
# Load dataset
dataset = MassSpecDataset(
    mgf_pth='../data/debug/example_5_spectra.mgf',
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter()
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth='../data/debug/example_5_spectra_split.tsv',
    batch_size=2
)

In [3]:
# Iterate over batches
data_module.prepare_data()
data_module.setup()
dataloader = data_module.train_dataloader()
for batch in dataloader:
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            print(k, v.shape, v.dtype)
        else:
            print(k, v)
    break

spec torch.Size([2, 60, 2]) torch.float64
mol torch.Size([2, 2048]) torch.int32
precursor_mz torch.Size([2]) torch.float64
adduct ['[M+H]+', '[M+H]+']


## Debug retrieval data loading

In [12]:
# Load dataset
dataset = RetrievalDataset(
    candidates_pth='../data/debug/example_5_spectra_candidates.json',
    mgf_pth='../data/debug/example_5_spectra.mgf',
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter()
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth='../data/debug/example_5_spectra_split.tsv',
    batch_size=2
)

In [16]:
# Iterate over batches
data_module.prepare_data()
data_module.setup()
dataloader = data_module.train_dataloader()
for batch in dataloader:
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            print(k, v.shape, v.dtype)
        elif isinstance(v, list):
            print(k, len(v), type(v[0]))
        else:
            print(k, v)
        # print(v)
    break

spec torch.Size([2, 60, 2]) torch.float64
mol torch.Size([2, 2048]) torch.int32
precursor_mz torch.Size([2]) torch.float64
adduct 2 <class 'str'>
candidates 97 <class 'numpy.ndarray'>
labels 97 <class 'bool'>
batch_ptr 2 <class 'int'>
