In [3]:
from enumeration import SmilesEnumerator
from data_reader import DataReader
import pandas as pd

## Prepare Data for Contrastive Learning Using the Enumerations

In [10]:
"""
This script only requires unlabelled SMILES and generates data for contrastive learning
"""

def perform_augmentation(dataset_path, out_file_path, smiles_column = "smiles", 
                         labels_column = "smiles", n_augmentations = 2):
    
    data_reader = DataReader(dataset_path, smiles_column, labels_column=labels_column)
    smiles_enumerator = SmilesEnumerator(canonical=False, enum=True)
    enumerated_triplets   = smiles_enumerator.enumerate_smiles_hard_neg(
        data_reader, smiles_column, random_pairs=True, replication_count=n_augmentations)
    
    print(enumerated_triplets.shape)
    enumerated_triplets.to_csv(out_file_path, index=False, encoding="utf-8-sig")
    

In [11]:
in_file_path = "data/smiles.csv"
out_file_path = "data/smiles_for_contrastive_learning.csv" 

perform_augmentation(dataset_path=in_file_path, out_file_path=out_file_path)

100%|██████████| 200/200 [00:00<00:00, 42478.27it/s]

(200, 3)





## Prepare Data for MTR by computing normalization values (mean, std) of Physicochemical Properties

In [32]:
import numpy as np
import torch
from tqdm import tqdm
from rdkit import Chem
from rdkit.ML.Descriptors.MoleculeDescriptors import MolecularDescriptorCalculator
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
import json

In [33]:
class PhysicoChemcialPropertyExtractor:
    """Computes RDKit properties on-the-fly."""

    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer

        self.descriptors = [name for name, _ in Chem.Descriptors.descList]
        self.descriptors.remove("Ipc")
        self.calculator = MolecularDescriptorCalculator(self.descriptors)
        self.num_labels = len(self.descriptors)

    def __len__(self):
        return self.len

    def _compute_descriptors(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            mol_descriptors = np.full(shape=(self.num_labels), fill_value=0.0)
        else:
            mol_descriptors = np.array(list(self.calculator.CalcDescriptors(mol)))
            mol_descriptors = np.nan_to_num(
                mol_descriptors, nan=0.0, posinf=0.0, neginf=0.0
            )
        assert mol_descriptors.size == self.num_labels

        return mol_descriptors

    def preprocess(self, feature_dict):
        smiles = feature_dict["text"]
        batch_encoding = self.tokenizer(
            smiles,
            add_special_tokens=True,
            truncation=True,
            padding="max_length",
            max_length=self.block_size,
        )
        batch_encoding = {k: torch.tensor(v) for k, v in batch_encoding.items()}

        mol_descriptors = self._compute_descriptors(smiles)
        batch_encoding["label"] = torch.tensor(mol_descriptors, dtype=torch.float32)

        return batch_encoding



def get_data_files(train_path):
    if os.path.isdir(train_path):
        return [
            os.path.join(train_path, file_name) for file_name in os.listdir(train_path)
        ]
    elif os.path.isfile(train_path):
        return train_path

    raise ValueError("Please pass in a proper train path")


In [34]:
## init physicochemical property extractor
physicochemical_fingerprint_extractor = PhysicoChemcialPropertyExtractor(tokenizer=None)

physicochemical_fingerprints = []
for smile in tqdm(smiles):
    physicochemical_fingerprints.append(physicochemical_fingerprint_extractor._compute_descriptors(smile))

100%|██████████| 100/100 [00:05<00:00, 18.06it/s]


In [36]:
physicochemical_fingerprints_np = np.array(physicochemical_fingerprints)

# Compute mean and std
mean = np.mean(physicochemical_fingerprints_np, axis=0)
std = np.std(physicochemical_fingerprints_np, axis=0)

## dump the output as json to be used for in pre-training
with open("data/normalization_values_207.json", "w") as outfile:
    json.dump({
    "mean": list(mean),
    "std": list(std)
    }, outfile)