In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm
from pprint import pprint
from torch.utils.data import Subset
import pytorch_lightning as pl

In [3]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes
from massspecgym.models.simulation.fp_ffn import FPFFNSimulationMassSpecGymModel

In [4]:
tsv_pth = "../data/MassSpecGym_small.tsv"
frag_pth = None
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
fp_types = ["morgan","maccs","rdkit"]
adducts = ["[M+H]+"]
instrument_types = ["QTOF","QFT","Orbitrap","ITFT"]
max_collision_energy = 200. # arbitrary
spec_transform = SpecToMzsInts()
mol_transform = MolToFingerprints(
    fp_types=fp_types
)
meta_transform = StandardMeta(
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy
)

In [5]:
pl_model_d = dict(
    # features
    fp_types=fp_types,
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy,
    # input
    metadata_insert_location="mlp",
    collision_energy_insert_size=16,
    adduct_insert_size=16,
    instrument_type_insert_size=16,
    ints_transform="none",
    # output
    mz_max=1000.,
    mz_bin_res=0.01,
    # model
    mlp_hidden_size=1024,
    mlp_dropout=0.1,
    mlp_num_layers=3,
    mlp_use_residuals=True,
    ff_prec_mz_offset=500,
    ff_bidirectional=True,
    ff_output_map_size=128,
    # optimization
    lr=1e-5,
    weight_decay=1e-7,
    train_sample_weight=True,
    eval_sample_weight=True
)
pl_model = FPFFNSimulationMassSpecGymModel(**pl_model_d)

In [6]:
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    frag_transform=None,
    cache_feats=False)

dl = DataLoader(
    ds,
    num_workers=0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

In [7]:
batch = next(iter(dl))
print_shapes(batch)
print(batch["spec_mzs"].max())
print(batch["precursor_mz"].max())

spec_mzs - (152,) - <class 'torch.Tensor'>
spec_ints - (152,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (152,) - <class 'torch.Tensor'>
tensor(450.1145)
tensor(1033.3547)


In [8]:
pl_model.cpu()


FPFFNSimulationMassSpecGymModel(
  (model): FPFFNModel(
    (collision_energy_embedder): Sequential(
      (0): FourierFeaturizerAbsoluteSines()
      (1): Linear(in_features=8, out_features=16, bias=True)
    )
    (adduct_embedder): Embedding(3, 16)
    (instrument_type_embedder): Embedding(6, 16)
    (ffn): SpecFFN(
      (in_layer): Linear(in_features=4311, out_features=1024, bias=True)
      (ff_layers): ModuleList(
        (0-2): 3 x NeimsBlock(
          (in_batch_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (in_activation): LeakyReLU(negative_slope=0.01)
          (in_linear): Linear(in_features=1024, out_features=512, bias=True)
          (out_batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (out_linear): Linear(in_features=512, out_features=1024, bias=True)
          (out_activation): LeakyReLU(negative_slope=0.01)
          (dropout): Dropout(p=0.1, inplace=False)
   

In [9]:
out_d = pl_model.forward(**batch)
print_shapes(out_d)

pred_mzs - (172650,) - <class 'torch.Tensor'>
pred_logprobs - (172650,) - <class 'torch.Tensor'>
pred_batch_idxs - (172650,) - <class 'torch.Tensor'>


In [10]:
bins = torch.arange(0.5,1000.+0.5,0.5)
print(bins.min(), bins.max())
mzs = torch.tensor([0., 0.5, 0.51, 10., 100., 999., 999.5, 1000.-1e-6, 1000.])
idxs = torch.bucketize(mzs, bins, right=True)
print(idxs)

tensor(0.5000) tensor(1000.)
tensor([   0,    1,    1,   20,  200, 1998, 1999, 2000, 2000])


In [11]:
# # Init data module
# data_module = MassSpecDataModule(
#     dataset=ds,
#     split_pth=split_pth,
#     batch_size=8
# )

train_ds = Subset(ds, ds.entry_df["fold"] == "train")
val_ds = Subset(ds, ds.entry_df["fold"] == "val")
test_ds = Subset(ds, ds.entry_df["fold"] == "test")

dl_config = {
    "num_workers": 0,
    "batch_size": 8,
    "drop_last": False,
    "collate_fn": ds.collate_fn
}

train_dl = DataLoader(train_ds, shuffle=True, **dl_config)
val_dl = DataLoader(val_ds, shuffle=False, **dl_config)

In [12]:
wandb_entity = "adamoyoung"
wandb_project = "MSG"
wandb_name = "simulation_test"
# logger = pl.loggers.WandbLogger(
#     entity=wandb_entity,
#     project=wandb_project,
#     name=wandb_name,
#     tags=[],
#     log_model=False,
# )
logger = None

# Init trainer
trainer = pl.Trainer(
    accelerator="cpu", max_epochs=5, logger=logger, log_every_n_steps=1
)

# Train
trainer.fit(
    pl_model, 
    train_dataloaders=train_dl, 
    val_dataloaders=val_dl
)

GPU available: False, used: False


TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madamoyoung[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type       | Params
-------------------------------------
0 | model | FPFFNModel | 46.9 M
-------------------------------------
46.9 M    Trainable params
1.6 K     Non-trainable params
46.9 M    Total params
187.482   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/adamo.young/miniconda3/envs/MSG/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


TypeError: Cannot index by location index with a non-integer key