# Using pre-trained DreaMS in a custom model

This simple tutorial demonstrates how to use the pre-trained DreaMS weights in a custom PyTorch model.

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from dreams.api import PreTrainedModel
from dreams.models.dreams.dreams import DreaMS as DreaMSModel

# Example model
class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.spec_encoder = PreTrainedModel.from_ckpt(
            # ckpt_path should be replaced with the path to the ssl_model.ckpt model downloaded from https://zenodo.org/records/10997887
            ckpt_path="<path/to/ssl_model.ckpt>", ckpt_cls=DreaMSModel, n_highest_peaks=60
        ).model.train()

        # Example head for a downstream task (e.g., for binary classification)
        self.lin_out = nn.Linear(1024, 1)

    def forward(self, x):
        x = self.spec_encoder(x)[:, 0, :]  # [:, 0, :] to get the precursor peak token embedding
        x = F.sigmoid(self.lin_out(x))  # Example forward pass through the head
        return x

model = CustomModel()
example_in = torch.rand(32, 100, 2)  # Example input (32 = batch size, 100 = num. peaks, 2 = m/z and intensity)
example_out = model(example_in)
example_out.shape

torch.Size([32, 1])