In [18]:
# Offline training of an end-to-end clip
# Encoder/decoder.
# Try to return to torch_emb=False

import os
import inspect
import pandas as pd
import random
import pickle

from rdkit import Chem
from rdkit.Chem import AllChem

import torch.multiprocessing as mp
from torch.utils.data.datapipes.iter import FileLister, Shuffler

from coati.data.dataset import COATI_dataset
from coati.common.s3 import cache_read
from coati.training.train_coati import train_autoencoder, do_args
from coati.common.util import dir_or_file_exists, makedir, query_yes_no
from coati.common.s3 import copy_bucket_dir_from_s3
from coati.data.batch_pipe import UnstackPickles, UrBatcher, stack_batch

# Dataset

In [20]:
cache_dir = "/Users/stefanhangler/Documents/Uni/Msc_AI/3_Semester/Seminar_Practical Work/Code.nosync/COATI/examples/practical_work_tests/"
DATA_PATH = "datasets/chembl_smiles_coors"

## COATI Dataset Class (adapted)
smiles_to_3d is a new function that takes a SMILES string and returns a 3D structure. It uses RDKit to generate the 3D structure.

In [26]:
class COATI_dataset:
    def __init__(
        self,
        cache_dir,
        fields=["smiles", "atoms", "coords"],
        test_split_mode="row",
        test_frac=0.02,  # in percent.
        valid_frac=0.02,  # in percent.
    ):
        self.cache_dir = cache_dir
        self.summary = {"dataset_type": "coati", "fields": fields}
        self.test_frac = test_frac
        self.fields = fields
        self.valid_frac = valid_frac
        assert int(test_frac * 100) >= 0 and int(test_frac * 100) <= 50
        assert int(valid_frac * 100) >= 0 and int(valid_frac * 100) <= 50
        assert int(valid_frac * 100 + test_frac * 100) < 50
        self.test_split_mode = test_split_mode

    def partition_routine(self, row):
        """Partitioning logic for dataset splits."""
        if not "mod_molecule" in row:
            tore = ["raw"]
            tore.append("train")
            return tore
        else:
            tore = ["raw"]

            if row["mod_molecule"] % 100 >= int(
                (self.test_frac + self.valid_frac) * 100
            ):
                tore.append("train")
            elif row["mod_molecule"] % 100 >= int((self.test_frac * 100)):
                tore.append("valid")
            else:
                tore.append("test")

            return tore

    # new function to convert SMILES to 3D coordinates
    def smiles_to_3d(self, row):
        """Convert SMILES string to 3D representation and optimize with MMFF94s."""
        smi = row['smiles']
        mol = Chem.MolFromSmiles(smi)
        if mol:
            mol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol, AllChem.ETKDG())
            AllChem.MMFFOptimizeMolecule(mol, mmffVariant='MMFF94s')
            coords = mol.GetConformer().GetPositions()
            row['coords'] = coords
            row['atoms'] = [atom.GetSymbol() for atom in mol.GetAtoms()]
        return row

    def get_data_pipe(
        self,
        rebuild=False,
        batch_size=32,
        partition: str = "raw",
        required_fields=[],
        distributed_rankmod_total=None,
        distributed_rankmod_rank=1,
        xform_routine=None,
    ):
        
        """Set up the data pipeline with RDKit processing included."""
        print(f"trying to open a {partition} datapipe for...")
        if (
            not dir_or_file_exists(os.path.join(self.cache_dir, DATA_PATH, "0.pkl"))
        ) or rebuild:
            makedir(self.cache_dir)
            # Automatically proceed with downloading the data without asking for confirmation
            copy_bucket_dir_from_s3(DATA_PATH, self.cache_dir)

        # Use the smiles_to_3d function as the transformation routine
        xform_routine = self.smiles_to_3d if xform_routine is None else xform_routine

        pipe = (
            FileLister(
                root=os.path.join(self.cache_dir, DATA_PATH),
                recursive=False,
                masks=["*.pkl"],
            )
            .shuffle()
            .open_files(mode="rb")
            .unstack_pickles()
            .unbatch()
            .shuffle(buffer_size=200000)
            .map(xform_routine)
        )
        pipe = pipe.ur_batcher(
            batch_size=batch_size,
            partition=partition,
            xform_routine=xform_routine,
            partition_routine=self.partition_routine,
            distributed_rankmod_total=distributed_rankmod_total,
            distributed_rankmod_rank=distributed_rankmod_rank,
            direct_mode=False,
            required_fields=self.fields,
        )

        return pipe


## Load ChEMBL canonical smile strings

In [15]:
# load Chembl dataset smile strings
with cache_read("s3://terray-public/datasets/chembl_canonical_smiles.pkl", "rb") as f:
    chembl_canonical_smiles = pickle.loads(f.read(), encoding="UTF-8")

# Shuffle the dataset and select a subset for the example
random.shuffle(chembl_canonical_smiles)
chembl_subset = chembl_canonical_smiles[:10_000]
chembl_subset = [{"smiles": s} for s in chembl_subset]


In [16]:
chembl_subset[:1]

[{'smiles': 'COc1cccc(NC(=O)CCC2CCN(C(=O)c3cncs3)CC2)c1'}]

Store in pickle file for faster read/write

In [23]:
# Assuming you have a cache_dir variable defined as in the COATI_dataset class
subset_file_path = os.path.join(cache_dir, DATA_PATH, "chembl_subset.pkl")
with open(subset_file_path, "wb") as f:
    pickle.dump(chembl_subset, f)

TEST Datset class

In [24]:
# Create an instance of the COATI_dataset class
dataset = COATI_dataset(
    cache_dir=cache_dir,
    fields=["smiles", "atoms", "coords"],
    test_split_mode="row",
    test_frac=0.02,
    valid_frac=0.02
)

# Call the get_data_pipe method to get the data pipeline
data_pipe = dataset.get_data_pipe(
    rebuild=False,  # Set to True if you want to force a rebuild of the data cache
    batch_size=32,  # The batch size for training
    partition="raw",  # Which partition of the data to use: 'raw', 'train', 'valid', or 'test'
    # Other parameters can be left as defaults or specified according to your needs
)

# You can then iterate over data_pipe to access your data
for batch in data_pipe:
    # Process each batch as needed
    print(batch)
    break


trying to open a raw datapipe for...
Will download ~340 GB of data to /Users/stefanhangler/Documents/Uni/Msc_AI/3_Semester/Seminar_Practical Work/Code.nosync/COATI/examples/practical_work_tests . This will take a while. Are you sure? [y/n] 

KeyboardInterrupt: Interrupted by user