# Demo for MoleculeSTM Downstream: Structure-Text Retrieval

## Load Packages

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import time
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader as pyg_DataLoader

from transformers import AutoModel, AutoTokenizer
from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval
from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART
from MoleculeSTM.models import GNN, GNN_graphpred
from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network

# Set-up the environment variable to ignore warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'False'

ModuleNotFoundError: No module named 'networkx'

## Setup Arguments

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"])
parser.add_argument("--load_latent_projector", type=int, default=1)
parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"])

########## for dataset and split ##########
parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--task", type=str, default="molecule_description",
    choices=[
        "molecule_description", "molecule_description_Raw",
        "molecule_description_removed_PubChem", "molecule_description_removed_PubChem_Raw",
        "molecule_pharmacodynamics", "molecule_pharmacodynamics_Raw",
        "molecule_pharmacodynamics_removed_PubChem", "molecule_pharmacodynamics_removed_PubChem_Raw"])
parser.add_argument("--test_mode", type=str, default="given_text", choices=["given_text", "given_molecule"])

########## for optimization ##########
parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20])
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--text_lr", type=float, default=1e-5)
parser.add_argument("--mol_lr", type=float, default=1e-5)
parser.add_argument("--text_lr_scale", type=float, default=0.1)
parser.add_argument("--mol_lr_scale", type=float, default=0.1)
parser.add_argument("--decay", type=float, default=0)

########## for contrastive objective ##########
parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"])
parser.add_argument("--CL_neg_samples", type=int, default=1)
parser.add_argument("--T", type=float, default=0.1)
parser.add_argument('--normalize', dest='normalize', action='store_true')
parser.add_argument('--no_normalize', dest='normalize', action='store_false')
parser.set_defaults(normalize=True)

########## for BERT model ##########
parser.add_argument("--max_seq_len", type=int, default=512)

########## for molecule model ##########
parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"])

########## for MegaMolBART ##########
parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt")

########## for saver ##########
parser.add_argument("--eval_train", type=int, default=0)
parser.add_argument("--verbose", type=int, default=0)

parser.add_argument("--input_model_dir", type=str, default="demo_checkpoints_SMILES")
parser.add_argument("--input_model_path", type=str, default="demo_checkpoints_SMILES/molecule_model.pth")


args = parser.parse_args("")
print("arguments\t", args)

arguments	 Namespace(CL_neg_samples=1, SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, T_list=[4, 10, 20], batch_size=32, dataspace_path='../data', decay=0, device=0, epochs=1, eval_train=0, input_model_dir='demo_checkpoints_SMILES', input_model_path='demo_checkpoints_SMILES/molecule_model.pth', load_latent_projector=1, max_seq_len=512, mol_lr=1e-05, mol_lr_scale=0.1, molecule_type='SMILES', normalize=True, num_workers=8, seed=42, task='molecule_description', test_mode='given_text', text_lr=1e-05, text_lr_scale=0.1, text_type='SciBERT', training_mode='zero_shot', verbose=0, vocab_path='../MoleculeSTM/bart_vocab.txt')


## Setup Seed

In [3]:
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

## Load SciBERT

In [4]:
pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')
text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)
text_dim = 768

input_model_path = os.path.join(args.input_model_dir, "text_model.pth")
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
text_model.load_state_dict(state_dict)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loading from demo_checkpoints_SMILES/text_model.pth...


<All keys matched successfully>

## Load MoleculeSTM-SMILES

In [5]:
input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth")
print("Loading from {}...".format(input_model_path))
MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None)
molecule_model = MegaMolBART_wrapper.model
state_dict = torch.load(input_model_path, map_location='cpu')
molecule_model.load_state_dict(state_dict)
molecule_dim = 256

# Rewrite the seed by MegaMolBART
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

Loading from demo_checkpoints_SMILES/molecule_model.pth...
using world size: 1 and model-parallel size: 1 
using torch.float32 for parameters ...
-------------------- arguments --------------------
  adam_beta1 ...................... 0.9
  adam_beta2 ...................... 0.999
  adam_eps ........................ 1e-08
  adlr_autoresume ................. False
  adlr_autoresume_interval ........ 1000
  apply_query_key_layer_scaling ... False
  apply_residual_connection_post_layernorm  False
  attention_dropout ............... 0.1
  attention_softmax_in_fp32 ....... False
  batch_size ...................... None
  bert_load ....................... None
  bias_dropout_fusion ............. False
  bias_gelu_fusion ................ False
  block_data_path ................. None
  checkpoint_activations .......... False
  checkpoint_in_cpu ............... False
  checkpoint_num_layers ........... 1
  clip_grad ....................... 1.0
  contigious_checkpointing ........ False
  cpu_opti

## Load Projection Layers

In [6]:
text2latent = nn.Linear(text_dim, args.SSL_emb_dim)
input_model_path = os.path.join(args.input_model_dir, "text2latent_model.pth")
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
text2latent.load_state_dict(state_dict)

mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim)
input_model_path = os.path.join(args.input_model_dir, "mol2latent_model.pth")
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
mol2latent.load_state_dict(state_dict)

Loading from demo_checkpoints_SMILES/text2latent_model.pth...
Loading from demo_checkpoints_SMILES/mol2latent_model.pth...


<All keys matched successfully>

## Define Support Functions

In [7]:
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr


def do_CL_eval(X, Y, neg_Y, args):
    X = F.normalize(X, dim=-1)
    X = X.unsqueeze(1) # B, 1, d

    Y = Y.unsqueeze(0)
    Y = torch.cat([Y, neg_Y], dim=0) # T, B, d
    Y = Y.transpose(0, 1)  # B, T, d
    Y = F.normalize(Y, dim=-1)

    logits = torch.bmm(X, Y.transpose(1, 2)).squeeze()  # B*T
    B = X.size()[0]
    labels = torch.zeros(B).long().to(logits.device)  # B*1

    criterion = nn.CrossEntropyLoss()

    CL_loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=False)
    confidence = logits
    CL_conf = confidence.max(dim=1)[0]
    CL_conf = CL_conf.cpu().numpy()

    CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B
    return CL_loss, CL_conf, CL_acc


def get_text_repr(text):
    text_tokens_ids, text_masks = prepare_text_tokens(
        device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
    text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)
    text_repr = text_output["pooler_output"]
    text_repr = text2latent(text_repr)
    return text_repr


@torch.no_grad()
def eval_epoch(dataloader):
    text_model.eval()
    molecule_model.eval()
    text2latent.eval()
    mol2latent.eval()

    accum_acc_list = [0 for _ in args.T_list]
    if args.verbose:
        L = tqdm(dataloader)
    else:
        L = dataloader
    for batch in L:
        text = batch[0]
        molecule_data = batch[1]
        neg_text = batch[2]
        neg_molecule_data = batch[3]

        text_repr = get_text_repr(text)
        SMILES_list = list(molecule_data)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            SMILES_list, mol2latent=mol2latent,
            molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper)

        if test_mode == "given_text":
            neg_molecule_repr = [
                get_molecule_repr_MoleculeSTM(
                    list(neg_molecule_data[idx]), mol2latent=mol2latent,
                    molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max)
            ]
            neg_molecule_repr = torch.stack(neg_molecule_repr)

            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        elif test_mode == "given_molecule":
            neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]
            neg_text_repr = torch.stack(neg_text_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        else:
            raise Exception
    
    accum_acc_list = np.array(accum_acc_list)
    accum_acc_list /= len(dataloader)
    return accum_acc_list

## Start Retrieval

In [8]:
text_model = text_model.to(device)
molecule_model = molecule_model.to(device)
text2latent = text2latent.to(device)
mol2latent = mol2latent.to(device)

T_max = max(args.T_list) - 1

initial_test_acc_list = []
test_mode = args.test_mode
dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data")

dataset_class = DrugBank_Datasets_SMILES_retrieval
dataloader_class = torch_DataLoader

if args.task == "molecule_description":
    template = "SMILES_description_{}.txt"
elif args.task == "molecule_description_removed_PubChem":
    template = "SMILES_description_removed_from_PubChem_{}.txt"
elif args.task == "molecule_description_Raw":
    template = "SMILES_description_{}_Raw.txt"
elif args.task == "molecule_description_removed_PubChem_Raw":
    template = "SMILES_description_removed_from_PubChem_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics":
    template = "SMILES_pharmacodynamics_{}.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt"
elif args.task == "molecule_pharmacodynamics_Raw":
    template = "SMILES_pharmacodynamics_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt"

full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, template=template)
full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers

initial_test_acc_list = eval_epoch(full_dataloader)
print('Results', initial_test_acc_list)

Loading negative samples from ../data/DrugBank_data/index/SMILES_description_full.txt
Results [0.94256757 0.89864865 0.84797297]
