In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.notebook import tqdm
from pprint import pprint
from torch.utils.data import Subset
import pytorch_lightning as pl

In [3]:
from massspecgym.data.datasets import SimulationDataset
from massspecgym.transforms import SpecToMzsInts, MolToPyG, StandardMeta, MolToFingerprints
from massspecgym.simulation_utils.misc_utils import print_shapes
from massspecgym.models.simulation.fp_ffn import FPFFNSimulationMassSpecGymModel

In [4]:
# tsv_pth = "../data/MassSpecGym_small.tsv"
# tsv_pth = "../data/MassSpecGym_3_small.tsv"
tsv_pth = "../data/MassSpecGym_3.tsv"
frag_pth = None
meta_keys = ["adduct","precursor_mz","instrument_type","collision_energy"]
fp_types = ["morgan","maccs","rdkit"]
adducts = ["[M+H]+"]
instrument_types = ["QTOF","QFT","Orbitrap","ITFT"]
max_collision_energy = 200. # arbitrary
spec_transform = SpecToMzsInts(
    mz_from=10.,
    mz_to=1500.
)
mol_transform = MolToFingerprints(
    fp_types=fp_types
)
meta_transform = StandardMeta(
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy
)

In [5]:
pl_model_d = dict(
    # features
    fp_types=fp_types,
    adducts=adducts,
    instrument_types=instrument_types,
    max_collision_energy=max_collision_energy,
    # input
    metadata_insert_location="mlp",
    collision_energy_insert_size=16,
    adduct_insert_size=16,
    instrument_type_insert_size=16,
    ints_transform="sqrt", #"none",
    # output
    mz_max=1500., # 1000.,
    mz_bin_res=0.1, # 0.01,
    # model
    mlp_hidden_size=1024,
    mlp_dropout=0.1,
    mlp_num_layers=3,
    mlp_use_residuals=True,
    ff_prec_mz_offset=500,
    ff_bidirectional=True,
    ff_output_map_size=128,
    # optimization
    lr=1e-3,
    weight_decay=1e-7,
    train_sample_weight=True,
    eval_sample_weight=True
)
pl_model = FPFFNSimulationMassSpecGymModel(**pl_model_d)

In [6]:
ds = SimulationDataset(
    tsv_pth=tsv_pth,
    frag_pth=None,
    meta_keys=meta_keys,
    spec_transform=spec_transform,
    mol_transform=mol_transform,
    meta_transform=meta_transform,
    frag_transform=None,
    cache_feats=False)

dl = DataLoader(
    ds,
    num_workers=8, # 0,
    shuffle=False,
    batch_size=8,
    drop_last=False,
    collate_fn=ds.collate_fn)

  entry_df = pd.read_csv(self.tsv_pth, sep="\t")


In [7]:
# for batch in tqdm(iter(dl),total=len(dl)):
#     if batch["spec_batch_idxs"].unique().shape[0] != batch["spec_batch_idxs"].max()+1:
#         print(batch["spec_id"])
#         print(batch["spec_batch_idxs"].unique())
#         raise ValueError

In [8]:
batch = next(iter(dl))
print_shapes(batch)
print(batch["spec_mzs"].max())
print(batch["precursor_mz"].max())

spec_mzs - (92,) - <class 'torch.Tensor'>
spec_ints - (92,) - <class 'torch.Tensor'>
fps - (8, 4263) - <class 'torch.Tensor'>
precursor_mz - (8,) - <class 'torch.Tensor'>
adduct - (8,) - <class 'torch.Tensor'>
instrument_type - (8,) - <class 'torch.Tensor'>
collision_energy - (8,) - <class 'torch.Tensor'>
spec_id - (8,) - <class 'torch.Tensor'>
weight - (8,) - <class 'torch.Tensor'>
spec_batch_idxs - (92,) - <class 'torch.Tensor'>
tensor(288.1230)
tensor(288.1225)


In [9]:
out_d = pl_model.forward(**batch)
print_shapes(out_d)

pred_mzs - (13529,) - <class 'torch.Tensor'>
pred_logprobs - (13529,) - <class 'torch.Tensor'>
pred_batch_idxs - (13529,) - <class 'torch.Tensor'>


In [10]:
bins = torch.arange(0.5,1000.+0.5,0.5)
print(bins.min(), bins.max())
mzs = torch.tensor([0., 0.5, 0.51, 10., 100., 999., 999.5, 1000.-1e-6, 1000.])
idxs = torch.bucketize(mzs, bins, right=True)
print(idxs)

tensor(0.5000) tensor(1000.)
tensor([   0,    1,    1,   20,  200, 1998, 1999, 2000, 2000])


In [11]:
# # Init data module
# data_module = MassSpecDataModule(
#     dataset=ds,
#     split_pth=split_pth,
#     batch_size=8
# )

entry_df = ds.entry_df

train_ds = Subset(ds, entry_df[entry_df["fold"]=="train"].index)
val_ds = Subset(ds, entry_df[entry_df["fold"]=="val"].index)
test_ds = Subset(ds, entry_df[entry_df["fold"]=="test"].index)

dl_config = {
    "num_workers": 4,
    "batch_size": 64,
    "drop_last": False,
    "collate_fn": ds.collate_fn
}

train_dl = DataLoader(train_ds, shuffle=True, **dl_config)
val_dl = DataLoader(val_ds, shuffle=False, **dl_config)

In [12]:
wandb_entity = "adamoyoung"
wandb_project = "MSG"
wandb_name = "simulation_test_full"
logger = pl.loggers.WandbLogger(
    entity=wandb_entity,
    project=wandb_project,
    name=wandb_name,
    tags=[],
    log_model=False,
)
# logger = None

# Init trainer
trainer = pl.Trainer(
    accelerator="cpu", max_epochs=10, logger=logger, log_every_n_steps=1
)

# Train
trainer.fit(
    pl_model, 
    train_dataloaders=train_dl, 
    val_dataloaders=val_dl
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madamoyoung[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type       | Params
-------------------------------------
0 | model | FPFFNModel | 13.8 M
-------------------------------------
13.8 M    Trainable params
1.6 K     Non-trainable params
13.8 M    Total params
55.221    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]