# Training a ML Model using the serenityff-charge package

## Import necessary packages and set file paths:

The [Trainer](../gnn/training/trainer.py) is a class that facilitates the training of an ML model and the [ChargeCorrectedNodeWiseAttentiveFP](../gnn/utils/model.py) is the model, that was trained in this work.

In [2]:
from shutil import rmtree
import pandas as pd
import torch
from rdkit import Chem

from serenityff.charge.gnn.training.trainer import Trainer, ChargeCorrectedNodeWiseAttentiveFP


sdf_file = "../data/example.sdf"
pt_file = "../data/example_graphs.pt"
state_dict_path = "../data/example_state_dict.pt"
model_path = "../data/example_model.pt"

## Instantiate the trainer

Here you can specify whether you want to train your model on a CUDA enabled gpu or on the cpu, simply by setting `device = "cuda" / "cpu"`.

The loss function for the training is set to the `torch.nn.functional.mse_loss` by default. Provide any other callable for loss calculation with the kwarg `_loss_function_`.

In [3]:
trainer = Trainer(device = "cpu")

## Initialize model, optimizer and setting output path

### Load existing model

You can load an already existing `ChargeCorrectedNodeWiseAttentiveFP` by loading a saved model or a saved state_dict from a previously trained model. 

This of course only works if the model or the state dict originate or are a `ChargeCorrectedNodeWiseAttentiveFP`.



In [4]:
#trainer.model = state_dict_path
#trainer.model = model_path

## Train new model

You can also train a new model.

Also, for the trainer to work, you have to provide it an optimizer, and a string to where trained model and loss files should be saved.

In [5]:
trainer.model = ChargeCorrectedNodeWiseAttentiveFP()
trainer.optimizer = torch.optim.Adam(trainer.model.parameters(), lr = 10**-5)
trainer.save_prefix = "./training/example"

### Generate molecular graph from .sdf files

There are two possibilities to load/generate molecular graphs. Either, you load in previously generated graphs by using `Trainer.load_graphs_from_pt()` or you generate them from an .sdf file containing molecules and their charges, that are used for the training. See [prep_sdf_input.ipynb](prep_sdf_input.ipynb) for the preparation of said sdf files.

In [6]:
trainer.load_graphs_from_pt(pt_file=pt_file)
trainer.gen_graphs_from_sdf(sdf_file=sdf_file)
trainer.prepare_training_data(train_ratio=0.8)

### Train a new model

To train a model, use the function `train_model()`. Specify epochs and batch size (defaults to 64).

This function saves you the train and eval losses in seperate files as well as the model's state dict after the training.

In [7]:
epochs = 30
train_loss, eval_loss = trainer.train_model(epochs=epochs)

## Predict charges for test molecule

To predict values for known or unknown molecules, use the `predict()` function. It takes either rdkit molecules or molecular graphs (or Sequences of them) created as shown above as input. 

In [9]:
mol = Chem.MolFromSmiles("c1ccccc1")
mol = Chem.AddHs(mol)
trainer.predict(mol)

[[[-0.0001327507197856903],
  [-0.0001327507197856903],
  [-0.0001327507197856903],
  [-0.0001327507197856903],
  [-0.0001327507197856903],
  [-0.0001327507197856903],
  [0.00013274699449539185],
  [0.00013274699449539185],
  [0.00013274699449539185],
  [0.00013274699449539185],
  [0.00013274699449539185],
  [0.00013274699449539185]]]

Remove all the generated files by this cool example notebook.

In [10]:
rmtree("./training/")