In [1]:
from rdkit import DataStructs, Chem, RDLogger
from rdkit.Chem import rdFingerprintGenerator
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection._split import _BaseKFold
from typing import List
from transformers import AutoTokenizer, PreTrainedModel, AutoModelForMaskedLM
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

In [2]:
logger=RDLogger.logger()

In [3]:
def EuclideanDist(pi, pj):
  dv = np.array(pi)- np.array(pj)
  return np.sqrt(dv * dv).sum()

def ClusterData(data,nPts,distThresh,isDistData=False,distFunc=EuclideanDist):
  """  clusters the data points passed in and returns the list of clusters

    **Arguments**

      - data: a list of items with the input data 
        (see discussion of _isDistData_ argument for the exception)

      - nPts: the number of points to be used

      - distThresh: elements within this range of each other are considered
        to be neighbors            

      - isDistData: set this toggle when the data passed in is a
          distance matrix.  The distance matrix should be stored
          symmetrically. An example of how to do this:

            dists = []
            for i in range(nPts):
              for j in range(i):
                dists.append( distfunc(i,j) )

      - distFunc: a function to calculate distances between points.
          Receives 2 points as arguments, should return a float
          
    **Returns**

      - a tuple of tuples containing information about the clusters:
        ( (cluster1_elem1, cluster1_elem2, ...),
          (cluster2_elem1, cluster2_elem2, ...),
          ...
        )  
        The first element for each cluster is its centroid.

  """
  if isDistData and len(data)>(nPts*(nPts-1)/2):
    logger.warning("Distance matrix is too long")
  nbrLists = [None]*nPts
  for i in range(nPts): nbrLists[i] = []

  dmIdx=0
  for i in range(nPts):
    for j in range(i):
      if not isDistData:
        dij = distFunc(data[i],data[j])
      else:
        dij = data[dmIdx]
        dmIdx+=1
      if dij<=distThresh:
        nbrLists[i].append(j)
        nbrLists[j].append(i)
  #print nbrLists
  # sort by the number of neighbors:
  tLists = [(len(y),x) for x,y in enumerate(nbrLists)]
  tLists.sort()
  tLists.reverse()

  res = []
  seen = [0]*nPts
  while tLists:
    nNbrs,idx = tLists.pop(0)
    if seen[idx]:
      continue
    tRes = [idx]
    for nbr in nbrLists[idx]:
      if not seen[nbr]:
        tRes.append(nbr)
        seen[nbr]=1
    res.append(tuple(tRes))
  return tuple(res)

def taylor_butina_clustering(fp_list: List[DataStructs.ExplicitBitVect], cutoff: float = 0.65) -> List[int]:
  """
  Cluster a set of fingerprints using the RDKit Taylor-Butina implementation

  Parameters:
  ----------
  fp_list: a list of fingerprints
  cutoff: distance cutoff (1 - Tanimoto similarity)

  Returns:
  -------
  A list of cluster ids
  """
  dists = []
  nfps = len(fp_list)
  for i in range(1, nfps):
      sims = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])
      dists.extend([1 - x for x in sims])
  cluster_res = ClusterData(dists, nfps, cutoff, isDistData=True)
  cluster_id_list = np.zeros(nfps, dtype=int)
  for cluster_num, cluster in enumerate(cluster_res):
      for member in cluster:
          cluster_id_list[member] = cluster_num
  return cluster_id_list.tolist()

def get_butina_clusters(smiles_list: List[str], cutoff: float = 0.65) -> List[int]:
  """
  Cluster a list of SMILES strings using the Butina clustering algorithm.

  Parameters:
  ----------
  smiles_list: List of SMILES strings
  cutoff: The cutoff value to use for clustering

  Returns:
  -------
  List of cluster labels corresponding to each SMILES string in the input list.
  """
  mol_list = [Chem.MolFromSmiles(x) for x in smiles_list]
  fg = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
  fp_list = [fg.GetFingerprint(x) for x in mol_list]
  return taylor_butina_clustering(fp_list, cutoff=cutoff)

In [4]:
# Class for the MLM-finetuned model
class MoLFormerMLMWithRegressionHead(PreTrainedModel):
    def __init__(self, pretrained_model, config=None):
        if config is None:
            config = pretrained_model.config
        super().__init__(config)
        self.backbone = pretrained_model
        hidden_size = self.backbone.config.hidden_size
        self.regression_head = nn.Linear(hidden_size, 1)
        self.config = config

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids=input_ids, 
                               attention_mask=attention_mask,
                               output_hidden_states=True)  # MLM needs hidden states explicitly
        last_hidden_state = outputs.hidden_states[-1]  # Access last layer’s hidden states
        cls_hidden_state = last_hidden_state[:, 0, :]
        output = self.regression_head(cls_hidden_state)
        return output.squeeze(-1)

In [5]:
# Data Variables

EXTERNAL_DATA_LOCATION = '/home/sunag/Documents/NNTI/Project/task2/new_train_data.tsv'
ext_data = pd.read_csv(EXTERNAL_DATA_LOCATION, sep='\t')
ext_data.rename({'Label' : 'label'}, inplace = True, axis = 1)

# Merge the two datasets
data = ext_data

data["ButinaCluster"] = get_butina_clusters(data["SMILES"])

In [6]:
data.describe()

Unnamed: 0,label,ButinaCluster
count,4242.0,4242.0
mean,2.187506,205.008015
std,1.198735,262.898946
min,-1.5,0.0
25%,1.42,23.0
50%,2.36,83.0
75%,3.1,284.0
max,4.5,1056.0


In [6]:
class GroupKFoldShuffleTV(_BaseKFold):
    """
    Group-aware K-fold cross-validator with optional shuffling,
    providing train, validation, and test indices.

    Parameters:
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 3 (train, validation, test).

    shuffle : bool, default=False
        Whether to shuffle the groups before splitting into folds.

    random_state : int, RandomState instance or None, default=None
        Controls the randomness of the shuffling. Pass an int for reproducible
        output across multiple function calls. Ignored if `shuffle=False`.
    """
    def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
        if n_splits < 3:
            raise ValueError("n_splits must be at least 3 to provide train, validation, and test splits.")
        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

    def split(self, X, y=None, groups=None):
        if groups is None:
            raise ValueError("Groups must be provided for group-aware splitting.")

        unique_groups = np.unique(groups)
        
        if self.shuffle:
            rng = np.random.RandomState(self.random_state)
            rng.shuffle(unique_groups)
        
        split_groups = np.array_split(unique_groups, self.n_splits)
        
        for i in range(self.n_splits):
            test_group_ids = split_groups[i]
            val_group_ids = split_groups[(i + 1) % self.n_splits]  # Next fold for validation
            
            test_mask = np.isin(groups, test_group_ids)
            val_mask = np.isin(groups, val_group_ids)
            train_mask = ~(test_mask | val_mask)
            
            train_idx = np.where(train_mask)[0]
            val_idx = np.where(val_mask)[0]
            test_idx = np.where(test_mask)[0]
            
            yield train_idx, val_idx, test_idx

In [7]:
split_dict = {}

for fold, (train_idx, val_idx, test_idx) in enumerate(GroupKFoldShuffleTV(n_splits=5, shuffle=True, random_state=42).split(X=data, groups=data["ButinaCluster"])):
    split_dict['train'] = train_idx
    split_dict['validation'] = val_idx
    split_dict['test'] = test_idx
    break

In [8]:
class Dataset(Dataset):
    def __init__(self, data):
        self.smiles = data['SMILES'].values
        self.label = data['Label'].values
    
    def __len__(self):
        return(len(self.smiles))
    
    def __getitem__(self, idx):
        smile = self.smiles[idx]
        value = self.label[idx]
        
        return smile, value

In [None]:
class Trainer:
    
    def __init__(self, data: pd.DataFrame, fin_model: nn.Module, split_dict: dict, **kwargs):
        """
        Trainer initialization
        
        Parameters:
        -----------
        data: pd.DataFrame
            col ['SMILES', 'label']
        
        fin_model: nn.Module
            - The neural network to be trained in this instance

        split_dict: dict
            - A dictionary of the indices of train, test and validation datasets
            
        **kwargs: Optional
            - model_type (str): ['nofit', 'finefit']. Default is 'nofit'
            
        """
        
        print('Data initialization')
        self.train_idx = split_dict['train']
        self.valid_idx = split_dict['validation']
        self.test_idx = split_dict['test']
        # Create training, test and validation datasets
        print(f'Train Points: {len(self.train_idx)}')
        print(f'Valid Points: {len(self.valid_idx)}')
        print(f'Test_Points: {len(self.test_idx)}')

        self.train_df = Dataset(data.iloc[self.train_idx])
        self.valid_df = Dataset(data.iloc[self.valid_idx])
        self.test_df  = Dataset(data.iloc[self.test_idx])
        
        
        print('Model Initialization')
        MODEL_NAME = "ibm/MoLFormer-XL-both-10pct"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pre_token = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        self.losses = {'Train': [], 'Valid': []}
        self.model = fin_model.to(self.device)
        print('Model Initialization successful')
        
    def train(self, train_params: dict):
        """
        Training
        
        Parameters:
        -----------
        train_params: dict
            - epochs (int): No. of epochs
            - lr (float) : Learning Rate
            - wt_decay (float) : Weight Decay
            - batch_size (int) : Batch Size
        """
        
        print('Training Starting...')
        epochs = train_params['epochs']
        lr = train_params['lr']
        wt_decay = train_params['wt_decay']
        batch_size = train_params['batch_size']

        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=wt_decay)
        self.criterion = nn.MSELoss()
        
        self.train_loader = DataLoader(self.train_df, batch_size = batch_size)
        self.valid_loader = DataLoader(self.valid_df, batch_size = batch_size)
        self.test_loader = DataLoader(self.test_df, batch_size = batch_size)
        
        for i in range(epochs):
            
            # Train
            self.model.train()
            train_loss = 0.0
            with tqdm(total=len(self.train_loader)) as pbar:
                pbar.set_description(f'Epoch: {i} - Train')
                
                for smiles, values in self.train_loader:
                    values = torch.tensor(values , dtype = torch.float, device = self.device)
                    smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                    #print(smiles)
                    
                    self.optimizer.zero_grad()
                    predict = self.model(**smiles)
                    loss = self.criterion(predict, values)
                    loss.backward()
                    self.optimizer.step()
                    
                    train_loss += loss
                    pbar.update(1)
            self.losses['Train'].append(float(train_loss/len(self.train_loader)))
            
            # Valid
            self.model.eval()
            valid_loss = 0.0
            with tqdm(total=len(self.valid_loader)) as pbar:
                pbar.set_description(f'Epoch: {i} - Valid')
                
                for smiles, values in self.valid_loader:
                    values = torch.tensor(values , dtype = torch.float, device = self.device)
                    smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                    
                    self.optimizer.zero_grad()
                    predict = self.model(**smiles)
                    loss = self.criterion(predict, values)
                    loss.backward()
                    self.optimizer.step()
                    
                    valid_loss += loss
                    pbar.update(1)
            self.losses['Valid'].append(float(valid_loss/len(self.valid_loader)))
            
            
    def test(self):
        print('Testing Starting...')
        
        test_loss, ss_total, ss_residual = 0.0, 0.0, 0.0
        with tqdm(total=len(self.test_loader)) as pbar:
            pbar.set_description('Testing')
            self.pred_results = []
            self.true_results = []
            
            for smiles, values in self.test_loader:
                values = torch.tensor(values , dtype = torch.float, device = self.device)
                smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                
                with torch.no_grad():
                    predict = self.model(**smiles)
                    loss = self.criterion(predict , values)

                    
                    # Prediction vs True values
                    self.pred_results.extend(predict.cpu().numpy().tolist())
                    self.true_results.extend(values.cpu().numpy().tolist())
                    
                # Loss for a batch
                test_loss += loss.item()
                ss_total += torch.sum((values - values.mean()) ** 2).item()
                ss_residual += torch.sum((predict - values) ** 2).item()
                pbar.update(1)
                
        # Avg loss over all batches this epoch
        self.avg_loss = float(test_loss / len(self.test_loader))
        r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else float('-inf')
        print(f"Test Error: \n R² Score: {r2_score:.4f}")
        
        
        # Doing a test with train dataset
        with tqdm(total=len(self.train_loader)) as pbar:
            pbar.set_description('Testing with Train Dataset')
            self.Trainpred_results = []
            self.Traintrue_results = []
            
            for smiles, values in self.train_loader:
                values = torch.tensor(values , dtype = torch.float, device = self.device)
                smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                
                with torch.no_grad():
                    predict = self.model(**smiles)
                    loss = self.criterion(predict , values)

                    
                    # Prediction vs True values
                    self.Trainpred_results.extend(predict.cpu().numpy().tolist())
                    self.Traintrue_results.extend(values.cpu().numpy().tolist())
                    
                pbar.update(1)
    
    def save(self, save_path):
        
        path = Path(save_path)
        path.mkdir(parents=True, exist_ok=True)

        # Saving Losses
        self.loss_df = pd.DataFrame(self.losses)
        self.loss_df.to_csv(f'{save_path}/losses.csv', index = False)
        
        # Train Loss Plot
        fig , ax = plt.subplots(1,1)
        ax.plot(range(1, len(self.loss_df['Train'])+1) , self.loss_df['Train'])
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Train Loss')
        ax.grid(True)
        fig.savefig(f'{save_path}/TrainLoss.png' , dpi = 300)
        plt.close()
        
        # Valid Loss Plot
        fig , ax = plt.subplots(1,1)
        ax.plot(range(1, len(self.loss_df['Valid'])+1) , self.loss_df['Valid'])
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Valid Loss')
        ax.grid(True)
        fig.savefig(f'{save_path}/ValidLoss.png' , dpi = 300)
        plt.close()
        
        # Test prediction
        fig , ax = plt.subplots(1,1)
        hb = ax.hexbin(x = self.true_results, y = self.pred_results, cmap = 'inferno', gridsize = 50)
        cb = fig.colorbar(hb, ax=ax)
        cb.set_label("Loss Value Range")
        ax.set_title(f'Average Test Loss: {self.avg_loss}')
        ax.set_xlabel('True values')
        ax.set_ylabel('Predicted values')
        ax.grid(True)
        
        fig.savefig(f'{save_path}/Test_TrueVpred.png' , dpi = 300)
        plt.close()
        
        # Train prediction
        fig , ax = plt.subplots(1,1)
        hb = ax.hexbin(x = self.Traintrue_results, y = self.Trainpred_results, cmap = 'inferno', gridsize = 50)
        cb = fig.colorbar(hb, ax=ax)
        cb.set_label("Counts")
        ax.set_xlabel('True values')
        ax.set_ylabel('Predicted values')
        ax.grid(True)
        
        fig.savefig(f'{save_path}/Train_TrueVpred.png' , dpi = 300)
        plt.close()

In [10]:
# Train parameters
train_params = {
    'epochs'    : 1,
    'lr'        : 0.0001,
    'wt_decay'  : 1e-5,
    'batch_size': 128
}

In [24]:
PATH = '/home/sunag/Documents/NNTI/Project/task3/molformerMLM'
pre_model = AutoModelForMaskedLM.from_pretrained(PATH, trust_remote_code=True)
fin_model = MoLFormerMLMWithRegressionHead(pre_model)

In [None]:
x = Trainer(data, fin_model, split_dict)
x.train(train_params)
x.test()
x.save('DS/')