This notebook demonstrates how to fine‑tune the DreaMS model for a binary classification task (detecting chlorine in molecules) using the MassSpecGym dataset. We’ll:

1. **Annotate** the MassSpecGym MGF with chlorine labels.
2. **Prepare** a `ChlorineDetectionDataset` and a `BenchmarkDataModule`.
3. **Train** a baseline MLP classifier.
4. **Fine‑tune** the DreaMS encoder with a classification head (`LitDreamsClassifier`).
5. **Evaluate** on the test split and **save** the checkpoint.

All paths are defined relative to `PROJECT_ROOT` for reproducibility.

In [13]:
import sys
from pathlib import Path

# assume this notebook lives in notebooks/, so parent() is the repo root
sys.path.append(str(Path().resolve().parent))
from paths import PROJECT_ROOT

from benchmark.utils.data import annotate_mgf_with_label

import torch
import pytorch_lightning as pl

from massspecgym.data.transforms import SpecTokenizer
from benchmark.data.datasets import ChlorineDetectionDataset
from benchmark.data.data_module import BenchmarkDataModule

from benchmark.models.lit_dreams_module import LitDreamsClassifier




In [6]:
# Paths
DATA_DIR   = PROJECT_ROOT / "data" / "massspecgym"
ORIG_MGF   = DATA_DIR / "MassSpecGym.mgf"
LABELED_MGF = DATA_DIR / "MassSpecGym_chlorine.mgf"

MODEL_PATH = PROJECT_ROOT / "data" / "model_checkpoints" / "ssl_model.ckpt"

#### Here we define function for annotation of our data. It is important as it will set ground truth for our data. Here we are working with MassSpecGym data where each mass spectra is annotated with correct molecule and based on this we can further annotate our spectra.

#### therefore here is are solving chlorine detection problem, we pull molecule associated with mass spectra and ask if molecule contains Chlorine, if yes we assign mass spectra label with value 1.0, respectively 0.0 if it does not contain Chlorine,

In [4]:
# Define labeling function: 1.0 if 'Cl' in FORMULA
label_fn = lambda md: float("Cl" in md.get("FORMULA", ""))

In [5]:
# Write out labeled MGF
annotate_mgf_with_label(ORIG_MGF, LABELED_MGF, label_fn)
print(f"Labeled MGF written to: {LABELED_MGF}")

Labeled MGF written to: /Users/macbook/CODE/DreaMS_MIMB/data/massspecgym/MassSpecGym_chlorine.mgf


#### Now as we prepared data, we can provide it to DreaMS and learns it to distinguish if mass spectra contain chlorine.

In [7]:
# 1) Instantiate your Lightning module, pointing to the SSL checkpoint
lit = LitDreamsClassifier(
    ckpt_path=MODEL_PATH,
    n_highest_peaks=128,    # must match our tokenizer
    lr=1e-4,
    dropout=0.1,
    train_encoder=True     # fine-tune the entire encoder
)

In [8]:
# 2) Tokenize spectra into fixed-length set representations
spec_transform = SpecTokenizer(n_peaks=128)

#### Here we actually load benchmark and data that will go into model, and will be trained on

In [9]:
ds_dreams = ChlorineDetectionDataset(
    pth=LABELED_MGF,
    spec_transform=spec_transform,
    dtype=torch.float32
)

In [10]:
# 3) Prepare the DataModule
dm_dreams = BenchmarkDataModule(
    dataset    = ds_dreams,
    batch_size = 16,
    num_workers= 0
)
dm_dreams.setup()

In [11]:
# Inspect one batch just to sanity-check shapes:
batch = next(iter(dm_dreams.train_dataloader()))
print("spec shape:", batch["spec"].shape)    # -> [B, 61, 2]
print("label shape:", batch["label"].shape)

spec shape: torch.Size([16, 129, 2])
label shape: torch.Size([16])


### Now we are approaching training, meaning model will se annotated examples and will learn to recognize if mass spectra contain chlorine 

#### Can take some time, however, once the DreaMS is trained, you do not need to redo all these steps, but just load already learned DreaMS

# TODO add condition if GPU available

In [None]:
# trainer = pl.Trainer(
#     max_epochs=5,
#     accelerator="cpu",  # or "gpu"
#     devices=1,
#     log_every_n_steps=10,
# )

In [14]:
trainer = pl.Trainer(
    max_epochs=1,
    limit_train_batches=5,     # run only 5 training batches
    limit_val_batches=3,       # run only 3 validation batches
    limit_test_batches=3,      # run only 3 test batches
    accelerator="cpu",         # or "gpu"
    devices=1,
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [15]:
trainer.fit(lit, datamodule=dm_dreams)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | DreamsClassifier | 95.6 M
1 | train_acc | BinaryAccuracy   | 0     
2 | val_acc   | BinaryAccuracy   | 0     
3 | val_auc   | BinaryAUROC      | 0     
-----------------------------------------------
95.6 M    Trainable params
0         Non-trainable params
95.6 M    Total params
382.202   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [16]:
trainer.test(lit, datamodule=dm_dreams)

/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    1.0
        test_loss          7.947286206899662e-08
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 7.947286206899662e-08, 'test_acc': 1.0}]

#### Now model was sucessfuly trained and we do not want to repeat computationally expensive training, so we save learned model and will reuse it 

In [17]:
# 4) Save your fine-tuned checkpoint
trainer.save_checkpoint("dreams_chlorine_finetuned.ckpt")

### 4 – Reload & verify

In [None]:
lit2 = LitDreamsClassifier.load_from_checkpoint("dreams_chlorine_finetuned.ckpt")
lit2.eval()
test_trainer = pl.Trainer(accelerator="cpu", devices=1)
test_results = test_trainer.test(lit2, datamodule=dm_dreams)
print(test_results)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/macbook/UTILS/anaconda3/envs/dreams_mimb/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]