In [None]:
import torch
import numpy as np
import pandas as pd
import os

from data_factory import *
from model_factory import *

In [None]:
## This notebook is for training a model to predict solubility-related properties of molecules,
## as described in the blog post. Make sure to set target_name to the name of the property of interest.
## The five possible values for target_name are given below. This is the only variable you need to adjust
## in this notebook.

#target_name = "mol_weight"
#target_name = "polar_area"
#target_name = "num_H_donors"
#target_name = "num_rings"
target_name = "num_rot_bonds"


In [None]:
# Create instance of model for predicting solubility-related property.
# These models are typically denoted "model_ds" to distinguish them from the main "model"
# for predicting log solubility, which was pretrained in another notebook.
model_ds = ModelDS(target_name)

# Create train and validation datasets
train_dataset = DelaneyDataset(mode = "train")
valid_dataset = DelaneyDataset(mode = "valid")

criterion = nn.MSELoss()
metric = nn.MSELoss()

optimizer = torch.optim.Adam(params = model_ds.model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 10, verbose = True)

# Create instance of the main solubility model
model = ChemTransformer().cuda()
model_factory = ModelFactory(model)
# Load saved weights for the main model
model_factory.load()

In [None]:
# Train model_ds
epochs = 400

for e in range(epochs):
  train_loss = train_ds(model_ds, model_factory, train_dataset, criterion, optimizer)
  valid_ds(model_ds, model_factory, valid_dataset, metric)

  model_ds.model_factory.print_last_loss(epoch = e)

  # Save only the best results based on cross-validation
  model_ds.model_factory.save_best("loss_valid")

  scheduler.step(train_loss)