In [1]:
# !pip install deepchem sklearn matplotlib pandas sentence_transformers

## Note: These fine-tuning notebooks donot reproduce the exact results mentioned in the paper, please follow the settings in paper to reproduce the results

In [1]:
from sklearn.decomposition import KernelPCA
from sklearn.preprocessing import StandardScaler
from sklearn.multiclass import OneVsRestClassifier
from sklearn.utils import class_weight
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.utils import class_weight
from sklearn.metrics import roc_curve, auc, roc_auc_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import PredefinedSplit
import pickle
import matplotlib.pyplot as plt
import warnings
from deepchem.molnet import load_clintox, load_tox21, load_bace_classification, load_bbbp
import pandas as pd
from data_reader import DataReader
import logging
from tqdm import tqdm
import numpy as np
from numpy import ndarray
import torch
from torch import Tensor, device
import transformers
from ≈ import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from typing import List, Dict, Tuple, Type, Union
warnings.filterwarnings("ignore")

In [14]:

logger = logging.getLogger(__name__)

class CBERT(object):
    """
    A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE.
    """
    def __init__(self, model_name_or_path: str, 
                device: str = None,
                num_cells: int = 100,
                num_cells_in_search: int = 10,
                pooler = None):

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModel.from_pretrained(model_name_or_path)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device

        self.index = None
        self.is_faiss_index = False
        self.num_cells = num_cells
        self.num_cells_in_search = num_cells_in_search

        if pooler is not None:
            self.pooler = pooler
        elif "unsup" in model_name_or_path:
            logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.")
            self.pooler = "cls_before_pooler"
        else:
            self.pooler = "cls"
    
    def encode(self, sentence: Union[str, List[str]], 
                device: str = None, 
                return_numpy: bool = False,
                normalize_to_unit: bool = True,
                keepdim: bool = False,
                batch_size: int = 64,
                max_length: int = 128) -> Union[ndarray, Tensor]:
        sentence = [str(smile) for smile in sentence]
        target_device = self.device if device is None else device
        self.model = self.model.to(target_device)
        
        single_sentence = False
        if isinstance(sentence, str):
            sentence = [sentence]
            single_sentence = True

        embedding_list = [] 
        with torch.no_grad():
            total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
            for batch_id in tqdm(range(total_batch)):
                inputs = self.tokenizer(
                    sentence[batch_id*batch_size:(batch_id+1)*batch_size], 
                    padding=True, 
                    truncation=True, 
                    max_length=max_length, 
                    return_tensors="pt"
                )
                inputs = {k: v.to(target_device) for k, v in inputs.items()}
                outputs = self.model(**inputs, return_dict=True)
                if self.pooler == "cls":
                    embeddings = outputs.pooler_output
                elif self.pooler == "cls_before_pooler":
                    embeddings = outputs.last_hidden_state[:, 0]
                else:
                    raise NotImplementedError
                if normalize_to_unit:
                    embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
                embedding_list.append(embeddings.cpu())
        embeddings = torch.cat(embedding_list, 0)
        
        if single_sentence and not keepdim:
            embeddings = embeddings[0]
        
        if return_numpy and not isinstance(embeddings, ndarray):
            return embeddings.numpy()
        return embeddings

In [15]:
# adjust the path to domain adapted encoder based on the domain adaptation dataset
model_name_or_path = "emtrl/simcse-smole-bert-muv-mlm"
encoder = CBERT(model_name_or_path)

Some weights of BertModel were not initialized from the model checkpoint at emtrl/simcse-smole-bert-muv-mlm and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
def load_dataset(dataset_name):
    dataset = DataReader(dataset_name)
    (train_dataset, valid_dataset, test_dataset) = (dataset.train_dataset,
                                                    dataset.valid_dataset, 
                                                    dataset.test_dataset
                                                   )
    X_train, y_train, X_valid, y_valid, X_test, y_test = (dataset.train_dataset.smiles,
                        dataset.train_dataset.y,
                        dataset.valid_dataset.smiles,
                        dataset.valid_dataset.y,
                        dataset.test_dataset.smiles,
                        dataset.test_dataset.y,
                       
                       )
    print(f"Loading and embedding SMILES for dataset {dataset_name}")
    return (
            encoder.encode(X_train), y_train,
            encoder.encode(X_valid), y_valid,
            encoder.encode(X_test), y_test
            )

In [17]:
def train_and_evaluate_model(X_train, y_train, X_valid, y_valid, X_test, y_test):

    # ps = PredefinedSplit(test_fold)
    print("Training Classifier")
    parameters = {'estimator__class_weight':['balanced'],
              'estimator__kernel':['rbf','sigmoid'], 
              'estimator__C':[1,0.5,0.25], 'estimator__gamma':['auto','scale']}
    tox21_svc = GridSearchCV(OneVsRestClassifier(SVC(probability=True,
                                                     random_state=23)), 
                             parameters, cv=3, scoring='roc_auc',n_jobs=-1)
    result = tox21_svc.fit(X_train, y_train)
    pred = tox21_svc.predict_proba(X_test)
    pred_svc = np.copy(pred)
    if len(np.array(y_test).shape) == 1 or np.array(y_test).shape[-1] == 1:
        return roc_auc_score(y_test,pred[:,1])
    else:
        return roc_auc_score(y_test,pred)

In [18]:
def evaluate_dataset(dataset_name):

    X_train, y_train, X_valid, y_valid, X_test, y_test = load_dataset(dataset_name=dataset_name)
    roc_score = train_and_evaluate_model(X_train, y_train, X_valid, y_valid, X_test, y_test)

    print(f"The AUROC score for dataset {dataset_name} is {roc_score:2f}")

## Evaluate MoleculeNet Datasets

In [None]:
evaluate_dataset(dataset_name="clintox")
print(f"\n{'*'*100}\n")
evaluate_dataset(dataset_name="bace")
print(f"\n{'*'*100}\n")
evaluate_dataset(dataset_name="bbbp")
print(f"\n{'*'*100}\n")
evaluate_dataset(dataset_name="tox21")

Loading and embedding SMILES for dataset clintox


100%|███████████████████████████████████████████| 19/19 [00:21<00:00,  1.13s/it]
100%|█████████████████████████████████████████████| 3/3 [00:03<00:00,  1.04s/it]
100%|█████████████████████████████████████████████| 3/3 [00:03<00:00,  1.02s/it]


Training Classifier
The AUROC score for dataset clintox is 0.981275

****************************************************************************************************

Loading and embedding SMILES for dataset bace


100%|███████████████████████████████████████████| 19/19 [00:17<00:00,  1.09it/s]
100%|█████████████████████████████████████████████| 3/3 [00:02<00:00,  1.40it/s]
100%|█████████████████████████████████████████████| 3/3 [00:03<00:00,  1.06s/it]


Training Classifier
The AUROC score for dataset bace is 0.776268

****************************************************************************************************

Loading and embedding SMILES for dataset bbbp


100%|███████████████████████████████████████████| 26/26 [00:26<00:00,  1.02s/it]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.12s/it]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.04s/it]


Training Classifier
The AUROC score for dataset bbbp is 0.712593

****************************************************************************************************





Loading and embedding SMILES for dataset tox21


100%|███████████████████████████████████████████| 98/98 [01:12<00:00,  1.36it/s]
 69%|██████████████████████████████▍             | 9/13 [00:09<00:04,  1.05s/it]