# Demo for MoleculeSTM Downstream: Property Prediction

## Load Packages

In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader as pyg_DataLoader

from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset
from MoleculeSTM.splitters import scaffold_split
from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM
from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART
from MoleculeSTM.models import GNN, GNN_graphpred

: 

## Setup Arguments

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--training_mode", type=str, default="fine_tuning", choices=["fine_tuning", "linear_probing"])
parser.add_argument("--molecule_type", type=str, default="Graph", choices=["SMILES", "Graph"])

########## for dataset and split ##########
parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--dataset", type=str, default="bace")
parser.add_argument("--split", type=str, default="scaffold")

########## for optimization ##########
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_scale", type=float, default=1)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--schedule", type=str, default="cycle")
parser.add_argument("--warm_up_steps", type=int, default=10)

########## for 2D GNN ##########
parser.add_argument("--gnn_emb_dim", type=int, default=300)
parser.add_argument("--num_layer", type=int, default=5)
parser.add_argument('--JK', type=str, default='last')
parser.add_argument("--dropout_ratio", type=float, default=0.5)
parser.add_argument("--gnn_type", type=str, default="gin")
parser.add_argument('--graph_pooling', type=str, default='mean')

########## for saver ##########
parser.add_argument("--eval_train", type=int, default=0)
parser.add_argument("--verbose", type=int, default=1)

parser.add_argument("--input_model_path", type=str, default="demo_checkpoints_Graph/molecule_model.pth")
parser.add_argument("--output_model_dir", type=str, default=None)

args = parser.parse_args("")
print("arguments\t", args)

: 

## Setup Seed

In [None]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

## Setup Dataset and Dataloader

In [None]:
num_tasks, task_mode = get_num_task_and_type(args.dataset)
dataset_folder = os.path.join(args.dataspace_path, "MoleculeNet_data", args.dataset)

dataset = MoleculeNetGraphDataset(dataset_folder, args.dataset)
dataloader_class = pyg_DataLoader
use_pyg_dataset = True

smiles_list = pd.read_csv(
    dataset_folder + "/processed/smiles.csv", header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(
    dataset, smiles_list, null_value=0, frac_train=0.8,
    frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset)


train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

: 

## Initialize and Load Model

In [None]:
molecule_node_model = GNN(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,
    JK=args.JK, drop_ratio=args.dropout_ratio,
    gnn_type=args.gnn_type)
model = GNN_graphpred(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,
    num_tasks=1, molecule_node_model=molecule_node_model) 
molecule_dim = args.gnn_emb_dim

if "GraphMVP" in args.input_model_path:
    print("Start from pretrained model (GraphMVP) in {}.".format(args.input_model_path))
    model.from_pretrained(args.input_model_path)
else:
    print("Start from pretrained model (MoleculeSTM) in {}.".format(args.input_model_path))
    state_dict = torch.load(args.input_model_path, map_location='cpu')
    model.load_state_dict(state_dict)


model = model.to(device)
linear_model = nn.Linear(molecule_dim, num_tasks).to(device)

# Rewrite the seed by MegaMolBART
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

Start from pretrained model (MoleculeSTM) in demo_checkpoints_Graph/molecule_model.pth.


## Setup Optimizer

In [None]:
if args.training_mode == "fine_tuning":
    model_param_group = [
        {"params": model.parameters()},
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
else:
    model_param_group = [
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)

## Define Support Functions

In [None]:
def train_classification(model, device, loader, optimizer):
    if args.training_mode == "fine_tuning":
        model.train()
    else:
        model.eval()
    linear_model.train()
    total_loss = 0

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        if args.molecule_type == "MegaMolBART":
            SMILES_list, y = batch
            SMILES_list = list(SMILES_list)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                SMILES_list, mol2latent=None,
                molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = y.to(device).float()
        else:
            batch = batch.to(device)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                batch, mol2latent=None,
                molecule_type="Graph", molecule_model=model)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = batch.y.view(pred.shape).to(device).float()

        is_valid = y ** 2 > 0
        loss_mat = criterion(pred, (y + 1) / 2)
        loss_mat = torch.where(
            is_valid, loss_mat,
            torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))

        optimizer.zero_grad()
        loss = torch.sum(loss_mat) / torch.sum(is_valid)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach().item()

    return total_loss / len(loader)


@torch.no_grad()
def eval_classification(model, device, loader):
    model.eval()
    linear_model.eval()
    y_true, y_scores = [], []

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        if args.molecule_type == "MegaMolBART":
            SMILES_list, y = batch
            SMILES_list = list(SMILES_list)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                SMILES_list, mol2latent=None,
                molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = y.to(device).float()
        else:
            batch = batch.to(device)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                batch, mol2latent=None,
                molecule_type="Graph", molecule_model=model)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = batch.y.view(pred.shape).to(device).float()

        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
        else:
            print("{} is invalid".format(i))

    if len(roc_list) < y_true.shape[1]:
        print(len(roc_list))
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list), 0, y_true, y_scores

## Start Training

In [None]:
train_func = train_classification
eval_func = eval_classification

train_roc_list, val_roc_list, test_roc_list = [], [], []
train_acc_list, val_acc_list, test_acc_list = [], [], []
best_val_roc, best_val_idx = -1, 0
criterion = nn.BCEWithLogitsLoss(reduction="none")

for epoch in range(1, args.epochs + 1):
    loss_acc = train_func(model, device, train_loader, optimizer)
    print("Epoch: {}\nLoss: {}".format(epoch, loss_acc))

    if args.eval_train:
        train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)
    else:
        train_roc = train_acc = 0
    val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)
    test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)

    train_roc_list.append(train_roc)
    train_acc_list.append(train_acc)
    val_roc_list.append(val_roc)
    val_acc_list.append(val_acc)
    test_roc_list.append(test_roc)
    test_acc_list.append(test_acc)
    print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc))
    print()

print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 19.51it/s]


Epoch: 1
Loss: 0.6760892538647902


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.39it/s]


train: 0.000000	val: 0.642125	test: 0.663189



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.04it/s]


Epoch: 2
Loss: 0.6383239313175804


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.31it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.58it/s]


train: 0.000000	val: 0.676190	test: 0.720049



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.44it/s]


Epoch: 3
Loss: 0.6019486816305863


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.11it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.19it/s]


train: 0.000000	val: 0.683516	test: 0.752043



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.23it/s]


Epoch: 4
Loss: 0.5672228501031273


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.20it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  5.00it/s]


train: 0.000000	val: 0.686447	test: 0.774474



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.79it/s]


Epoch: 5
Loss: 0.5250759069856844


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.60it/s]

train: 0.000000	val: 0.689377	test: 0.788211

best train: 0.000000	val: 0.642125	test: 0.663189



