<a href="https://colab.research.google.com/github/pranay8297/deep-learning-projects/blob/master/ChemBERTA_Arrakis_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from ipdb import set_trace as st
import sklearn
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from torchmetrics import MeanAbsolutePercentageError
from ipdb import set_trace as st
import wandb
import os

import numpy as np
import pandas as pd
import json

from typing import List
from deepchem.molnet import load_bbbp, load_clearance, load_clintox, load_delaney, load_hiv, load_qm7, load_tox21
from rdkit import Chem
from transformers import RobertaTokenizerFast
import torch
import torch.nn as nn

from transformers import PreTrainedModel, RobertaModel
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from tqdm import tqdm

from torch.utils.data import DataLoader
from pytorch_metric_learning import losses
from sklearn.metrics import roc_auc_score
from torch.utils.data import random_split
from transformers import RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from transformers.file_utils import ModelOutput
from dataclasses import dataclass
from torch.nn import CrossEntropyLoss, MSELoss
from typing import Dict, List, Optional, Tuple
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from nlp import load_dataset

from transformers.data.data_collator import InputDataClass
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_callback import EarlyStoppingCallback

from  torch.utils.data import random_split
from scipy.stats import pearsonr, spearmanr, kendalltau
from bertviz import head_view, model_view
TOKENIZERS_PARALLELISM = False

In [None]:
smiles_tokenizer = RobertaTokenizerFast.from_pretrained("seyonec/SMILES_tokenized_PubChem_shard00_160k", max_len = 512)

class MyCustomException(Exception):
    pass

def multitask_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
    """
    Very simple data collator that simply collates batches of dict-like objects and performs special handling for potential keys named label
    """
    if not isinstance(features[0], (dict, BatchEncoding)):
        features = [vars(f) for f in features]

    first = features[0]
    batch = {}

    if "label" in first and first["label"] is not None:
        batch["labels"] = torch.stack([f["label"] for f in features])

    # Handling of all other possible keys.
    # Again, we will use the first element to figure out which key/values are not None for this model.
    for k, v in first.items():
        if k != "label" and v is not None and not isinstance(v, str):
            if isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])
            else:
                batch[k] = torch.tensor([f[k] for f in features])

    return batch

def preprocess(line, tokenizer, block_size, text_name= 'SMILES', label_names = []):

    def _clean_property(x):
        return float(x)
    # st()
    smiles = line[text_name]
    labels = [_clean_property(line[label_name]) for label_name in label_names]

    batch_encoding = tokenizer(
        smiles,
        add_special_tokens = True,
        truncation = True,
        padding = "max_length",
        max_length=block_size,
    )
    batch_encoding["label"] = labels
    batch_encoding = {k: torch.tensor(v) for k, v in batch_encoding.items()}

    return batch_encoding

def get_data_files(train_path):
    if os.path.isdir(train_path):
        return [
            os.path.join(train_path, file_name) for file_name in os.listdir(train_path)
        ]
    elif os.path.isfile(train_path):
        return train_path

    raise ValueError("Please pass in a proper train path")

class RegressionDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, file_path: str, block_size = 515, x_col = None, y_col = []):
        super().__init__()
        print("init dataset")
        self.tokenizer = tokenizer
        self.file_path = file_path
        self.block_size = block_size

        data_files = get_data_files(file_path)
        self.dataset = load_dataset("csv", data_files=data_files)["train"]
        dataset_columns = list(self.dataset.features.keys())
        self.smiles_column = x_col if x_col else dataset_columns[-1]
        self.label_columns = y_col if len(y_col) > 0 else dataset_columns[1:-1]
        self.num_labels = len(self.label_columns)

        print("Loaded Dataset")
        self.len = len(self.dataset)
        print("Number of lines: " + str(self.len))
        print("Block size: " + str(self.block_size))

    def __len__(self):
        return self.len

    def __getitem__(self, i):
        dd = preprocess(self.dataset[i], self.tokenizer, self.block_size, self.smiles_column, self.label_columns)
        return dd

class RobertaForRegression(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_missing = ["position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.num_outputs = config.num_outputs if config.num_outputs else 1

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.regression_heads = nn.ModuleList([RobertaRegressionHead(config) for i in range(self.num_outputs)])
        self.loss_fct = MSELoss()
        self.init_weights()

        self.is_classification = True if config.is_classification else False


        self.do_norm = False
        self.mean = None
        self.std = None

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = (
            outputs.last_hidden_state
        )

        logits = [rh(sequence_output) for rh in self.regression_heads]

        if labels is None:
            return self.unnormalize_logits(logits)

        if labels is not None:
            normalized_labels = self.normalize_logits(labels)

            loss = self.loss_fct(logits[0].squeeze(), normalized_labels[0].squeeze())
            for i in range(normalized_labels.shape[1] - 1):
              loss = loss + self.loss_fct(logits[i].squeeze(), normalized_labels[i].squeeze())

            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output

        return RegressionOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def normalize_logits(self, tensor):
        if self.do_norm:
            return [(tensor[0] - self.mean) / self.std]
        return tensor#(tensor - self.norm_mean) / self.norm_std

    def unnormalize_logits(self, tensor):
        if self.do_norm:
            return [(tensor[0] * self.std) + self.mean]
        return tensor

    def freeze_unfreeze(self, action = 'f'):
        # action can be either 'f' - Freeze or 'u' - Unfreeze
        flag_to_set = False
        if action == 'u':
            flag_to_set = True
        for param in self.roberta.parameters():
            param.requires_grad = flag_to_set

    def save_model(self, roberta_path, head_path = None):

        torch.save(self.roberta.state_dict() , roberta_path)
        if head_path:
            torch.save(self.regression_heads.state_dict() , head_path)

        print(f'Saved the model at : {roberta_path} and {head_path}')

    def load_model(self, roberta_path, head_path = None):
        # if there is name, load from appropriate path, if not name, check for path and load it. If not both, then throw an error
        self.roberta.load_state_dict(torch.load(roberta_path))
        if head_path:
            self.regression_heads.load_state_dict(torch.load(head_path))

        print(f'Sucesfully loaded model from : {roberta_path}, {head_path}')

class RobertaRegressionHead(nn.Module):
    """Head for multitask regression models."""

    def __init__(self, config):

        self.is_classification = config.is_classification

        super(RobertaRegressionHead, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.out_proj(x)

        if self.is_classification:
            x = torch.sigmoid(x)
        return x

@dataclass
class RegressionOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

def train_regression(model, train_dataloader, val_dataloader, learning_rate, epochs, wd = 1e-02):

    hyperparameters = {
        "batch_size": train_dataloader.batch_size,
        "num_epochs": epochs,
    }
    wandb.config.update(hyperparameters)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion_1 = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr = learning_rate, weight_decay = wd)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = learning_rate, steps_per_epoch = len(train_dataloader), epochs = epochs)

    if use_cuda:
        criterion_1 = criterion_1.cuda()

    for epoch_num in range(epochs):

        total_loss_train = []
        c = 0
        for dd in tqdm(train_dataloader):
            try:
                mask = dd['attention_mask']
                input_id = dd['input_ids']
                train_label = dd['label'].float()

                if(input_id.shape[0] != train_dataloader.batch_size):
                    continue

                if use_cuda:
                    input_id = input_id.to(device)
                    train_label = train_label.to(device)
                outputs = model(input_id)
                # st()
                loss = criterion_1(outputs[0].squeeze(), train_label[:, 0].squeeze())
                for i in range(train_label.shape[1] - 1):
                  loss = loss + criterion_1(outputs[i].squeeze(), train_label[:, i].squeeze())
                # outputs = torch.stack(outputs, dim = 1).squeeze()
                # loss = criterion_1(outputs, train_label)
                total_loss_train.append(loss.item())
                c += 1
                if (c%10 == 0):
                    wandb.log({
                      "Train Loss Running Mean": np.mean(total_loss_train)
                    })

                model.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            except Exception as e:
                st()
                print(e)



        total_loss_val = []

        with torch.no_grad():
            c = 0
            for vdd in val_dataloader:
                try:
                    if use_cuda:

                        val_label = vdd['label'].float()
                        val_label = val_label.to(device)
                        input_id = vdd['input_ids'].to(device)

                    if(input_id.shape[0] != val_dataloader.batch_size):
                        continue

                    outputs = model(input_id)

                    loss = criterion_1(outputs[0].squeeze(), val_label[:, 0].squeeze())
                    for i in range(train_label.shape[1] - 1):
                        loss = loss + criterion_1(outputs[i].squeeze(), val_label[:, i].squeeze())

                    # outputs = torch.stack(outputs, dim = 1).squeeze()
                    # loss = criterion_1(outputs, val_label)
                    c += 1
                    total_loss_val.append(loss.item())
                    if (c%3 == 0):
                        wandb.log({
                          "Validation Loss Running Mean": np.mean(total_loss_val)
                        })
                except Exception as e:
                    st()
                    print(e)



        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {np.mean(total_loss_train): .3f} \
            | Val Loss: {np.mean(total_loss_val): .3f}')

def train_classification(model, train_dataloader, val_dataloader, learning_rate, epochs, wd = 1e-02):

    hyperparameters = {
        "batch_size": train_dataloader.batch_size,
        "num_epochs": epochs,
    }
    wandb.config.update(hyperparameters)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if model.device != device:
        model.to(device)
    criterion_1 = nn.BCELoss()
    optimizer = AdamW(model.parameters(), lr = learning_rate, weight_decay = wd)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = learning_rate, steps_per_epoch = len(train_dataloader), epochs = epochs)

    if use_cuda:
        criterion_1 = criterion_1.cuda()

    for epoch_num in range(epochs):

        total_loss_train = []
        c = 0
        for dd in tqdm(train_dataloader):
            try:
                mask = dd['attention_mask']
                input_id = dd['input_ids']
                train_label = dd['label'].float()

                if(input_id.shape[0] != train_dataloader.batch_size):
                    continue

                if use_cuda:
                    input_id = input_id.to(device)
                    train_label = train_label.to(device)
                outputs = model(input_id)

                loss = criterion_1(outputs[0].squeeze(), train_label[:, 0].squeeze())
                for i in range(1, train_label.shape[1]):
                  loss = loss + criterion_1(outputs[i].squeeze(), train_label[:, i].squeeze())

                total_loss_train.append(loss.item())
                c += 1
                if (c%10 == 0):
                    wandb.log({
                      "Train Loss Running Mean": np.mean(total_loss_train)
                    })

                model.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            except Exception as e:
                st()
                print(e)

        total_loss_val = []
        with torch.no_grad():
            c = 0
            for vdd in val_dataloader:
                try:
                    if use_cuda:

                        val_label = vdd['label'].float()
                        val_label = val_label.to(device)
                        input_id = vdd['input_ids'].to(device)

                    if(input_id.shape[0] != val_dataloader.batch_size):
                        continue

                    outputs = model(input_id)

                    loss = criterion_1(outputs[0].squeeze(), val_label[:, 0].squeeze())
                    for i in range(1, val_label.shape[1]):
                        loss = loss + criterion_1(outputs[i].squeeze(), val_label[:, i].squeeze())

                    c += 1
                    total_loss_val.append(loss.item())
                    if (c%3 == 0):
                        wandb.log({
                          "Validation Loss Running Mean": np.mean(total_loss_val)
                        })
                except Exception as e:
                    st()
                    print(e)



        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {np.mean(total_loss_train): .3f} \
            | Val Loss: {np.mean(total_loss_val): .3f}')

## **Functions and classes for using this models.**

In [None]:
wandb.init(project = "Deleney Test")

In [None]:
def create_model(num_outputs = 1, is_classification = False):
  config = RobertaConfig(
        vocab_size = 600,m
        max_position_embeddings = 515,
        num_attention_heads = 6,
        num_hidden_layers = 6,
        type_vocab_size = 515,
        num_labels = 1,
        is_gpu = True,
        num_outputs = num_outputs,
        position_embedding_type = 'random',
        is_classification = is_classification
    )
  model = RobertaForRegression(config)
  return model

def load_model(model, roberta_path, head_path):
  model.load_model(roberta_path, head_path)
  return model

def tokenize_smiles(smiles_string):
  data = smiles_tokenizer(
        smiles_string,
        add_special_tokens = True,
        truncation = True,
        padding = "max_length",
        max_length = 515,
    )
  return data['input_ids']

# Use batch_size as 2 if you are on local machine. On GPU server, you can use batch_size up to 32
def prepare_batch_loader(smiles_list = [], batch_size = 16):

  tokenized_smiles = [tokenize_smiles(i) for i in smiles_list]
  ds = torch.Tensor(tokenized_smiles).long()
  return DataLoader(ds, batch_size = batch_size)

def predict(model, dl):
  out = torch.Tensor()
  model.eval()
  with torch.no_grad():
    for x in dl:
      if model.device != x.device:
        x = x.to(model.device)
      outputs = model(x)
      yhat = outputs[0].squeeze()
      yhat = yhat.to(torch.device('cpu'))
      out = torch.cat([out, yhat], dim = -1)
  return out.numpy()

In [None]:
# Create Model
model = create_model(num_outputs = 1, is_classification = False)

# Load pre trained model
model = load_model(model, 'models/delaney_finetuned_base.pth', 'models/delaney_finetuned_heads_v2.pth')

# Get a smiles_list - Example Smiles String list below.
smiles_list = ['ClCC(Cl)(Cl)Cl','CC(Cl)(Cl)Cl', 'ClC(Cl)C(Cl)Cl', 'ClCC(Cl)Cl', 'FC(F)(Cl)C(F)(Cl)Cl', 'CC(Cl)Cl', 'ClC(=C)Cl', 'CCOC(C)OCC', 'BrCCBr', 'ClCCCl', 'CC(Cl)CCl', 'FC(F)(Cl)C(F)(F)Cl', 'CCOCCOCC', 'C=CC=C', 'ClCCCCl', 'CCNC(=S)NCC', 'C=CCC=C', 'C=CCCC=C', 'CC(C)CBr', 'CCCCBr', 'CCCCCCCBr', 'CCCCCCBr', 'CCCCCCCCBr', 'CCCCCBr', 'CCCBr', 'CCCCO', 'CCC=C', 'CCC#C', 'ClCCBr', 'ClCC(C)C', 'CCCCCl', 'CCCCCCCCl', 'CCCCCCCl', 'CCCCCCl', 'CCCCl', 'CCCCCCCCCCO', 'CCCCCCCCC=C', 'CCCCCCCCCCCCO', 'CCCCCCCO', 'CCCCCC=C', 'CCCCCC#C', 'CCCCCCCCCCCCCCCCO', 'CCCCCCO', 'CCCCC=C', 'CCCC(O)C=C', 'CCCCC#C', 'CCCCI', 'CCCCCCCI', 'CCCI', 'CCCN(=O)=O', 'CCCCCCCCCO', 'CCCCCCCC=C', 'CCCCCCCC#C', 'CCCCCCCCCCCCCCCCCCO', 'CCCCCCCCO']

# Prepare a Batch
batch_loader = prepare_batch_loader(smiles_list = smiles_list)

# check if GPU is available and if available, load the model to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Get Predictions
predictions = predict(model = model, dl = batch_loader)

**Training Model on a new Dataset** - And required functions.

In [None]:
# Function to prepare train and valid data loaders. Please use batch size of 32 only if you are running on server, else use batch_size = 4. Training is not recommended on local machines.
def get_data_loaders(f_path, smiles_col, label_col, batch_size = 32):
  ds = RegressionDataset(smiles_tokenizer, file_path = f_path, block_size = 515, x_col = smiles_col, y_col = [label_col])

  train_ds, valid_ds = random_split(ds, lengths= (int(len(ds)*0.8), len(ds) - int(len(ds)*0.8)))
  train_loader = DataLoader(train_ds, batch_size = 32, shuffle = True)
  valid_loader = DataLoader(valid_ds, batch_size = 32, shuffle = True)
  return train_loader, valid_loader

# This
def train_one_epoch(model, train_dataloader, val_dataloader, criterion_1, optimizer, scheduler, device, is_classification = False):

    total_loss_train = []
    rmse_train = []
    rmse_valid = []
    c = 0
    for dd in tqdm(train_dataloader):
        try:
            mask = dd['attention_mask']
            input_id = dd['input_ids']
            train_label = dd['label'].float()

            if(input_id.shape[0] != train_dataloader.batch_size):
                continue

            input_id = input_id.to(device)
            train_label = train_label.to(device)
            outputs = model(input_id)
            # st()
            loss = criterion_1(outputs[0].squeeze(), train_label[:, 0].squeeze())
            for i in range(train_label.shape[1] - 1):
              loss = loss + criterion_1(outputs[i].squeeze(), train_label[:, i].squeeze())
            # outputs = torch.stack(outputs, dim = 1).squeeze()
            # loss = criterion_1(outputs, train_label)
            total_loss_train.append(loss.item())
            if is_classification:
                # Calculate ROC_AUC
                rmse_train.append(roc_auc(train_label[:, 0].to(torch.device('cpu').numpy()), outputs[0].squeeze().detach().to(torch.device('cpu')).numpy()))
            else:
                rmse_train.append(np.sqrt(loss.item()))

            c += 1
            if (c%10 == 0):
                wandb.log({
                  "Train Loss Running Mean": np.mean(total_loss_train)
                })

            model.zero_grad()
            loss.backward()
            optimizer.step()

        except Exception as e:
            print(e)

    total_loss_val = []
    with torch.no_grad():
        c = 0
        for vdd in val_dataloader:
            try:
                val_label = vdd['label'].float()
                val_label = val_label.to(device)
                input_id = vdd['input_ids'].to(device)

                if(input_id.shape[0] != val_dataloader.batch_size):
                    continue

                outputs = model(input_id)

                loss = criterion_1(outputs[0].squeeze(), val_label[:, 0].squeeze())
                for i in range(train_label.shape[1] - 1):
                    loss = loss + criterion_1(outputs[i].squeeze(), val_label[:, i].squeeze())

                # outputs = torch.stack(outputs, dim = 1).squeeze()
                # loss = criterion_1(outputs, val_label)
                c += 1
                total_loss_val.append(loss.item())

                if is_classification:
                    # Calculate ROC_AUC
                    rmse_valid.append(roc_auc(val_label[:, 0].to(torch.device('cpu').numpy()), outputs[0].squeeze().detach().to(torch.device('cpu')).numpy()))
                else:
                    rmse_valid.append(np.sqrt(loss.item()))
                if (c%3 == 0):
                    wandb.log({
                      "Validation Loss Running Mean": np.mean(total_loss_val)
                    })
                scheduler.step(np.mean(total_loss_val))
            except Exception as e:
                print(e)
    return np.mean(total_loss_train), np.mean(total_loss_val), np.mean(rmse_train), np.mean(rmse_valid)

def train_regression_v2(model, train_dataloader, val_dataloader, learning_rate = 1e-04, epochs = 50, wd = 1e-02):

    hyperparameters = {
        "batch_size": train_dataloader.batch_size,
        "num_epochs": epochs,
    }
    wandb.config.update(hyperparameters)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if model.device != device:
        model.to(device)

    criterion_1 = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr = learning_rate, weight_decay = wd)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = learning_rate, steps_per_epoch = len(train_dataloader), epochs = epochs)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode = 'min',
                            factor = 0.7,
                            patience = 10,
                            min_lr = 1e-04
                )

    if use_cuda:
        criterion_1 = criterion_1.cuda()

    model.freeze_unfreeze()
    for epoch_num in range(5):

        train_loss, valid_loss, train_rmse, valid_rmse = train_one_epoch(model, train_loader, valid_loader, criterion_1, optimizer, scheduler, device)

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {train_loss: .3f} \
            | Val Loss: {valid_loss: .3f} \
            | Train RMSE: {train_rmse: .3f} \
            | Val RMSE: {valid_rmse: .3f}')

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode = 'min',
                            factor = 0.7,
                            patience = 10,
                            min_lr = 1e-06
                )
    model.freeze_unfreeze('u')
    for epoch_num in range(epochs):

        train_loss, valid_loss, train_rmse, valid_rmse = train_one_epoch(model, train_loader, valid_loader, criterion_1, optimizer, scheduler, device)

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {train_loss: .3f} \
            | Val Loss: {valid_loss: .3f} \
            | Train RMSE: {train_rmse: .3f} \
            | Val RMSE: {valid_rmse: .3f}')

    return model

def train_classification_v2(model, train_dataloader, val_dataloader, learning_rate = 1e-04, epochs = 50, wd = 1e-02):

    hyperparameters = {
        "batch_size": train_dataloader.batch_size,
        "num_epochs": epochs,
    }
    wandb.config.update(hyperparameters)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if model.device != device:
        model.to(device)

    criterion_1 = nn.BCELoss()
    optimizer = AdamW(model.parameters(), lr = learning_rate, weight_decay = wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode = 'min',
                            factor = 0.7,
                            patience = 10,
                            min_lr = 1e-04
                )

    if use_cuda:
        criterion_1 = criterion_1.cuda()
    model.freeze_unfreeze()

    for epoch_num in range(5):

        train_loss, valid_loss, train_roc_auc, valid_roc_auc = train_one_epoch(model, train_loader, valid_loader, criterion_1, optimizer, scheduler, device, is_classification = True)

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {train_loss: .3f} \
            | Val Loss: {valid_loss: .3f} \
            | Train ROC AUC: {train_roc_auc: .3f} \
            | Val ROC AUC: {valid_roc_auc: .3f}')

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode = 'min',
                            factor = 0.7,
                            patience = 10,
                            min_lr = 1e-06
                )
    model.freeze_unfreeze('u')
    for epoch_num in range(epochs):

        train_loss, valid_loss, train_roc_auc, valid_roc_auc = train_one_epoch(model, train_loader, valid_loader, criterion_1, optimizer, scheduler, device, is_classification = True)

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {train_loss: .3f} \
            | Val Loss: {valid_loss: .3f} \
            | Train ROC AUC: {train_rmse: .3f} \
            | Val ROC AUC: {valid_rmse: .3f}')

    return model

In [None]:
# Create data loader objects
train_loader, valid_loader = get_data_loaders("./chem_datasets/dataset-delaney.csv", 'SMILES', 'measured log(solubility:mol/L)')

# Create a model
model = create_model()

# Train the model
model = train_regression_v2(model, train_loader, valid_loader)

# save the model, update the path here.
model.save_model('./models/deleney_roberta_06_14.pth', './models/deleney_head_06_14.pth')