In [None]:
import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning import Trainer

from massspecgym.data import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.transforms import SpecTokenizer, MolFingerprinter, SpecBinner
from massspecgym.models.retrieval import (
    DeepSetsRetrieval,
    RandomRetrieval,
    FingerprintFFNRetrieval,
)
from massspecgym.models.de_novo import DummyDeNovo, RandomDeNovo

%load_ext autoreload
%autoreload 2

In [None]:
pl.seed_everything(0)

DEBUG = False

In [None]:
if DEBUG:
    mgf_pth = "../data/debug/example_5_spectra.mgf"
    candidates_pth = "../data/debug/example_5_spectra_candidates.json"
    split_pth = "../data/debug/example_5_spectra_split.tsv"
else:
    # Use default benchmark paths
    mgf_pth = None
    candidates_pth = None
    split_pth = None

## Deep Sets model on the fingerprint retrieval task

In [None]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = RetrievalDataset(
    mgf_pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(dataset=dataset, split_pth=split_pth, batch_size=2)

# Init model
model = DeepSetsRetrieval()
# model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=1)

## Fingerpint FFN model on the fingerprint retrieval task

In [None]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(dataset=dataset, split_pth=split_pth, batch_size=64)

# Init model
model = FingerprintFFNRetrieval(in_channels=1000, out_channels=fp_size)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50)

In [None]:
trainer.test(model, datamodule=data_module)

## Random model on the de novo generation task

In [None]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = MassSpecDataset(
    pth=mgf_pth, spec_transform=SpecTokenizer(n_peaks=60), mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(dataset=dataset, split_pth=split_pth, batch_size=2)

# Init model
model = RandomDeNovo(formula_known=False, max_top_k=10)
# model = DummyDeNovo()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymDeNovo"
name = "RandomBasline"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu",
    max_epochs=1,
    logger=logger,
    log_every_n_steps=1000,
    limit_val_batches=0,
    num_sanity_val_steps=0,
)

### Train (only chem. element stats are allowed to be checked)

In [None]:
# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
# trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

### Test

In [None]:
trainer.test(model, datamodule=data_module)