In [1]:
import csv, os, multiprocessing as mp
import time
import random
from functools import partial
from rdkit import Chem, DataStructs
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
from rdkit.ML.Cluster import Butina
import pandas as pd
import numpy as np

# ---------- helpers ----------------------------------------------------------

def mol_from_smiles(sm):
    """SMILES → Mol (silently skips invalid)"""
    m = Chem.MolFromSmiles(sm)
    return m

# Create Morgan fingerprint generator (modern approach)
morgan_generator = GetMorganGenerator(radius=3, fpSize=2048)

def ecfp6(mol):
    """RDKit ECFP6 using modern MorganGenerator"""
    return morgan_generator.GetFingerprint(mol) if mol else None

def fp_chunk(smiles):
    """worker: SMILES → fingerprint (or None)"""
    mol = mol_from_smiles(smiles)
    return ecfp6(mol)

def build_dist_mat(fps, thresh=0.6):
    """
    Upper-triangular distance list for Butina:
      dist = 1 - tanimoto
    """
    dists = []
    for i in range(1, len(fps)):
        sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
        dists.extend(1 - s for s in sims)
    return dists

def diverse_subsample(fps, indices, target_size=15000, random_seed=42):
    """
    Create diverse subsample using MaxMin algorithm
    Returns: (subsample_fps, subsample_indices)
    """
    if len(fps) <= target_size:
        return fps, indices
    
    print(f"Creating diverse subsample of {target_size:,} from {len(fps):,} molecules...")
    
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Start with random molecule
    selected_indices = [random.randint(0, len(fps) - 1)]
    selected_fps = [fps[selected_indices[0]]]
    remaining_indices = list(range(len(fps)))
    remaining_indices.remove(selected_indices[0])
    
    # MaxMin selection: pick molecule most distant from already selected
    for i in range(1, min(target_size, len(fps))):
        if i % 1000 == 0:
            print(f"  Selected {i:,}/{target_size:,} diverse molecules...")
        
        max_min_dist = -1
        best_idx = -1
        
        # Sample subset for efficiency (check 1000 random molecules)
        check_indices = random.sample(remaining_indices, min(1000, len(remaining_indices)))
        
        for idx in check_indices:
            # Find minimum distance to any selected molecule
            min_dist = min(1 - DataStructs.TanimotoSimilarity(fps[idx], sel_fp) 
                          for sel_fp in selected_fps)
            
            if min_dist > max_min_dist:
                max_min_dist = min_dist
                best_idx = idx
        
        selected_indices.append(best_idx)
        selected_fps.append(fps[best_idx])
        remaining_indices.remove(best_idx)
    
    # Map back to original indices
    original_indices = [indices[i] for i in selected_indices]
    
    print(f"  Diverse subsample created: {len(selected_fps):,} molecules")
    return selected_fps, original_indices

def hierarchical_butina_clustering(fps, indices, thresh=0.6, max_subsample=15000, progress_interval=2000):
    """
    Hierarchical Butina clustering for large datasets
    
    Stage 1: Create diverse subsample
    Stage 2: Apply Butina clustering to subsample  
    Stage 3: Assign remaining molecules to nearest cluster centroids
    
    Parameters:
    - fps: list of fingerprints
    - indices: original indices of molecules
    - thresh: similarity threshold
    - max_subsample: maximum size for Butina clustering
    - progress_interval: progress reporting interval
    
    Returns:
    - cluster_assignments: dict mapping original_index -> cluster_id
    """
    
    print(f"Starting hierarchical Butina clustering (threshold={thresh})")
    print(f"Dataset size: {len(fps):,} molecules")
    
    if len(fps) <= max_subsample:
        print("Dataset small enough for direct Butina clustering")
        return direct_butina_clustering(fps, indices, thresh)
    
    # Stage 1: Create diverse subsample
    subsample_fps, subsample_indices = diverse_subsample(fps, indices, max_subsample)
    
    # Stage 2: Apply Butina to subsample
    print(f"Applying Butina clustering to subsample of {len(subsample_fps):,} molecules...")
    start_time = time.time()
    
    dists = build_dist_mat(subsample_fps, thresh)
    print(f"Built distance matrix ({len(dists):,} distances) in {time.time() - start_time:.1f}s")
    
    clusters = Butina.ClusterData(
        dists, len(subsample_fps), distThresh=1 - thresh, isDistData=True
    )
    butina_time = time.time() - start_time
    print(f"Butina clustering completed in {butina_time:.1f}s: {len(clusters):,} clusters")
    
    # Create cluster centroids from subsample
    cluster_centroids = {}
    subsample_assignments = {}
    
    for cluster_id, cluster in enumerate(clusters):
        centroid_idx = cluster[0]  # First molecule in cluster is centroid
        cluster_centroids[cluster_id] = subsample_fps[centroid_idx]
        
        # Assign all molecules in this cluster
        for mol_idx in cluster:
            original_idx = subsample_indices[mol_idx]
            subsample_assignments[original_idx] = cluster_id
    
    print(f"Created {len(cluster_centroids):,} cluster centroids")
    
    # Stage 3: Assign remaining molecules to nearest centroids
    print("Assigning remaining molecules to clusters...")
    
    # Create set of subsampled indices for fast lookup
    subsampled_set = set(subsample_indices)
    
    cluster_assignments = subsample_assignments.copy()
    assignment_start = time.time()
    processed = 0
    
    for i, (fp, orig_idx) in enumerate(zip(fps, indices)):
        if fp is None:
            cluster_assignments[orig_idx] = -1  # Invalid molecule
            continue
            
        if orig_idx in subsampled_set:
            continue  # Already assigned in subsample
        
        # Find nearest cluster centroid
        best_cluster = None
        best_sim = 0
        
        for cluster_id, centroid_fp in cluster_centroids.items():
            sim = DataStructs.TanimotoSimilarity(fp, centroid_fp)
            if sim >= thresh and sim > best_sim:
                best_sim = sim
                best_cluster = cluster_id
        
        if best_cluster is not None:
            cluster_assignments[orig_idx] = best_cluster
        else:
            # Create new singleton cluster
            new_cluster_id = len(cluster_centroids)
            cluster_centroids[new_cluster_id] = fp
            cluster_assignments[orig_idx] = new_cluster_id
        
        processed += 1
        
        # Progress reporting
        if processed % progress_interval == 0:
            elapsed = time.time() - assignment_start
            remaining = len(fps) - len(subsampled_set) - processed
            rate = processed / elapsed if elapsed > 0 else 0
            eta = remaining / rate if rate > 0 else 0
            
            print(f"  Assigned {processed:,} molecules | "
                  f"Rate: {rate:.1f} mol/s | "
                  f"ETA: {eta/60:.1f}min | "
                  f"Total clusters: {len(cluster_centroids):,}")
    
    total_time = time.time() - start_time
    print(f"Hierarchical clustering completed in {total_time:.1f}s")
    print(f"Final cluster count: {len(cluster_centroids):,}")
    
    return cluster_assignments

def direct_butina_clustering(fps, indices, thresh=0.6):
    """Direct Butina clustering for smaller datasets"""
    print("Applying direct Butina clustering...")
    
    start_time = time.time()
    dists = build_dist_mat(fps, thresh)
    print(f"Built distance matrix in {time.time() - start_time:.1f}s")
    
    clusters = Butina.ClusterData(
        dists, len(fps), distThresh=1 - thresh, isDistData=True
    )
    print(f"Butina clustering completed: {len(clusters):,} clusters")
    
    # Map to cluster assignments
    cluster_assignments = {}
    for cluster_id, cluster in enumerate(clusters):
        for mol_idx in cluster:
            original_idx = indices[mol_idx]
            cluster_assignments[original_idx] = cluster_id
    
    return cluster_assignments

# ---------- main workflow (notebook version) --------------------------------

def cluster_molecules(csv_file, smiles_col="smiles", thresh=0.6, workers=None, out_csv="clusters.csv"):
    """
    Cluster molecules using hierarchical Butina clustering
    
    Parameters:
    - csv_file: path to input CSV file
    - smiles_col: column name containing SMILES
    - thresh: Tanimoto similarity threshold (0.55 = 55% similarity)
    - workers: number of parallel workers (None = use all cores - 1)
    - out_csv: output file name
    """
    
    if workers is None:
        workers = max(1, os.cpu_count() - 1)
    
    print(f"Starting hierarchical Butina clustering with threshold {thresh} using {workers} workers...")
    
    # 1  Read SMILES
    df = pd.read_csv(csv_file)
    smiles = df[smiles_col].tolist()
    print(f"Loaded {len(smiles):,} SMILES")

    # 2  Parallel fingerprints
    print("Computing fingerprints...")
    start_time = time.time()
    
    if workers > 1:
        try:
            ctx = mp.get_context("fork")
            with ctx.Pool(processes=workers) as pool:
                fps = pool.map(fp_chunk, smiles)
        except:
            print("Multiprocessing failed, using sequential processing...")
            fps = [fp_chunk(sm) for sm in smiles]
    else:
        fps = [fp_chunk(sm) for sm in smiles]
    
    fp_time = time.time() - start_time
    valid_count = sum(1 for fp in fps if fp is not None)
    print(f"Generated {valid_count:,} valid fingerprints (dropped {len(smiles) - valid_count} invalid) in {fp_time:.1f}s")

    # 3  Hierarchical Butina clustering
    print("Performing hierarchical Butina clustering...")
    indices = list(range(len(fps)))
    cluster_assignments = hierarchical_butina_clustering(fps, indices, thresh)

    # 4  Create results dataframe and save
    df['cluster_id'] = df.index.map(lambda i: cluster_assignments.get(i, -1))
    
    # Save to CSV
    df[[smiles_col, 'cluster_id']].to_csv(out_csv, index=False)
    print(f"Wrote {out_csv}")
    
    # Print cluster statistics
    cluster_series = df['cluster_id']
    valid_clusters = cluster_series[cluster_series >= 0]
    cluster_sizes = valid_clusters.value_counts().sort_values(ascending=False)
    
    print(f"\nCluster statistics:")
    print(f"Total clusters: {len(cluster_sizes)}")
    print(f"Valid molecules clustered: {len(valid_clusters):,}")
    print(f"Invalid molecules: {(cluster_series == -1).sum()}")
    print(f"Largest cluster: {cluster_sizes.iloc[0] if len(cluster_sizes) > 0 else 0} molecules")
    print(f"Clusters with >10 molecules: {(cluster_sizes > 10).sum()}")
    print(f"Clusters with >100 molecules: {(cluster_sizes > 100).sum()}")
    print(f"Singleton clusters: {(cluster_sizes == 1).sum()}")
    print(f"Mean cluster size: {cluster_sizes.mean():.1f}")
    
    return df

# ---------- Run clustering ---------------------------------------------------

# Set parameters
CSV_FILE = "chembl_pretraining.csv"
SMILES_COL = "smiles"
THRESHOLD = 0.6  # 60% Tanimoto similarity threshold
WORKERS = 12       # Adjust based on your system
OUTPUT_FILE = "chembl_clusters_hierarchical.csv"

# Run the clustering
try:
    result_df = cluster_molecules(
        csv_file=CSV_FILE,
        smiles_col=SMILES_COL, 
        thresh=THRESHOLD,
        workers=WORKERS,
        out_csv=OUTPUT_FILE
    )
    print(f"\nClustering completed successfully!")
    print(f"Results saved to {OUTPUT_FILE}")
    print(f"DataFrame shape: {result_df.shape}")
    
except Exception as e:
    print(f"Error during clustering: {e}")
    import traceback
    traceback.print_exc()

Starting hierarchical Butina clustering with threshold 0.6 using 12 workers...
Loaded 79,492 SMILES
Computing fingerprints...
Generated 79,492 valid fingerprints (dropped 0 invalid) in 4.0s
Performing hierarchical Butina clustering...
Starting hierarchical Butina clustering (threshold=0.6)
Dataset size: 79,492 molecules
Creating diverse subsample of 15,000 from 79,492 molecules...
  Selected 1,000/15,000 diverse molecules...


KeyboardInterrupt: 