In [None]:
#!/usr/bin/env python3
"""
Data preparation script for REINVENT4 transfer learning
Downloads and processes ChEMBL data for molecular generation
"""

import pandas as pd
import requests
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen
import sqlite3
from typing import List, Optional
import os
from tqdm import tqdm

class ChEMBLDataProcessor:
    def __init__(self, output_dir: str = "./data"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
    def download_chembl_subset(self, limit: int = 100000) -> pd.DataFrame:
        """
        Download a subset of ChEMBL data via their web services API
        For full dataset, you'd download the SQLite database
        """
        print(f"Downloading {limit} molecules from ChEMBL...")
        
        # ChEMBL REST API endpoint for molecules
        base_url = "https://www.ebi.ac.uk/chembl/api/data/molecule"
        
        molecules = []
        offset = 0
        batch_size = 1000
        
        while len(molecules) < limit:
            params = {
                'format': 'json',
                'limit': batch_size,
                'offset': offset,
                'molecule_chembl_id__isnull': False,
                'molecule_structures__canonical_smiles__isnull': False
            }
            
            try:
                response = requests.get(base_url, params=params, timeout=30)
                response.raise_for_status()
                data = response.json()
                
                if not data.get('molecules'):
                    break
                    
                for mol in data['molecules']:
                    if mol.get('molecule_structures') and mol['molecule_structures'].get('canonical_smiles'):
                        molecules.append({
                            'chembl_id': mol['molecule_chembl_id'],
                            'smiles': mol['molecule_structures']['canonical_smiles'],
                            'molecular_weight': mol.get('molecule_properties', {}).get('mw_freebase'),
                            'alogp': mol.get('molecule_properties', {}).get('alogp')
                        })
                        
                        if len(molecules) >= limit:
                            break
                            
                offset += batch_size
                print(f"Downloaded {len(molecules)} molecules...")
                
            except Exception as e:
                print(f"Error downloading batch: {e}")
                break
                
        return pd.DataFrame(molecules)
    
    def filter_molecules(self, df: pd.DataFrame, 
                        min_mw: float = 150, 
                        max_mw: float = 500,
                        min_atoms: int = 10,
                        max_atoms: int = 50) -> pd.DataFrame:
        """
        Filter molecules based on drug-likeness criteria
        """
        print("Filtering molecules for drug-likeness...")
        initial_count = len(df)
        
        valid_smiles = []
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing molecules"):
            smiles = row['smiles']
            mol = Chem.MolFromSmiles(smiles)
            
            if mol is None:
                continue
                
            # Basic filters
            mw = Descriptors.MolWt(mol)
            num_atoms = mol.GetNumHeavyAtoms()
            logp = Crippen.MolLogP(mol)
            
            # Apply filters
            if (min_mw <= mw <= max_mw and 
                min_atoms <= num_atoms <= max_atoms and
                -2 <= logp <= 5):  # Reasonable LogP range
                
                valid_smiles.append({
                    'chembl_id': row['chembl_id'],
                    'smiles': smiles,
                    'molecular_weight': mw,
                    'logp': logp,
                    'num_atoms': num_atoms
                })
        
        filtered_df = pd.DataFrame(valid_smiles)
        print(f"Filtered from {initial_count} to {len(filtered_df)} molecules")
        return filtered_df
    
    def split_data(self, df: pd.DataFrame, 
                   train_ratio: float = 0.8,
                   val_ratio: float = 0.1) -> tuple:
        """
        Split data into train/validation/test sets
        """
        df = df.sample(frac=1).reset_index(drop=True)  # Shuffle
        
        n_train = int(len(df) * train_ratio)
        n_val = int(len(df) * val_ratio)
        
        train_df = df[:n_train]
        val_df = df[n_train:n_train + n_val]
        test_df = df[n_train + n_val:]
        
        return train_df, val_df, test_df
    
    def save_smiles_files(self, train_df: pd.DataFrame, 
                         val_df: pd.DataFrame, 
                         test_df: pd.DataFrame):
        """
        Save SMILES files in format expected by REINVENT4
        """
        # Save training set
        train_file = os.path.join(self.output_dir, "chembl_molecules.smi")
        with open(train_file, 'w') as f:
            for smiles in train_df['smiles']:
                f.write(f"{smiles}\n")
        
        # Save validation set
        val_file = os.path.join(self.output_dir, "chembl_validation.smi")
        with open(val_file, 'w') as f:
            for smiles in val_df['smiles']:
                f.write(f"{smiles}\n")
        
        # Save test set
        test_file = os.path.join(self.output_dir, "chembl_test.smi")
        with open(test_file, 'w') as f:
            for smiles in test_df['smiles']:
                f.write(f"{smiles}\n")
        
        # Save metadata
        metadata_file = os.path.join(self.output_dir, "dataset_info.txt")
        with open(metadata_file, 'w') as f:
            f.write(f"Training molecules: {len(train_df)}\n")
            f.write(f"Validation molecules: {len(val_df)}\n")
            f.write(f"Test molecules: {len(test_df)}\n")
            f.write(f"Total molecules: {len(train_df) + len(val_df) + len(test_df)}\n")
        
        print(f"Saved SMILES files to {self.output_dir}")
        print(f"Training: {len(train_df)} molecules")
        print(f"Validation: {len(val_df)} molecules") 
        print(f"Test: {len(test_df)} molecules")

def main():
    """
    Main data preparation pipeline
    """
    processor = ChEMBLDataProcessor()
    
    # Download ChEMBL subset
    df = processor.download_chembl_subset(limit=50000)  # Adjust limit as needed
    
    if df.empty:
        print("No molecules downloaded. Check your internet connection.")
        return
    
    # Filter for drug-like molecules
    filtered_df = processor.filter_molecules(df)
    
    if filtered_df.empty:
        print("No molecules passed filtering criteria.")
        return
    
    # Split data
    train_df, val_df, test_df = processor.split_data(filtered_df)
    
    # Save SMILES files
    processor.save_smiles_files(train_df, val_df, test_df)
    
    print("Data preparation complete!")
    print("You can now run transfer learning with:")
    print("reinvent -l transfer_learning.log transfer_learning.toml")

if __name__ == "__main__":
    main()
    

Downloading 50000 molecules from ChEMBL...
Downloaded 1000 molecules...
Downloaded 2000 molecules...
Downloaded 3000 molecules...
Downloaded 4000 molecules...
Downloaded 5000 molecules...
Downloaded 6000 molecules...
Downloaded 7000 molecules...
Downloaded 8000 molecules...
Downloaded 9000 molecules...
Downloaded 10000 molecules...
Downloaded 11000 molecules...
Downloaded 12000 molecules...
Downloaded 13000 molecules...
Downloaded 14000 molecules...
Downloaded 15000 molecules...
Downloaded 16000 molecules...
Downloaded 17000 molecules...
Downloaded 18000 molecules...
Downloaded 19000 molecules...
Downloaded 20000 molecules...
Downloaded 21000 molecules...
Downloaded 22000 molecules...
Downloaded 23000 molecules...
Downloaded 24000 molecules...
Downloaded 25000 molecules...
Downloaded 26000 molecules...
Downloaded 27000 molecules...
Downloaded 28000 molecules...
Downloaded 29000 molecules...
Downloaded 30000 molecules...
Downloaded 31000 molecules...
Downloaded 32000 molecules...
Downlo

Processing molecules: 100%|██████████| 50000/50000 [00:43<00:00, 1154.10it/s]

Filtered from 50000 to 32785 molecules
Saved SMILES files to ./data
Training: 26228 molecules
Validation: 3278 molecules
Test: 3279 molecules
Data preparation complete!
You can now run transfer learning with:
reinvent -l transfer_learning.log transfer_learning.toml





: 