# Training with `NLogProbEnrichment`

This notebook demonstrates how to use the loss function described in [Lim et al. (2022) JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.2c00041) for use on Poisson distributed (or negative binomial distributed) count data e.g. DNA-encoded library screening data.

Notable differences in how this type of model is setup compared to typical Chemprop training:
- this loss function requires two target columns, "postive" and "negative". Both must be count (int) values
- do not use scaling, as the loss function takes the raw counts
- output transform for the FNN must be set to SoftPlus
- the NLogProbEnrichment metric must be used

This notebook is adapted from the [regular training demo notebook](https://chemprop.readthedocs.io/en/latest/training.html), which may be helpful as a reference.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

## Initial Setup

We'll follow the typical procedure for importing the necessary packages and defining some overall settings related to the data.

In [2]:
from pathlib import Path

import torch 
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from chemprop import data, featurizers, models, nn

In [3]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv"
num_workers = 0
smiles_column = 'smiles'

The first big difference between `NLogProbEnrichment` training and convention Chemprop training is that the target columns must be exactly two: the count of positive and count of negative samples.
For this demo these columns aren't actually in the dataset, we'll just randomly generate the data for demonstration purposes.

In [4]:
target_columns = ['counts_pos', 'counts_neg']

In [5]:
df_input = pd.read_csv(input_path)


# creating some random count data for the NLogProbEnrichment metric
df_input['counts_pos'] = np.random.poisson(lam=6, size= df_input.shape[0])
df_input['counts_neg'] = np.random.poisson(lam=4, size= df_input.shape[0])

total_counts_pos= int(df_input['counts_pos'].sum())  # total number of positive samples
total_counts_neg = int(df_input['counts_neg'].sum())  # total number of negative samples

print(f"Total counts (positive samples): {total_counts_pos}")
print(f"Total counts (negative samples): {total_counts_neg}")

df_input

Total counts (positive samples): 624
Total counts (negative samples): 417


Unnamed: 0,smiles,lipo,counts_pos,counts_neg
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54,8,5
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18,3,3
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69,8,2
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37,6,3
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10,6,8
...,...,...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20,3,2
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04,4,3
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49,6,9
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20,3,1


In [6]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

## Training Prepation

Now we follow the typical procedure to set up our data and neural network, with just a few small changes to faciliate `NLogProbEnrichment` loss.

In [7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

In [8]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


When creating the `MoleculeDataset` class one would often rescale the target variables like this:

```python
scaler = train_dset.normalize_targets()
val_dset.normalize_targets(scaler)
```

We do NOT do this here, since the loss function operates on the counts directly.

In [9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)

In [10]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

In [11]:
mp = nn.BondMessagePassing()

In [12]:
agg = nn.MeanAggregation()

The output of our FFN must go through the `Softplus` activation function, which we will use as our `output_transform`.

In [13]:
output_transform = torch.nn.Softplus()

We'll then initialize the actual loss function:

In [14]:
criterion = nn.metrics.NLogProbEnrichment(
    n1=total_counts_pos,
    n2=total_counts_neg,
    method="sqrt", 
    zscale=1.0,
    zinterval=5.0,
)

And finally build the FFN (note that we pass our `Softplus` as the `output_transform` - this is the only way to train with `NLogProbEnrichment` loss):

In [15]:
ffn = nn.predictors.RegressionFFN(criterion=criterion, output_transform=output_transform)

In [16]:
batch_norm = True

In [17]:
metric_list = [nn.metrics.NLogProbEnrichment(n1=total_counts_pos, n2=total_counts_neg)]

In [18]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): NLogProbEnrichment(n1=624, n2=417, method='sqrt', zscale=1.0, zinterval=5)
    (output_transform): Softplus(beta=1.0, threshold=20.0)
  )
  (X_d_transform): Identity()
  (metrics): Module

## Training and Inference

In [19]:
checkpointing = ModelCheckpoint(
    "checkpoints",
    "best-{epoch}-{val_loss:.2f}",
    "val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(mpnn, train_loader, val_loader)

/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /home/jackson/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | train
3 | predictor       | RegressionFFN      | 90

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 32.95it/s, train_loss_step=0.449, val_loss=0.632, train_loss_epoch=0.282]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 18.67it/s, train_loss_step=0.449, val_loss=0.632, train_loss_epoch=0.282]


In [None]:
results = trainer.test(dataloaders=test_loader, weights_only=False)  # weights_only=False is only required for lighting 2.6+

Restoring states from the checkpoint path at /home/jackson/chemprop/examples/checkpoints/best-epoch=19-val_loss=0.63.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jackson/chemprop/examples/checkpoints/best-epoch=19-val_loss=0.63.ckpt
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 262.47it/s]
