# Training with Negative Log Loss as Loss function. 
Implementation of the loss function described in [Lim et al. (2022) JCIM]('https://pubs.acs.org/doi/10.1021/acs.jcim.2c00041') for use on Poisson distributed (or negative binomial distributed) count data e.g. DNA-encoded library screening data.


Notable differences in how this type of model is setup: 
- this loss function require two target columns, "postive" and "negative". Both must be count (int) values
- do not use scaling, as the loss function takes the raw counts
- do not use output transforms
- use the softplus-based RegressionSoftplusFNN predictor to ensure positive preds
 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training.ipynb)

In [7]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

# Import packages

In [8]:
from pathlib import Path

import torch 
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [9]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['counts_pos', 'counts_neg']# list of names of the columns containing targets

## load data -- and make some synthetic data for testing.

In [10]:
df_input = pd.read_csv(input_path)


# creating some random count data for the NLogProbEnrichment metric
df_input['counts_pos'] = np.random.poisson(lam=6, size= df_input.shape[0])
df_input['counts_neg'] = np.random.poisson(lam=4, size= df_input.shape[0])

total_counts_pos= int(df_input['counts_pos'].sum())  # total number of positive samples
total_counts_neg = int(df_input['counts_neg'].sum())  # total number of negative samples

print(f"Total counts (positive samples): {total_counts_pos}")
print(f"Total counts (negative samples): {total_counts_neg}")

df_input

Total counts (positive samples): 643
Total counts (negative samples): 367


Unnamed: 0,smiles,lipo,counts_pos,counts_neg
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54,8,4
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18,6,4
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69,7,3
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37,5,2
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10,6,3
...,...,...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20,9,4
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04,8,6
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49,12,2
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20,7,3


## Get SMILES and targets

In [11]:

smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [12]:
smis[:5] # show first 5 SMILES strings

array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
      dtype=object)

In [13]:
ys[:5] # show first 5 targets

array([[8, 4],
       [6, 4],
       [7, 3],
       [5, 2],
       [6, 3]])

## Get molecule datapoints

In [14]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [15]:
# available split types
list(data.SplitType.keys())

['SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

Chemprop's `make_split_indices` function will always return a two- (if no validation) or three-length tuple.
Each member is a list of length `num_replicates`.
The inner lists then contain the actual indices for splitting.

The type signature for this return type is `tuple[list[list[int]], ...]`.

In [16]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


Chemprop's splitting function implements our preferred method of data splitting, which is random replication.
It's also possible to add your own custom cross-validation splitter, such as one of those as implemented in scikit-learn, as long as you get the data into the same `tuple[list[list[int]], ...]` data format with something like this:

In [17]:
from sklearn.model_selection import KFold

k_splits = KFold(n_splits=5)
k_train_indices, k_val_indices, k_test_indices = [], [], []
for fold in k_splits.split(mols):
    k_train_indices.append(fold[0])
    k_val_indices.append([])
    k_test_indices.append(fold[1])
k_train_data, _, k_test_data = data.split_data_by_indices(
    all_data, k_train_indices, None, k_test_indices
)

## Get MoleculeDataset
Recall that the data is in a list equal in length to the number of replicates, so we select the zero index of the list to get the first replicate.

In [18]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)

#! NO scaler - loss function takes the raw counts as input
#scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data[0], featurizer)
#val_dset.normalize_targets(scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer)

## Get DataLoader

In [19]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [20]:
mp = nn.BondMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [21]:
agg = nn.MeanAggregation()


## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [22]:
print(nn.PredictorRegistry)

ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}


In [23]:
output_transform = torch.nn.Softplus()

In [24]:
ffn = nn.predictors.RegressionFFN(criterion=nn.metrics.NLogProbEnrichment(n1 = total_counts_pos,
                                                                         n2 = total_counts_neg,
                                                                         method="sqrt", 
                                                                         zscale=1.0,
                                                                         zinterval=5.0,
                                                                         ),
                                      output_transform=output_transform,
)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [25]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [26]:
print(nn.metrics.MetricRegistry)

ClassRegistry {
    'mse': <class 'chemprop.nn.metrics.MSE'>,
    'mae': <class 'chemprop.nn.metrics.MAE'>,
    'rmse': <class 'chemprop.nn.metrics.RMSE'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
    'r2': <class 'chemprop.nn.metrics.R2Score'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}


In [27]:
metric_list = [nn.metrics.NLogProbEnrichment(n1=total_counts_pos, n2=total_counts_neg)] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [28]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): NLogProbEnrichment(n1=643, n2=367, method='sqrt', zscale=1.0, zinterval=5.0)
    (output_transform): Softplus(beta=1.0, threshold=20.0)
  )
  (X_d_transform): Identity()
  (metrics): Modu

# Set up trainer

In [29]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)


trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    #gradient_clip_val=1.0,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)

/compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-plws-arc02/envs/chemprop_dev/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-pl ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [30]:
trainer.fit(mpnn, train_loader, val_loader)

/compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-plws-arc02/envs/chemprop_dev/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /compchem/arc/users/dvik/repos/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-plws-arc02/envs/chemprop_dev/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn      

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-plws-arc02/envs/chemprop_dev/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 14.79it/s, train_loss_step=0.462, val_loss=1.120, train_loss_epoch=0.409]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 10.38it/s, train_loss_step=0.462, val_loss=1.120, train_loss_epoch=0.409]


# Test results

In [31]:
results = trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at /compchem/arc/users/dvik/repos/chemprop/examples/checkpoints/best-epoch=6-val_loss=1.05.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at /compchem/arc/users/dvik/repos/chemprop/examples/checkpoints/best-epoch=6-val_loss=1.05.ckpt
/compchem/arc/apps/anaconda/anaconda-2020.11_dkcn-plws-arc02/envs/chemprop_dev/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 115.15it/s]
