In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm
from pprint import pprint

In [3]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes
from massspecgym.simulation_utils.models import FPFFNModel

In [4]:
tsv_pth = "../data/MassSpecGym_small.tsv"
frag_pth = None
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
adducts = ["[M+H]+"]
instruments = ["QTOF","QFT","Orbitrap","ITFT"]
max_ce = 100. # arbitrary
spec_transform = SpecToMzsInts()
mol_transform = MolToFingerprints()
meta_transform = StandardMeta(
    adducts=adducts,
    instruments=instruments,
    max_ce=max_ce)

In [5]:
mol_d = mol_transform.get_input_sizes()
meta_d = meta_transform.get_input_sizes()
model_d = dict(
    metadata_insert_location="mlp",
    collision_energy_insert_size=16,
    adduct_insert_size=16,
    instrument_type_insert_size=16,
    # output
    mz_max=1000.,
    mz_bin_res=0.01,
    # model
    mlp_hidden_size=1024,
    mlp_dropout=0.1,
    mlp_num_layers=3,
    mlp_use_residuals=True,
    ff_prec_mz_offset=500,
    ff_bidirectional=True,
    ff_output_map_size=128
)
model_d = {**model_d, **mol_d, **meta_d}
model = FPFFNModel(**model_d)

In [6]:
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    frag_transform=None,
    cache_feats=False)

dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [7]:
batch = next(iter(dl))
print_shapes(batch)

spec_mzs - (192,) - <class 'torch.Tensor'>
spec_ints - (192,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (192,) - <class 'torch.Tensor'>


In [8]:
print(meta_d)
print(batch["collision_energy"])
print(model.collision_energy_embedder)

{'adduct_input_size': 2, 'instrument_type_input_size': 5, 'collision_energy_input_size': 100}
tensor([28, 20, 60, 60, 99, 20, 55, 20])
Sequential(
  (0): FourierFeaturizerAbsoluteSines()
  (1): Linear(in_features=7, out_features=16, bias=True)
)


In [9]:
# # model.collision_energy_embedder(torch.zeros_like(batch["collision_energy"]))
# from massspecgym.simulation_utils.formula_embedder import FourierFeaturizerAbsoluteSines
# ffs = FourierFeaturizerAbsoluteSines(max_count_int=100)
# print(ffs.num_dim)
# print(ffs.forward(batch["collision_energy"].reshape(-1,1)).shape)

In [10]:
out_d = model.forward(**batch)
print_shapes(out_d)

pred_mzs - (130544,) - <class 'torch.Tensor'>
pred_logprobs - (130544,) - <class 'torch.Tensor'>
pred_batch_idxs - (130544,) - <class 'torch.Tensor'>


In [11]:
from massspecgym.models.simulation.base import sparse_cosine_distance
loss = sparse_cosine_distance(
    batch["true_mzs"],
    batch["true_logprobs"],
    batch["true_batch_idxs"],
    out_d["pred_mzs"],
    out_d["pred_logprobs"],
    out_d["pred_batch_idxs"],
    1000.,
    0.01,
)
print(loss)
loss.backward()

KeyError: 'true_mzs'