# Training

# Import packages

In [1]:
import pandas as pd
from pathlib import Path
import numpy as np

from lightning import pytorch as pl

from chemprop import data, featurizers, models, nn
from chemprop.utils import make_mol

# Change data inputs here

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,lipo
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10
...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20


## Get SMILES and targets

In [4]:
smis = df_input.loc[:, smiles_column].values
mols = []
for smi in smis:
    mols.append(make_mol(smi,False,False))
ys = [np.random.rand(mol.GetNumAtoms(), 1) for mol in mols]

In [5]:
smis[:5] # show first 5 SMILES strings

array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
      dtype=object)

In [6]:
ys[:5] # show first 5 molecule targets

[array([[0.68188908],
        [0.62999387],
        [0.96523715],
        [0.50685012],
        [0.33754227],
        [0.70763255],
        [0.55591384],
        [0.64739569],
        [0.28749773],
        [0.45436203],
        [0.12876126],
        [0.73623753],
        [0.00944813],
        [0.99915634],
        [0.82377346],
        [0.34107808],
        [0.29202741],
        [0.5674308 ],
        [0.45743398],
        [0.45504712],
        [0.05617857],
        [0.5222286 ],
        [0.85707514],
        [0.21466219]]),
 array([[8.08328255e-01],
        [2.09307610e-01],
        [9.45542778e-04],
        [4.90274432e-01],
        [7.50795213e-01],
        [6.71053773e-01],
        [4.47577353e-01],
        [6.61427405e-01],
        [2.43401908e-01],
        [4.23850060e-01],
        [5.55789454e-01],
        [7.12614913e-01],
        [7.55760409e-01],
        [9.50983625e-01],
        [9.26647048e-01],
        [1.83552608e-01],
        [3.64050285e-01],
        [8.18322612e-01],
  

## Get molecule datapoints

In [7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [8]:
# available split types
list(data.SplitType.keys())

['CV_NO_VAL',
 'CV',
 'SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

In [9]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

## Get AtomDataset

In [10]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.AtomDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.AtomDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)

test_dset = data.AtomDataset(test_data, featurizer)


## Get DataLoader

In [11]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [12]:
mp = nn.BondMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [13]:
print(nn.agg.AggregationRegistry)

ClassRegistry {
    'mean': <class 'chemprop.nn.agg.MeanAggregation'>,
    'sum': <class 'chemprop.nn.agg.SumAggregation'>,
    'norm': <class 'chemprop.nn.agg.NormAggregation'>
}


In [14]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [15]:
print(nn.PredictorRegistry)

ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}


In [16]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [17]:
ffn = nn.RegressionFFN(output_transform=output_transform)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [18]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [19]:
print(nn.metrics.MetricRegistry)

ClassRegistry {
    'mae': <class 'chemprop.nn.metrics.MAEMetric'>,
    'mse': <class 'chemprop.nn.metrics.MSEMetric'>,
    'rmse': <class 'chemprop.nn.metrics.RMSEMetric'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAEMetric'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSEMetric'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSEMetric'>,
    'r2': <class 'chemprop.nn.metrics.R2Metric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROCMetric'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRCMetric'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracyMetric'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Metric'>,
    'bce': <class 'chemprop.nn.metrics.BCEMetric'>,
    'ce': <class 'chemprop.nn.metrics.CrossEntropyMetric'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'sid': <class 'chemprop.nn.metrics.SIDMetric'>,
    'wasserstein': <class '

In [20]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## Constructs AtomMPNN

In [21]:
atom_mpnn = models.AtomMPNN(mp, agg, ffn, batch_norm, metric_list)

atom_mpnn

AtomMPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
)

# Set up trainer

In [22]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [23]:
trainer.fit(atom_mpnn, train_loader, val_loader)

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/brianli/Documents/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | train
3 | predictor       | RegressionFFN      | 90.6 K | train
4 | X_d_transform   | Identity           | 0   

Sanity Checking DataLoader 0:   0%|                       | 0/1 [00:00<?, ?it/s]

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
  return squared_errors[mask].mean().sqrt()


Epoch 0: 100%|██████████████████| 2/2 [00:00<00:00,  3.40it/s, train_loss=0.971]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 32.67it/s][A
Epoch 1: 100%|██| 2/2 [00:00<00:00,  9.16it/s, train_loss=0.997, val_loss=0.994][A
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 87.79it/s][A
Epoch 2: 100%|██| 2/2 [00:00<00:00,  9.36it/s, train_loss=0.956, val_loss=0.990][A
Validation: |                                             | 0/? [00:00<?, ?it/s

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|█| 2/2 [00:00<00:00,  5.66it/s, train_loss=0.969, val_loss=1.000]


# Test results

In [25]:
results = trainer.test(atom_mpnn, test_loader)

Testing DataLoader 0: 100%|███████████████████████| 1/1 [00:00<00:00, 57.42it/s]
