In [2]:
from pathlib import Path
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer

import massspecgym.utils as utils
from massspecgym.data import MassSpecDataset, RetrievalDataset, MassSpecDataModule
from massspecgym.data.transforms import SpecTokenizer, MolFingerprinter, SpecBinner
from massspecgym.models.retrieval import DeepSetsRetrieval, RandomRetrieval, FingerprintFFNRetrieval, FromDictRetrieval
from massspecgym.models.de_novo import DummyDeNovo, RandomDeNovo, SmilesTransformer
from massspecgym.models.tokenizers import SmilesBPETokenizer, SelfiesTokenizer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
pl.seed_everything(0)

DEBUG = True

Seed set to 0


In [4]:
if DEBUG:
    mgf_pth = Path("../data/debug/example_5_spectra.mgf")
    candidates_pth = Path("../data/debug/example_5_spectra_candidates.json")
    split_pth=Path("../data/debug/example_5_spectra_split.tsv")
else:
    # Use default benchmark paths
    # mgf_pth = None
    # candidates_pth = None
    # split_pth = None
    mgf_pth = Path("../data/MassSpecGym_with_test/MassSpecGym_with_test.tsv")
    candidates_pth = Path("../data/MassSpecGym_with_test/MassSpecGym_retrieval_candidates_formula_with_test.json")
    split_pth = None

## Deep Sets model on the fingerprint retrieval task

In [5]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=MolFingerprinter(),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=3
)

# Init model
model = DeepSetsRetrieval(
    bootstrap_metrics=True,
    df_test_path='./df_test.pkl',
    out_channels=2048,
    fourier_features=True
)
# model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "DeepSets"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=1
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [7]:
# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

# Train
trainer.test(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]


   | Name                    | Type             | Params
--------------------------------------------------------------
0  | ff                      | FourierFeatures  | 6.0 K 
1  | ff_proj_mz              | Linear           | 4.9 M 
2  | ff_proj_i               | Linear           | 206   
3  | phi                     | MLP              | 525 K 
4  | rho                     | MLP              | 1.3 M 
5  | loss_fn                 | CosSimLoss       | 0     
6  | val_fingerprint_cos_sim | CosineSimilarity | 0     
7  | val_hit_rate@1          | MeanMetric       | 0     
8  | val_hit_rate@5          | MeanMetric       | 0     
9  | val_hit_rate@20         | MeanMetric       | 0     
10 | val_mces@1              | MeanMetric       | 0     
--------------------------------------------------------------
6.7 M     Trainable params
6.0 K     Non-trainable params
6.8 M     Total params
27.003    Total estimated model params size (MB)


torch.Size([1, 61, 512])
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 val_fingerprint_cos_sim    0.1643836796283722
     val_hit_rate@1                 0.0
     val_hit_rate@20                1.0
     val_hit_rate@5                 1.0
        val_loss            0.8356163501739502
       val_mces@1                  19.5
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])


/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])
torch.Size([3, 61, 512])


Validation: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])


/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

torch.Size([1, 61, 512])




────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_fingerprint_cos_sim    0.13307173550128937
     test_hit_rate@1                0.0
   test_hit_rate@1_std              nan
    test_hit_rate@20                0.0
  test_hit_rate@20_std              nan
     test_hit_rate@5                0.0
   test_hit_rate@5_std              nan
        test_loss           0.8669282793998718
       test_mces@1                 14.0
     test_mces@1_std                nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_fingerprint_cos_sim': 0.13307173550128937,
  'test_loss': 0.8669282793998718,
  'test_hit_rate@1': 0.0,
  'test_hit_rate@1_std': nan,
  'test_hit_rate@5': 0.0,
  'test_hit_rate@5_std': nan,
  'test_hit_rate@20': 0.0,
  'test_hit_rate@20_std': nan,
  'test_mces@1': 14.0,
  'test_mces@1_std': nan}]

## Random baseline on the fingerprint retrieval task

In [9]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
model = RandomRetrieval()

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "RandomFFN"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Fingerpint FFN model on the fingerprint retrieval task

In [9]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
model = FingerprintFFNRetrieval(
    in_channels=1005,
    out_channels=fp_size,
)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "FingerprintFFN_debug"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


# MIST on the fingerprint retrieval task

In [5]:
fp_size = 4096

# Load dataset
dataset = RetrievalDataset(
    pth=mgf_pth,
    spec_transform=SpecBinner(),
    mol_transform=MolFingerprinter(fp_size=fp_size),
    candidates_pth=candidates_pth,
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=64
)

# Init model
df = pd.read_pickle('fp_preds_MassSpecGym_df.pkl')
dct = dict(zip(df['name'], df['fp_predict']))
model = FromDictRetrieval(dct=dct)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymRetrieval"
name = "MIST"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=50
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## Dummy model on the de novo generation task

In [13]:
# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
model = DummyDeNovo(
    df_test_path='./df_test.pkl'
)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymDeNovo"
name = "RandomBasline"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=50, logger=logger, log_every_n_steps=1
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


## De novo SMILES transformer

In [17]:
# Load dataset
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=32
)

# Init model
model = SmilesTransformer(
    input_dim=2,
    d_model=512,
    nhead=8,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dropout=0.0,
    smiles_tokenizer=SmilesBPETokenizer(max_len=200),
    k_predictions=1,
    pre_norm=False,
    max_smiles_len=100,
    validate_only_loss=True
)

# Init trainer
project = "MassSpecGymDeNovo"
name = "SmilesTransformer_debug_overfitting"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)
trainer = Trainer(
    accelerator="cpu", max_epochs=100, logger=logger, log_every_n_steps=1, check_val_every_n_epoch=50
)

# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

Training tokenizer on 4116646 SMILES strings.





GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[13:42:13] Invalid InChI prefix in generating InChI Key
[13:42:13] Invalid InChI prefix in generating InChI Key

   | Name                        | Type             | Params
------------------------------------------------------------------
0  | src_encoder                 | Linear           | 1.5 K 
1  | tgt_embedding               | Embedding        | 2.7 M 
2  | transformer                 | Transformer      | 29.4 M
3  | tgt_decoder                 | Linear           | 2.7 M 
4  | criterion                   | CrossEntropyLoss | 0     
5  | val_num_valid_mols          | MeanMetric       | 0     
6  | val_top_1_mces_dist         | MeanMetric       | 0     
7  | val_top_1_max_tanimoto_sim  | MeanMetric       | 0     
8  | val_top_1_accuracy          | MeanMetric       | 0     
9  | val_top_10_mces_dist        | MeanMetric       | 0     
10 | val_top_10_max_tanimoto_sim | MeanMetric       | 0     
11 | val_top_10_accuracy         | MeanMetric       | 0     
---------------------------

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      Validate metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_loss               9.054847717285156
    val_num_valid_mols                 1.0
    val_top_10_accuracy                0.0
val_top_10_max_tanimoto_sim    0.03614457696676254
   val_top_10_mces_dist               100.0
    val_top_1_accuracy                 0.0
val_top_1_max_tanimoto_sim     0.03614457696676254
    val_top_1_mces_dist               100.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
[13:42:16] Invalid InChI prefix in generating InChI Key
[13:42:16] Invalid InChI prefix in generating InChI Key
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:48] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:50] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating InChI Key
[13:42:52] Invalid InChI prefix in generating In

Validation: |          | 0/? [00:00<?, ?it/s]

[13:44:17] SMILES Parse Error: syntax error while parsing: C/C1=C/CC[C)O[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2c2c2c2c2c2c2c2c2c2c2c
[13:44:17] SMILES Parse Error: Failed parsing SMILES 'C/C1=C/CC[C)O[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2c2c2c2c2c2c2c2c2c2c2c' for input: 'C/C1=C/CC[C)O[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2c2c2c2c2c2c2c2c2c2c2c'
[13:44:19] SMILES Parse Error: syntax error while parsing: C/C1=C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/
[13:44:19] SMILES Parse Error: Failed parsing SMILES 'C/C1=C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/' for input: 'C/C1=C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/C/'
[13:44:19] SMILES Parse Error: syntax error while parsing: C/C1=C/CC[C)O[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@@H]2[C@

Validation: |          | 0/? [00:00<?, ?it/s]

[13:45:50] SMILES Parse Error: unclosed ring for input: 'CNC(=O)O[C@H]1COc2c(c(N3CCN(C4COC4)CC3)c2)[C@@H]1NC(=O)ccccc1'
`Trainer.fit` stopped: `max_epochs=100` reached.


In [18]:
with torch.inference_mode():
    batch = next(iter(data_module.train_dataloader()))
    print(batch['mol'])
    print(model.decode_smiles(batch))

['CNC(=O)O[C@H]1COc2c(cc(N3CCN(C4COC4)CC3)cc2)[C@@H]1NC(=O)c1ccc(F)cc1', 'COc1ncc2cc(C(=O)Nc3c(Cl)ccc(C(=O)NCc4cc(Cl)ccc4)c3)c(=O)[nH]c2n1', 'C/C1=C/CC[C@@]2(C)O[C@@H]2[C@H]2OC(=O)[C@H](CN(C)C)[C@@H]2CC1']
[['CNC(=O)O[C@H]1COc2c(c(N3CCN(C4COC4)CC3)c2)[C@@H]1NC(=O)ccccc1'], ['C/C1=C/CC[C@@]2(C)O[C@@H]2[C@H]2OC(=O)[C@H](CN(C)C)[C@@H]2CC1'], ['C/C1=C/CC[C@@]2(C)O[C@@H]2[C@H]2OC(=O)[C@H](CN(C)C)[C@@H]2CC1']]


## De novo random chemical generation

In [4]:
pl.seed_everything(0)

# Load dataset
# Uncomment the paths to use debugging data containing only 5 spectra
dataset = MassSpecDataset(
    pth=mgf_pth,
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform=None
)

# Init data module
data_module = MassSpecDataModule(
    dataset=dataset,
    split_pth=split_pth,
    batch_size=2
)

# Init model
name = "random_baseline_no_formula"
model = RandomDeNovo(
    formula_known=False,
    max_top_k=10,
    estimate_chem_element_stats=True,
    enforce_connectivity=False,
    df_test_path=Path(f'../data/test_results/de_novo/{name}.pkl')
)

# Init logger
# You may need to run wandb init first to use the wandb logger
# Alternatively set logger = None in Trainer below not to use wandb
project = "MassSpecGymDeNovo"
logger = pl.loggers.WandbLogger(
    project=project,
    name=name,
    tags=[],
    log_model=False,
)

# Init trainer
trainer = Trainer(
    accelerator="cpu", max_epochs=1, logger=logger, log_every_n_steps=1000,
    limit_val_batches=0, num_sanity_val_steps=0
)

Seed set to 0
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [5]:
trainer.fit(model, datamodule=data_module)
trainer.test(model, datamodule=data_module)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manton-bushuiev[0m. Use [1m`wandb login --relogin`[0m to force relogin


/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:181: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer

  | Name | Type | Params
------------------------------
------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[15:39:32] Invalid InChI prefix in generating InChI Key
[15:39:32] Invalid InChI prefix in generating InChI Key
[15:39:32] Invalid InChI prefix in generating InChI Key
[15:39:32] Invalid InChI prefix in generating InChI Key
[15:39:32] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating InChI Key
[15:39:33] Invalid InChI prefix in generating In

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          Test metric                     DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           test_loss                          0.0
      test_num_valid_mols                     10.0
      test_top_10_accuracy                    0.0
    test_top_10_accuracy_std                  0.0
  test_top_10_max_tanimoto_sim        0.10096427798271179
test_top_10_max_tanimoto_sim_std     0.0001744613837217912
     test_top_10_mces_dist             25.71958351135254
   test_top_10_mces_dist_std          0.04963930323719978
      test_top_1_accuracy                     0.0
    test_top_1_accuracy_std                   0.0
  test_top_1_max_tanimoto_sim         0.07304630428552628
test_top_1_max_tanimoto_sim_std      0.00013741747534368187
      test_top_1_mces_dist             28.5852985382

[{'test_loss': 0.0,
  'test_num_valid_mols': 10.0,
  'test_top_1_mces_dist': 28.585298538208008,
  'test_top_1_mces_dist_std': 0.053359825164079666,
  'test_top_1_max_tanimoto_sim': 0.07304630428552628,
  'test_top_1_max_tanimoto_sim_std': 0.00013741747534368187,
  'test_top_1_accuracy': 0.0,
  'test_top_1_accuracy_std': 0.0,
  'test_top_10_mces_dist': 25.71958351135254,
  'test_top_10_mces_dist_std': 0.04963930323719978,
  'test_top_10_max_tanimoto_sim': 0.10096427798271179,
  'test_top_10_max_tanimoto_sim_std': 0.0001744613837217912,
  'test_top_10_accuracy': 0.0,
  'test_top_10_accuracy_std': 0.0}]

## Train

In [11]:
# Validate before training
data_module.prepare_data()  # Explicit call needed for validate before fit
data_module.setup()  # Explicit call needed for validate before fit
trainer.validate(model, datamodule=data_module)

# Train
trainer.fit(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

RuntimeError: For random denovo generation without known formula, the model has to be trained first,to record training molecular weights with corresponding formulas.

## Test

In [6]:
trainer.test(model, datamodule=data_module)

/Users/anton/miniconda3/envs/massspecgym/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_fingerprint_cos_sim   0.056853797286748886
     test_hit_rate@1                0.0
    test_hit_rate@20                0.0
     test_hit_rate@5                0.0
        test_loss           0.9431462287902832
     test_mces_at_1                14.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_fingerprint_cos_sim': 0.056853797286748886,
  'test_loss': 0.9431462287902832,
  'test_hit_rate@1': 0.0,
  'test_hit_rate@5': 0.0,
  'test_hit_rate@20': 0.0,
  'test_mces_at_1': 14.0}]