In [2]:
import random, pathlib, multiprocessing as mp
import pandas as pd
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import rdFingerprintGenerator as rfg
from rdkit.ML.Cluster import Butina

RDLogger.DisableLog("rdApp.*")          # silence RDKit chatter


In [4]:
csv_path   = "chembl_pretraining.csv"  # your file
tan_thresh = 0.7                       # Tanimoto threshold for clustering
test_frac  = 0.10                      # test-set size  (by clusters)
val_frac   = 0.10                      # validation size (by clusters)
radius     = 3                         # ECFP-6 → radius 3
fp_bits    = 2048                      # fingerprint length
seed       = 42                        # reproducibility
n_jobs     = mp.cpu_count()            # parallel SMILES → FP


In [6]:
df = pd.read_csv(csv_path, engine="pyarrow")
smiles_col = "smiles" if "smiles" in df.columns else "Smiles"
df = df.dropna(subset=[smiles_col]).reset_index(drop=True)

protein_cols = [c for c in df.columns if c not in {smiles_col}]
print(f"{len(df):,} molecules, {len(protein_cols)} protein columns "
      f"({df[protein_cols].count().sum():,} measured activities)")


79,492 molecules, 357 protein columns (122,822 measured activities)


In [6]:
# ─── Cell 4  ── Fingerprints with MorganGenerator (single-thread) ─────────────
from tqdm.auto import tqdm          # nice progress bar; pip install tqdm if needed

gen = rfg.GetMorganGenerator(radius=radius,
                             fpSize=fp_bits,
                             includeChirality=True)

def one_fp(smi: str):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        raise ValueError(f"Bad SMILES: {smi}")
    return gen.GetFingerprint(mol)

fps = [one_fp(smi) for smi in tqdm(df[smiles_col], desc="ECFP6")]

print(f"✓ fingerprints ready ({len(fps):,} molecules, single core)")


ECFP6:   0%|          | 0/79492 [00:00<?, ?it/s]

✓ fingerprints ready (79,492 molecules, single core)


In [None]:
# build condensed distance list (1-similarity) expected by Butina
dists = []
for i in range(1, len(fps)):
    dists.extend(1 - x for x in DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]))

clusters = Butina.ClusterData(dists, len(fps), tan_thresh, isDistData=True)
cid_map  = {m: c for c, members in enumerate(clusters) for m in members}
df["cluster_id"] = df.index.map(cid_map)

print(f"✓ {len(clusters):,} clusters (Tanimoto ≤ {tan_thresh} across clusters)")


In [7]:
random.seed(seed)
cids = list(df["cluster_id"].unique())
random.shuffle(cids)

limits  = {"test": int(round(len(cids)*test_frac)),
           "val":  int(round(len(cids)*val_frac))}
counts  = {"train": 0, "val": 0, "test": 0}
split_of = {}

for cid in cids:
    for label in ("test", "val", "train"):
        if label == "train" or counts[label] < limits[label]:
            split_of[cid] = label
            counts[label] += 1
            break

df["split"] = df["cluster_id"].map(split_of)

for s in ("train","val","test"):
    sub = df[df.split==s]
    print(f"{s:<5} : {len(sub):>7} mols in {sub.cluster_id.nunique():>6} clusters")


KeyError: 'cluster_id'