In [2]:
import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.data import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.transforms import SpecTokenizer, MolFingerprinter, SpecBinner
from massspecgym.models.retrieval import DeepSetsRetrieval, RandomRetrieval, FingerprintFFNRetrieval
from massspecgym.models.de_novo import DummyDeNovo

%load_ext autoreload
%autoreload 2

In [3]:
pl.seed_everything(0)

DEBUG = False

Seed set to 0


In [4]:
if DEBUG:
    mgf_pth = "../data/debug/example_5_spectra.mgf"
    candidates_pth = "../data/debug/example_5_spectra_candidates.json"
    split_pth="../data/debug/example_5_spectra_split.tsv"
else:
    # Use default benchmark paths
    mgf_pth = None
    candidates_pth = None
    split_pth = None

## Deep Sets model on the fingerprint retrieval task

In [4]:
# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
model = DeepSetsRetrieval()
# model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

MassSpecGym.tsv:   0%|          | 0.00/302M [00:00<?, ?B/s]

  self.metadata = pd.read_csv(self.pth, sep="\t")
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Fingerpint FFN model on the fingerprint retrieval task

In [5]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
model = FingerprintFFNRetrieval(in_channels=1000, out_channels=fp_size)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

  self.metadata = pd.read_csv(self.pth, sep="\t")
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Dummy model on the de novo generation task

In [4]:
# Load dataset
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None  # Leave molecules as they are for the genration task
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
model = DummyDeNovo()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymDeNovo"
name = "DummyBasline"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

  self.metadata = pd.read_csv(self.pth, sep="\t")
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Train

In [10]:
# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

  binned_intensities /= np.max(binned_intensities)
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...

  | Name                    | Type             | Params
-------------------------------------------------------------
0 | ffn                     | MLP              | 2.6 M 
1 | loss_fn                 | CosSimLoss       | 0     
2 | val_fingerprint_cos_sim | CosineSimilarity | 0     
3 | val_hit_rate@1          | RetrievalHitRate | 0     
4 | val_hit_rate@5          | RetrievalHitRate | 0     
5 | val_hit_rate@10         | RetrievalHitRate | 0     
-------------------------------------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total params
10.459    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


## Test

In [7]:
trainer.test(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
