# Train BonDNet 

In this notebook, we show how to train the BonDNet graph neural network model for bond dissociation energy (BDE) prediction. We only show how to train on CPUs. See [train_bde_distributed.py](./) for a script for training on GPUs (a single GPU or distributed training on multiple GPUs). 

In [31]:
import torch
import time, wandb
import numpy as np 
from tqdm import tqdm
from sklearn.metrics import r2_score
from torchmetrics import R2Score
from torch.optim.lr_scheduler import ReduceLROnPlateau

from bondnet.model.metric import EarlyStopping
from bondnet.data.dataset import ReactionNetworkDatasetGraphs
from bondnet.data.dataloader import DataLoaderReactionNetwork
from bondnet.data.dataset import train_validation_test_split
#from bondnet.scripts.create_label_file import read_input_files
#from bondnet.model.metric import WeightedL1Loss, WeightedMSELoss
from bondnet.utils import seed_torch, pickle_dump, parse_settings
from bondnet.model.training_utils import (
    evaluate, 
    evaluate_classifier, 
    train, 
    train_classifier, 
    load_model, 
    evaluate_r2, 
    get_grapher
)
seed_torch()

In [32]:
print(torch.__version__)

1.12.1


## Dataset 

We work with a small dataset consisting of 200 BDEs for netural and charged molecules. The dataset is specified in three files:
- `molecules.sdf` This file contains all the molecules (both reactants and products) in the bond dissociation reactions. The molecules are specified in SDF format. 
- `molecule_attributes.yaml` This file contains extra molecular attributes (charges here) for molecules given in `molecules.sdf`. Some molecular attributes can be inferred from its SDF block, and they are overrode by the attributes specified in the `molecule_attributes.yaml` file.  
- `reactions.csv` This file list the bond dissociation reations formed by the molecules given in `molecules.sdf`. Each line lists the reactant, products, and BDE of a reaction. The reactant and products are specified by their index in `molecules.sdf`. 

See [here](./examples/train) for the three files used in this notebook. 

#### Grapher 

BondNet is graph neutral network model that takes atom features (e.g. atom type), bond features (e.g. whether a bond is in a ring), and global features (e.g. total charge) as input. We extract the features for a molecule using a grapher.

#### Read dataset 

Let's now read the dataset and featurize the molecules using the above defined grapher. The dataset is split into a training set (80%), validation set (10%), and test set (10%). We will train our model using the training set, stop the training using the validation set, and report error on the test set. 

In [33]:
settings_file = './settings.txt'
best = 1e10
feature_names = ["atom", "bond", "global"]
dict_train = parse_settings(settings_file)
path_mg_data = "../../../dataset/mg_dataset/20220826_mpreact_reactions.json"
    

using the following settings:
----------------------------------------
Small Dataset?: True
restore: True
distributed: False
on gpu: False
filter species? [2, 3]
num gpu: 1
xyz feeaturizer: False
hyperparam save file: ./hyper.pkl
dataset state dict: home/santiagovargas/Documents/Dataset/mg/dataset_state_dict.pkl
model dir /home/santiagovargas/Documents/Dataset/mg/
classifier False
batch size: 256
epochs: 100
lr: 0.000100
weight decay: 0.000
early_stop: True
scheduler: True
transfer_epochs: 100
transfer: True
loss: False
categories: 3
embedding size: 24
fc layers: 2
fc hidden layer: [128, 64]
gated layers: 3
gated hidden layers: [64, 64, 64]
num lstm iters: 6
num lstm layer: 3
gated fc layers: 2
fc activation: ReLU
fc batch norm: 0
fc dropout: 0.00
gated activation: ReLU
gated dropout: 0.10
gated batch norm: True
gated graph norm: 0
gated resid: True
----------------------------------------


In [34]:
device = None
if(device == None):
    if dict_train["on_gpu"]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dict_train["gpu"] = device
    else:
        device = torch.device("cpu")
        dict_train["gpu"] = "cpu"
else: 
    dict_train["gpu"] = device
    device = torch.device("cpu")
path_mg_data = '../dataset/mg_dataset/20220613_reaction_data.json'
dataset = ReactionNetworkDatasetGraphs(
    grapher=get_grapher(), 
    file=path_mg_data, 
    out_file="./", 
    target = 'ts', 
    classifier = False, 
    classif_categories=5, 
    debug = True,
    filter_species=dict_train["filter_species"],
    device =  device
)

using bond featurizer w/xyz coords
reading file from: ../dataset/mg_dataset/20220613_reaction_data.json
rxn raw len: 100
Program finished in 1.1245083029998568 seconds
.............failures.............
reactions len: 84
valid ind len: 84
bond break fail count: 		13
default fail count: 		3
sdf map fail count: 		0
product bond fail count: 	0
about to group and organize
number of grouped reactions: 84
features: 256
labels: 84
molecules: 256
constructing graphs & features....
number of graphs valid: 256
number of graphs: 256


In [35]:
trainset, valset, testset = train_validation_test_split(dataset, validation=0.2, test=0.2)
dataset_loader = DataLoaderReactionNetwork(dataset, batch_size=100,shuffle=True)
train_loader = DataLoaderReactionNetwork(trainset, batch_size=100,shuffle=True)
val_loader = DataLoaderReactionNetwork(valset, batch_size=len(valset), shuffle=False)
test_loader = DataLoaderReactionNetwork(testset, batch_size=len(testset), shuffle=False)
#test_ind = 3
#elements = [i['name'] for i in dataset.pandas_df.iloc[test_ind]['reactant_molecule_graph']['molecule']['sites']]
#print(elements)
#print(len(dataset.molecules))

## Model 

We create the BonDNet model by instantiating the `GatedGCNReactionNetwork` class and providing the parameters defining the model structure. 
- `embedding_size` The size to unify the atom, bond, and global feature length.
- `gated_num_layers` Number of graph to graph module to learn molecular representation. 
- `gated_hidden_size` Hidden layer size in the graph to graph modules. 
- `gated_activation` Activation function appleid after the hidden layers in the graph to graph modules. 
- `fc_num_layers` Number of hidden layers of the fully connected network to map reaction feature to the BDE. The reaction feature is obtained as the differece of the features between the products and the reactant. 
- `fc_hidden_size` Size of the hidden layers. 
- `fc_activation` Activation function applied after the hidden layers. 

There are other arguments (e.g. residual connection, dropout ratio, batch norm) that can be specified to fine control the model. See the documentation of the `GatedGCNReactionNetwork` for more information.  

In [42]:
from bondnet.model.gated_reaction_network_graph import GatedGCNReactionNetwork

from bondnet.model.metric import EarlyStopping
from bondnet.data.dataset import ReactionNetworkDatasetGraphs
from bondnet.data.dataloader import DataLoaderReactionNetwork
from bondnet.data.dataset import train_validation_test_split
#from bondnet.scripts.create_label_file import read_input_files
#from bondnet.model.metric import WeightedL1Loss, WeightedMSELoss
from bondnet.utils import seed_torch, pickle_dump, parse_settings
from bondnet.model.training_utils import (
    evaluate, 
    evaluate_classifier, 
    train, 
    train_classifier, 
    load_model, 
    evaluate_r2, 
    get_grapher
)


model = GatedGCNReactionNetwork(
    in_feats=dataset.feature_size,
    embedding_size=8,
    gated_num_layers=2,
    gated_hidden_size=[32, 32],
    gated_activation="ReLU",
    fc_num_layers=2,
    fc_hidden_size=[64, 32],
    fc_activation='ReLU'
)
#print(dict(model.named_parameters()))

## Train the model 

Before going to the main training loop, we define two functions: `train` and `evaluate` that will be used later. 

The `train` function optimizes the model parameters for an epoch. We note that our target BDEs are centered and then normalized by the standard deviation (done in the `ReactionNetworkDataset`.) So to measure the mean absolute error, we need to multiply the standard deviation back. This is acheived achieved by the `WeightedL1Loss` function passed as `metric_fn`.   

Now, we have all the ingredients to train the model. 

We optimize the model parameters by minimizing a mean squared error loss function using the `Adam` optimizer with a learning rate of `0.001`. Here we train the model for 20 epochs; save the best performing model that gets the smallest mean absolute error on the validation set; and finally test model performance on the test set. 

In [44]:
dict_train['in_feats'] = dataset.feature_size
#dict_train["learning_rate"] = 0.1
model, optimizer, optimizer_transfer = load_model(dict_train)

for epoch in tqdm(range(dict_train['transfer_epochs'])):
    loss_transfer, train_acc_transfer = train(
        model, 
        feature_names, 
        train_loader, 
        optimizer_transfer, 
        device = dict_train["gpu"]
    )
    val_acc_transfer = evaluate(
        model, 
        feature_names, 
        val_loader, 
        device = dict_train["gpu"]
    )

    train_r2 = evaluate_r2(
        model, 
        feature_names, 
        val_loader, 
        device = dict_train["gpu"]
        )

100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


In [45]:
for epoch in range(dict_train['epochs']):

        
    loss, train_acc = train(
        model, 
        feature_names, 
        train_loader, 
        optimizer, 
        device = dict_train["gpu"]
        )
    # evaluate on validation set
    val_acc = evaluate(
        model, 
        feature_names, 
        val_loader, 
        device = dict_train["gpu"]
        )
    val_r2 = evaluate_r2(
        model, 
        feature_names, 
        val_loader, 
        device = dict_train["gpu"]
        )
        
    train_r2 = evaluate_r2(
        model, 
        feature_names, 
        train_loader, 
        device = dict_train["gpu"]
        )

    print(
        "{:5d}   {:12.6e}   {:12.2e}   {:12.6e}   {:.2f}   {:.2f}".format(
            epoch, loss, train_acc, val_acc, val_r2, train_r2
        )
    )

    0   2.699188e+00       1.22e-01   6.267557e-01   -0.34   0.89
    1   7.060315e+00       1.65e-01   6.558151e-01   -0.43   0.96
    2   2.651794e+00       1.34e-01   6.666546e-01   -0.46   0.95
    3   2.853068e+00       1.42e-01   6.604475e-01   -0.44   0.95
    4   2.967622e+00       1.38e-01   6.486682e-01   -0.40   0.95
    5   2.870139e+00       1.37e-01   6.392334e-01   -0.37   0.95
    6   2.649525e+00       1.35e-01   6.342085e-01   -0.35   0.96
    7   2.374305e+00       1.25e-01   6.330262e-01   -0.35   0.96
    8   2.184971e+00       1.12e-01   6.341418e-01   -0.35   0.97


KeyboardInterrupt: 