In [8]:
from pathlib import Path
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer

import massspecgym.utils as utils
from massspecgym.data import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.data.transforms import SpecTokenizer, MolFingerprinter, SpecBinner
from massspecgym.models.retrieval import DeepSetsRetrieval, RandomRetrieval, FingerprintFFNRetrieval, FromDictRetrieval
from massspecgym.models.de_novo import DummyDeNovo, RandomDeNovo, SmilesTransformer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
pl.seed_everything(0)

DEBUG = False

Seed set to 0


In [10]:
if DEBUG:
    mgf_pth = Path("../data/debug/example_5_spectra.mgf")
    candidates_pth = Path("../data/debug/example_5_spectra_candidates.json")
    split_pth=Path("../data/debug/example_5_spectra_split.tsv")
else:
    # Use default benchmark paths
    mgf_pth = None
    candidates_pth = None
    split_pth = None

## Deep Sets model on the fingerprint retrieval task

In [39]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
model = DeepSetsRetrieval()
# model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=1
)

ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

## Random baseline on the fingerprint retrieval task

In [None]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "RandomFFN"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Fingerpint FFN model on the fingerprint retrieval task

In [None]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
model = FingerprintFFNRetrieval(in_channels=1005, out_channels=fp_size)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "FingerprintFFN_debug"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


# MIST on the fingerprint retrieval task

In [None]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
df = pd.read_pickle('fp_preds_MassSpecGym_df.pkl')
dct = dict(zip(df['name'], df['fp_predict']))
model = FromDictRetrieval(dct=dct)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "MIST"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Dummy model on the de novo generation task

In [None]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
# model = RandomDeNovo()
model = DummyDeNovo()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymDeNovo"
name = "RandomBasline"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=1
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## De novo SMILES transformer

In [19]:
from massspecgym.models.tokenizers import SelfiesTokenizer, SmilesBPETokenizer


# Load dataset
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=32
)

tokenizer = SelfiesTokenizer(max_len=100)
# tokenizer = SmilesBPETokenizer(max_len=100)

# Init model
model = SmilesTransformer(
    input_dim=2,
    d_model=512,
    nhead=8,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dropout=0.0,
    smiles_tokenizer=tokenizer,
    k_predictions=1,
    pre_norm=False,
    max_smiles_len=100,
    # validate_only_loss=True,
    log_only_loss_at_stages=['train'],
)

# Init trainer
project = "MassSpecGymDeNovo"
name = "SmilesTransformer_debug_overfitting"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)
trainer = Trainer(
    accelerator="cpu", max_epochs=100, logger=None, log_every_n_steps=1, check_val_every_n_epoch=1, val_check_interval=10
)

# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/roman/miniconda/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/roman/miniconda/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:04] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating InChI Key
[13:43:12] Invalid InChI prefix in generating In

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/roman/miniconda/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:42] Invalid InChI prefix in generating InChI Key
[13:44:53] Invalid InChI prefix in generating InChI Key
[13:44:53] Invalid InChI prefix in generating InChI Key
[13:44:53] Invalid InChI prefix in generating InChI Key
[13:44:53] Invalid InChI prefix in generating InChI Key


Training: |          | 0/? [00:00<?, ?it/s]

/Users/roman/miniconda/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
import tokenizers
unk_token = "[UwNK]"
vocabulary = {unk_token: 0}
tokenizer = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocabulary, unk_token=unk_token))
tokenizer.encode("this is a test text")

tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit()
tokenizer.encode("this is a test text").ids

[0, 0, 0, 0, 0]

In [None]:
import selfies as sf
sf.encoder("C1=CC=C2C(=C1)O[P-]3(O2)(OC4=CC=CC=C4O3)(N=[N+]=[N-])Cl", strict=False)

'[C][=C][C][=C][C][=Branch1][Ring2][=C][Ring1][=Branch1][O][P-1][Branch1][Ring2][O][Ring1][=Branch1][Branch1][=N][O][C][=C][C][=C][C][=C][Ring1][=Branch1][O][Ring1][#Branch2][Branch1][Ring2][N][=N+1][=N-1][Cl]'

In [None]:
with torch.inference_mode():
    batch = next(iter(data_module.train_dataloader()))
    print(batch['mol'])
    print(model.decode_smiles(batch['spec'].float()))

['COc1ncc2cc(C(=O)Nc3c(Cl)ccc(C(=O)NCc4cc(Cl)ccc4)c3)c(=O)[nH]c2n1', 'CNC(=O)O[C@H]1COc2c(cc(N3CCN(C4COC4)CC3)cc2)[C@@H]1NC(=O)c1ccc(F)cc1', 'C/C1=C/CC[C@@]2(C)O[C@@H]2[C@H]2OC(=O)[C@H](CN(C)C)[C@@H]2CC1']
[['COc1ncc2cc(C(=O)Nc3c(Cl)ccc(C(=O)NCc4cc(Cl)cc4)c(=O)[nH]c2n1'], ['CNC(=O)O[C@H]1COc2c(c(N3CCN(C4COC4)CC3)cc2)[C@@H]1NC(=O)cc1ccccc(F)c1'], ['C/C1=C/CC[C@@]2(C)O[C@@H]2[C@H]2OC(=O)[C@H](CN(C)C)[C@@H]2CC1']]


## Train

In [None]:
# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manton-bushuiev[0m. Use [1m`wandb login --relogin`[0m to force relogin


/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]


  | Name                    | Type             | Params
-------------------------------------------------------------
0 | ffn                     | MLP              | 2.6 M 
1 | loss_fn                 | CosSimLoss       | 0     
2 | val_fingerprint_cos_sim | CosineSimilarity | 0     
3 | val_hit_rate@1          | RetrievalHitRate | 0     
4 | val_hit_rate@5          | RetrievalHitRate | 0     
5 | val_hit_rate@20         | RetrievalHitRate | 0     
6 | val_mces_at_1           | MeanMetric       | 0     
-------------------------------------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total params
10.469    Total estimated model params size (MB)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 val_fingerprint_cos_sim   0.023980841040611267
     val_hit_rate@1                 0.0
     val_hit_rate@20                0.0
     val_hit_rate@5                 0.0
        val_loss            0.9760191440582275
      val_mces_at_1                17.5
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


## Test

In [None]:
trainer.test(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_fingerprint_cos_sim   0.056853797286748886
     test_hit_rate@1                0.0
    test_hit_rate@20                0.0
     test_hit_rate@5                0.0
        test_loss           0.9431462287902832
     test_mces_at_1                14.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_fingerprint_cos_sim': 0.056853797286748886,
  'test_loss': 0.9431462287902832,
  'test_hit_rate@1': 0.0,
  'test_hit_rate@5': 0.0,
  'test_hit_rate@20': 0.0,
  'test_mces_at_1': 14.0}]