In [1]:
import sys
from pathlib import Path

In [2]:
sys.path.append(str(Path().resolve().parent))
from paths import PROJECT_ROOT

In [11]:
import torch
import pandas as pd
import numpy as np
from dreams.utils.data import MSData
from benchmark.models.lit_binary_classifier import LitBinaryClassifier

In [8]:
# 1) Paths
CKPT       = PROJECT_ROOT / "notebooks" / "dreams_chlorine_finetuned.ckpt"
HDF5_PTH   = PROJECT_ROOT / "data" / "rawfiles" / "202312_20_P09-Leaf-r1_1uL_high_quality_dedup.hdf5"

In [20]:
# 2) Load model (and move to GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LitBinaryClassifier.load_from_checkpoint(str(CKPT), map_location="cpu").to(device)
model.eval()


LitBinaryClassifier(
  (model): DreamsClassifier(
    (spec_encoder): DreaMS(
      (fourier_enc): FourierFeatures()
      (ff_fourier): FeedForward(
        (ff): Sequential(
          (0): Linear(in_features=11994, out_features=512, bias=True)
          (1): Dropout(p=0.1, inplace=False)
          (2): ReLU()
          (3): Linear(in_features=512, out_features=512, bias=True)
          (4): Dropout(p=0.1, inplace=False)
          (5): ReLU()
          (6): Linear(in_features=512, out_features=512, bias=True)
          (7): Dropout(p=0.1, inplace=False)
          (8): ReLU()
          (9): Linear(in_features=512, out_features=512, bias=True)
          (10): Dropout(p=0.1, inplace=False)
          (11): ReLU()
          (12): Linear(in_features=512, out_features=980, bias=True)
          (13): ReLU()
        )
      )
      (ff_peak): FeedForward(
        (ff): Sequential(
          (0): Linear(in_features=2, out_features=44, bias=True)
          (1): ReLU()
        )
      )
      (tr

In [21]:
# 3) Load MSData
msdata = MSData.from_hdf5(str(HDF5_PTH), in_mem=False)

In [16]:
# 4) Pull out everything
spectra_list  = msdata.get_values("spectrum")      # list of (2, n_peaks)
spectra_arr   = np.stack(spectra_list, axis=0)     # → (N, 2, n_peaks)
# **permute** so the model sees (N, n_peaks, 2)
spectra_arr   = spectra_arr.transpose(0, 2, 1)    # → (N, n_peaks, 2)

prec_mzs      = msdata.get_values("precursor_mz")
file_names    = msdata.get_values("file_name")
scan_numbers  = msdata.get_values("scan_number")
ids           = [f"{fn}:{sn}" for fn, sn in zip(file_names, scan_numbers)]

In [24]:
from torch.utils.data import TensorDataset, DataLoader

In [25]:
# 5) To torch
spec_tensor = torch.from_numpy(spectra_arr).float()
dataset      = TensorDataset(spec_tensor)  # no labels needed here
loader       = DataLoader(dataset, batch_size=64, shuffle=False)

In [26]:
# 6) Inference in batches
all_probs = []
with torch.no_grad():
    for (batch_spec,) in loader:
        batch_spec = batch_spec.to(device)
        logits     = model(batch_spec)
        probs      = torch.sigmoid(logits).cpu()
        all_probs.append(probs)


In [27]:
all_probs = torch.cat(all_probs).numpy()

In [28]:
# 7) Assemble DataFrame & pick top‐k
df = pd.DataFrame({
    "id":                   ids,
    "precursor_mz":         prec_mzs,
    "chlorine_confidence":  all_probs
})
top10 = (
    df[df["chlorine_confidence"] > 0.5]
      .sort_values("chlorine_confidence", ascending=False)
      .head(10)
      .reset_index(drop=True)
)

In [29]:
print(top10)

                                    id  precursor_mz  chlorine_confidence
0  202312_20_P09-Leaf-r1_1uL.mzML:2920    718.323425             0.600898
1  202312_20_P09-Leaf-r1_1uL.mzML:2806    723.350830             0.600882
2  202312_20_P09-Leaf-r1_1uL.mzML:2491    344.222412             0.600880
3  202312_20_P09-Leaf-r1_1uL.mzML:2807    362.179077             0.600870
4  202312_20_P09-Leaf-r1_1uL.mzML:1534    344.186005             0.600864
5  202312_20_P09-Leaf-r1_1uL.mzML:2066    764.408752             0.600862
6  202312_20_P09-Leaf-r1_1uL.mzML:2516    344.222961             0.600859
7  202312_20_P09-Leaf-r1_1uL.mzML:3660    221.080948             0.600859
8  202312_20_P09-Leaf-r1_1uL.mzML:2182    415.210632             0.600856
9  202312_20_P09-Leaf-r1_1uL.mzML:1412    317.184143             0.600853
