In [1]:
import torch
from transformers import AutoTokenizer, PreTrainedModel, AutoModelForMaskedLM
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

In [2]:
# Class for the MLM-finetuned model
class MoLFormerMLMWithRegressionHead(PreTrainedModel):
    def __init__(self, pretrained_model, config=None):
        if config is None:
            config = pretrained_model.config
        super().__init__(config)
        self.backbone = pretrained_model
        hidden_size = self.backbone.config.hidden_size
        self.regression_head = nn.Linear(hidden_size, 1)
        self.config = config

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids=input_ids, 
                               attention_mask=attention_mask,
                               output_hidden_states=True)  # MLM needs hidden states explicitly
        last_hidden_state = outputs.hidden_states[-1]  # Access last layer’s hidden states
        cls_hidden_state = last_hidden_state[:, 0, :]
        output = self.regression_head(cls_hidden_state)
        return output.squeeze(-1)

In [3]:
class Dataset(Dataset):
    def __init__(self, data):
        self.smiles = data['SMILES'].values
        self.label = data['label'].values
    
    def __len__(self):
        return(len(self.smiles))
    
    def __getitem__(self, idx):
        smile = self.smiles[idx]
        value = self.label[idx]
        
        return smile, value

In [4]:
def collate_fn(batch):
    """
    Collate function for the DataLoader
    """
    smile, value = zip(*batch)
    return list(smile), list(value)

In [5]:
class Trainer:
    
    def __init__(self, data: pd.DataFrame, model_path: str, **kwargs):
        """
        Trainer initialization
        
        Parameters:
        -----------
        data: pd.DataFrame
            col ['SMILES', 'label']
        
        model_path: str
            Path to the model
            
        **kwargs: Optional
            - model_type (str): ['nofit', 'finefit']. Default is 'nofit'
            
        """
        # Default values
        self.default_kwargs = {'model_type' : 'nofit'}
        self.default_kwargs.update(kwargs)
        
        # Update default values if input
        for key, value in self.default_kwargs.items():
            setattr(self, key, value)
        
        
        print('Data initialization')
        # Generating Train, Valid, Test dataset
        train_len = int(len(data)*0.7)
        valid_len = int(len(data)*0.2)
        test_len = int(len(data)*0.1)
        print(f'Train Points: {train_len}')
        print(f'Valid Points: {valid_len}')
        print(f'Test_Points: {test_len}')

        self.train_df = Dataset(data.iloc[0:train_len, :])
        self.valid_df = Dataset(data.iloc[train_len:train_len+valid_len, :])
        self.test_df  = Dataset(data.iloc[train_len+valid_len:, :])
        
        
        print('Model Initialization')
        MODEL_NAME = "ibm/MoLFormer-XL-both-10pct"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pre_model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True)
        self.pre_token = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        self.losses = {'Train': [], 'Valid': []}
        
        match self.model_type:
            case 'nofit':
                self.model = self._nofit(self.pre_model).to(self.device) # MolFormer + RegHead - FineTuning
            case 'finefit':
                self.model = self._bitfit(self.pre_model).to(self.device) # MolFormer + RegHead + FineTuning
        
    def train(self, train_params: dict):
        """
        Training
        
        Parameters:
        -----------
        train_params: dict
            - epochs (int): No. of epochs
            - lr (float) : Learning Rate
            - wt_decay (float) : Weight Decay
            - batch_size (int) : Batch Size
        """
        
        print('Training Starting...')
        epochs = train_params['epochs']
        lr = train_params['lr']
        wt_decay = train_params['wt_decay']
        batch_size = train_params['batch_size']
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=wt_decay)
        self.criterion = nn.MSELoss()
        
        self.train_loader = DataLoader(self.train_df, batch_size = batch_size)
        self.valid_loader = DataLoader(self.valid_df, batch_size = batch_size)
        self.test_loader = DataLoader(self.test_df, batch_size = batch_size)
        
        for i in range(epochs):
            
            # Train
            self.model.train()
            train_loss = 0.0
            with tqdm(total=len(self.train_loader)) as pbar:
                pbar.set_description(f'Epoch: {i} - Train')
                
                for smiles, values in self.train_loader:
                    values = torch.tensor(values , dtype = torch.float, device = self.device)
                    smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                    #print(smiles)
                    
                    self.optimizer.zero_grad()
                    predict = self.model(**smiles)
                    loss = self.criterion(predict, values)
                    loss.backward()
                    self.optimizer.step()
                    
                    train_loss += loss
                    pbar.update(1)
            self.losses['Train'].append(float(train_loss/len(self.train_loader)))
            
            
            # Valid
            self.model.eval()
            valid_loss = 0.0
            with tqdm(total=len(self.valid_loader)) as pbar:
                pbar.set_description(f'Epoch: {i} - Valid')
                
                for smiles, values in self.valid_loader:
                    values = torch.tensor(values , dtype = torch.float, device = self.device)
                    smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                    
                    self.optimizer.zero_grad()
                    predict = self.model(**smiles)
                    loss = self.criterion(predict, values)
                    loss.backward()
                    self.optimizer.step()
                    
                    valid_loss += loss
                    pbar.update(1)
            self.losses['Valid'].append(float(valid_loss/len(self.valid_loader)))
            
    def test(self):
        print('Testing Starting...')
        
        test_loss = 0.0
        with tqdm(total=len(self.test_loader)) as pbar:
            pbar.set_description('Testing')
            self.pred_results = []
            self.true_results = []
            
            for smiles, values in self.test_loader:
                values = torch.tensor(values , dtype = torch.float, device = self.device)
                smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                
                with torch.no_grad():
                    predict = self.model(**smiles)
                    loss = self.criterion(predict , values)

                    
                    # Prediction vs True values
                    self.pred_results.extend(predict.cpu().numpy().tolist())
                    self.true_results.extend(values.cpu().numpy().tolist())
                    
                # Loss for a batch
                test_loss += loss.item()
                pbar.update(1)
                
        # Avg loss over all batches this epoch
        self.avg_loss = float(test_loss / len(self.test_loader))

        
        
        # Doing a test with train dataset
        with tqdm(total=len(self.train_loader)) as pbar:
            pbar.set_description('Testing with Train Dataset')
            self.Trainpred_results = []
            self.Traintrue_results = []
            
            for smiles, values in self.train_loader:
                values = torch.tensor(values , dtype = torch.float, device = self.device)
                smiles = self.pre_token(smiles, padding = True, return_tensors = 'pt').to(self.device)
                
                with torch.no_grad():
                    predict = self.model(**smiles)
                    loss = self.criterion(predict , values)

                    
                    # Prediction vs True values
                    self.Trainpred_results.extend(predict.cpu().numpy().tolist())
                    self.Traintrue_results.extend(values.cpu().numpy().tolist())
                    
                pbar.update(1)
    
    def save(self, save_path):
        
        path = Path(save_path)
        path.mkdir(parents=True, exist_ok=True)

        # Saving Losses
        self.loss_df = pd.DataFrame(self.losses)
        self.loss_df.to_csv(f'{save_path}/losses.csv', index = False)
        
        # Train Loss Plot
        fig , ax = plt.subplots(1,1)
        ax.plot(range(1, len(self.loss_df['Train'])+1) , self.loss_df['Train'])
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Train Loss')
        ax.grid(True)
        fig.savefig(f'{save_path}/TrainLoss.png' , dpi = 300)
        plt.close()
        
        # Valid Loss Plot
        fig , ax = plt.subplots(1,1)
        ax.plot(range(1, len(self.loss_df['Valid'])+1) , self.loss_df['Valid'])
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Valid Loss')
        ax.grid(True)
        fig.savefig(f'{save_path}/ValidLoss.png' , dpi = 300)
        plt.close()
        
        # Test prediction
        fig , ax = plt.subplots(1,1)
        hb = ax.hexbin(x = self.true_results, y = self.pred_results, cmap = 'inferno', gridsize = 50)
        cb = fig.colorbar(hb, ax=ax)
        cb.set_label("Loss Value Range")
        ax.set_title(f'Average Test Loss: {self.avg_loss}')
        ax.set_xlabel('True values')
        ax.set_ylabel('Predicted values')
        ax.grid(True)
        
        fig.savefig(f'{save_path}/Test_TrueVpred.png' , dpi = 300)
        plt.close()
        
        # Train prediction
        fig , ax = plt.subplots(1,1)
        ax.hexbin(x = self.Traintrue_results, y = self.Trainpred_results, cmap = 'inferno', gridsize = 50)
        ax.set_xlabel('True values')
        ax.set_ylabel('Predicted values')
        ax.grid(True)
        
        fig.savefig(f'{save_path}/Train_TrueVpred.png' , dpi = 300)
        plt.close()     
        
    def _bitfit(self,model):
        for name, param in model.named_parameters():
            if 'bias' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
                
        fin_model = MoLFormerMLMWithRegressionHead(model)
        return fin_model
    
    def _nofit(self,model):
        for _, param in model.named_parameters():
            param.requires_grad = False
        
        fin_model = MoLFormerMLMWithRegressionHead(model)
        return fin_model

In [None]:
# Data Variables

# Add the path to the external dataset here
data1 = pd.read_csv("")
data1 = data1[["SMILES", "Label"]].dropna()
data1.rename(columns = {'Label':'label'}, inplace = True)

# Path to huggingFace dataset
DATASET_PATH = "scikit-fingerprints/MoleculeNet_Lipophilicity"

# load the dataset from HuggingFace
dataset = load_dataset(DATASET_PATH)
hf_df = pd.DataFrame(dataset['train'])
hf_df = hf_df[["SMILES", "label"]].dropna()

# Merge the two datasets
data = pd.concat([data1, hf_df], axis = 0).reset_index(drop = True)

In [None]:
# Train parameters
train_params = {
    'epochs'    : 50,
    'lr'        : 0.0001,
    'wt_decay'  : 1e-5,
    'batch_size': 128
}

In [None]:
# Add the path to the pre-trained MLM model here
PATH = ''

In [None]:
x = Trainer(data, model_type = 'nofit', model_path = PATH)
x.train(train_params)
x.test()
x.save('noFit/')

In [None]:
x = Trainer(data, model_type = 'finefit', model_path = PATH)
x.train(train_params)
x.test()
x.save('fineFit/')