# 🧬 molGPT: Conditional Molecular Generation with Transformers

**molGPT** is an end-to-end pipeline for generating novel drug-like molecules using a transformer-based language model (GPT-style). This project focuses on **conditional SMILES generation**, where molecular properties such as LogP, QED, TPSA, and scaffold are used to guide the generation process. It's an exciting intersection of **natural language processing** and **computational drug discovery**.

---

## 🚀 Highlights

- ✅ Trains a decoder-only transformer (GPT) to generate valid SMILES strings
- 🎯 Conditioned on molecular properties like:
  - **LogP** (lipophilicity)
  - **QED** (quantitative estimate of drug-likeness)
  - **TPSA** (topological polar surface area)
  - **Scaffold** (molecular backbone)
- 📊 Includes evaluation metrics:
  - SMILES validity
  - Molecular uniqueness
  - Structural novelty (Tanimoto similarity)
  - Property alignment

---

## 🧪 Dataset

Uses the [MOSES](https://github.com/molecularsets/moses) dataset — a curated collection of drug-like molecules, suitable for generative modeling.

---




In [1]:
# # Dependencies

# %pip install pandas rdkit transformers[torch] accelerate>=0.26.0
# %pip install scikit-learn matplotlib tqdm pathos
# %pip install torch --index-url https://download.pytorch.org/whl/cu118
# %pip install git+https://github.com/molecularsets/moses.git

## Import packages

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
from pathos.multiprocessing import ProcessingPool as Pool
from functools import partial

from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors, QED, rdMolDescriptors
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    GPT2Config,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

## 2. Data Loading and Filtering

 We'll use the MOSES dataset, which is a curated set of drug-like molecules specifically
 designed for machine learning applications. It's much smaller than ChEMBL
 (https://www.ebi.ac.uk/chembl/, database: https://chembl.gitbook.io/chembl-interface-documentation/)
 but still contains high-quality, drug-like compounds.

In [3]:
def load_moses_data():
    filtered_path = Path('dataset_v1_filtered.csv')
    if filtered_path.exists():
      print(f"Loading pre-filtered dataset from {filtered_path}")
      df = pd.read_csv(filtered_path)
      print(f"Loaded {len(df)} pre-filtered drug-like molecules")
      return df
    train_path = Path('dataset_v1.csv')
    print(f"Reading the original dataset from {train_path}")
    df = pd.read_csv(train_path)
    print(df.info)
    print("\nFirst few rows:")
    print(df.head())

    if 'smiles' in df.columns:
      df = df.rename(columns={'smiles': 'SMILES'})
    elif 'SMILES' not in df.columns:
      print("Available columns:", df.columns.tolist())
      raise ValueError("No 'smiles' or 'SMILES' column found in the dataset.")
    smiles_list = df['SMILES'].values
    print(f"\nFound {len(smiles_list)} SMILES strings in the dataset.")

    # Here we begin with filtering
    valid_mols = []
    for smi in tqdm(smiles_list, desc="Validating and filtering SMILES"):
      # Filter 1 ensures that all molecules are chemically and structurally valid
      mol = Chem.MolFromSmiles(smi)
      if mol is not None:
        # Filter 2 makes sure molecules with physicochemical properties out of a desirable range are removed from the list
        mw = Descriptors.ExactMolWt(mol) # Molecular weight (in Da units)
        logp = Descriptors.MolLogP(mol) # LogP(measured lipophilicity, i.e., how much a molecule likes to be solved in fat versus water)
        hbd = rdMolDescriptors.CalcNumHBD(mol) # Number of a molecule's hydrogen-bond donor heavy atoms
        hba = rdMolDescriptors.CalcNumHBA(mol) # Number of a molecule's hydrogen-bond acceptor heavy atoms

        if mw <= 500 and logp <=5 and hbd <= 5 and hba <= 10:
          # Filter 3 screens for problematic chemical groups shown to be associated with toxicity, carcinogenicity, etc.
          has_bad_groups = False
          patt_list = [
              '[N+]([O-])=O',  # Nitro groups: Highly reactive, can cause DNA damage and carcinogenicity
              '[S](=[O])(=[O])',  # Sulfonyl groups: Can be chemically reactive and cause skin/eye irritation
              '[P](=[O])',  # Phosphoryl groups: Potential toxicity and instability in biological systems
              '[As]'  # Arsenic: Highly toxic heavy metal with severe health risks and carcinogenic properties
          ]
          for patt in patt_list:
              if mol.HasSubstructMatch(Chem.MolFromSmarts(patt)):
                has_bad_groups = True
                break
          if not has_bad_groups:
            valid_mols.append(smi)

    # Create the filtered dataframe
    filtered_df = df[df['SMILES'].isin(valid_mols)]
    print(f"\nAfter filtering, {len(filtered_df)} molecules remain")
    return filtered_df

# Load and display filtered data
filtered_df = load_moses_data()
print("\nFirst few rows of filtered dataset:")
print(filtered_df.head())



Loading pre-filtered dataset from dataset_v1_filtered.csv
Loaded 1735494 pre-filtered drug-like molecules

First few rows of filtered dataset:
                                   SMILES  SPLIT
0  CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1  train
1    CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1  train
2  CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1   test
3     Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO  train
4        Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C  train


In [4]:
filtered_df.to_csv("dataset_v1_filtered", index=False)


## 3. Descriptor Calculation (Scaffolds, logP, QED, TPSA)
 We compute additional descriptors needed:
 - Murcko Scaffolds: Core molecular framework obtained by removing side chains and keeping only ring systems and linkers between rings
 - QED
 - TPSA
 - LogP

We'll store these in the DataFrame alongside the SMILES.

In [5]:
def calculate_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return None, None, None, None

    try:
        scaffold = GetScaffoldForMol(mol)
        scaffold_smiles = Chem.MolToSmiles(scaffold)

        qed_val = QED.qed(mol)

        tpsa_val = rdMolDescriptors.CalcTPSA(mol)

        logp_val = Descriptors.MolLogP(mol)

        return scaffold_smiles, logp_val, qed_val, tpsa_val
    except:
        return None, None, None, None

def process_batch(smiles_batch):
    results = []
    for smi in smiles_batch:
        results.append(calculate_descriptors(smi))
    return results

def calculate_descriptors_parallel(df, parallel=True, batch_size=100):
    print("Calculating molecular descriptors...")

    if parallel:
        n_cores = Pool().ncpus
        print(f"Detected {n_cores} CPU cores")
        print(f"Running descriptor calculations in parallel across {n_cores} cores")

        smiles_list = df['SMILES'].tolist()
        n_batches = (len(smiles_list) + batch_size - 1) // batch_size
        batches = [smiles_list[i*batch_size:(i+1)*batch_size]
                  for i in range(n_batches)]

        print(f"Processing {len(smiles_list)} SMILES strings in {n_batches} batches")

        with Pool() as pool:
            results = list(tqdm(
                pool.imap(process_batch, batches),
                total=n_batches,
                desc="Processing batches"
            ))

        all_results = [item for batch in results for item in batch]

    else:
        print("Running descriptor calculations sequentially")
        all_results = []
        for smi in tqdm(df['SMILES'], desc="Calculating descriptors"):
            all_results.append(calculate_descriptors(smi))

    scaffolds, logps, qeds, tpsas = zip(*all_results)

    df['Scaffold'] = scaffolds
    df['LogP'] = logps
    df['QED'] = qeds
    df['TPSA'] = tpsas

    df = df.dropna(subset=['Scaffold', 'LogP', 'QED', 'TPSA'])
    print(f"Final dataset size after descriptor calculation: {len(df)}")

    print("\nDescriptor Statistics:")
    print(f"LogP range: {df['LogP'].min():.2f} to {df['LogP'].max():.2f}")
    print(f"QED range: {df['QED'].min():.2f} to {df['QED'].max():.2f}")
    print(f"TPSA range: {df['TPSA'].min():.2f} to {df['TPSA'].max():.2f}")

    return df

filtered_df = calculate_descriptors_parallel(filtered_df, parallel=True)
filtered_df.head()

Calculating molecular descriptors...
Detected 16 CPU cores
Running descriptor calculations in parallel across 16 cores
Processing 1735494 SMILES strings in 17355 batches


Processing batches:   0%|          | 0/17355 [00:00<?, ?it/s]

NameError: name 'calculate_descriptors' is not defined