In [1]:
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdFingerprintGenerator
from sklearn.utils.extmath import randomized_svd
from sklearn.cluster import KMeans

In [2]:
df = pd.read_csv('datasets/chembl_pretraining.csv')
smiles = df['smiles'].tolist()
protein_cols = [c for c in df.columns if c != 'smiles']

In [3]:
def bm_scaf(s):
    m = Chem.MolFromSmiles(s)
    if m is None: return None
    sc = MurckoScaffold.GetScaffoldForMol(m)
    if sc is None or sc.GetNumAtoms()==0: return None
    return Chem.MolToSmiles(sc, isomericSmiles=False)
df['scaffold_smiles'] = df['smiles'].map(bm_scaf)
scaf_id, scaf_vocab = pd.factorize(df['scaffold_smiles'], sort=True)
df['scaf_id'] = scaf_id
n_mols = len(df)

In [4]:
mols = [Chem.MolFromSmiles(s) for s in smiles]
grouped = df.groupby('scaf_id').indices
scaff_to_idx = {k: np.array(v, dtype=int) for k, v in grouped.items() if k >= 0}
rep_idx = [v[0] for v in scaff_to_idx.values()]
gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=2048)
X = np.zeros((len(rep_idx), 2048), dtype=bool)
for r, i in enumerate(rep_idx):
    bv = gen.GetFingerprint(mols[i])
    on = list(bv.GetOnBits())
    if on: X[r, on] = True

In [5]:
def jaccard_AB(A, B, chunk=2048):
    B_u16 = B.astype(np.uint16); B_t = B_u16.T
    bsum = B.sum(axis=1).astype(np.int32)
    n = A.shape[0]
    out = np.empty((n, B.shape[0]), dtype=np.float32)
    for i in range(0, n, chunk):
        Ab = A[i:i+chunk]
        inter = Ab.astype(np.uint16) @ B_t
        asum = Ab.sum(axis=1).astype(np.int32)[:, None]
        union = asum + bsum[None, :] - inter
        union = np.maximum(union, 1)
        out[i:i+chunk] = inter / union
    return out
rng = np.random.default_rng(0)
n = X.shape[0]
m = min(1024, n)
L = rng.choice(n, size=m, replace=False)
B = X[L]
C = jaccard_AB(X, B, chunk=2048).astype(np.float32)
W = jaccard_AB(B, B, chunk=1024).astype(np.float32)
eps = 1e-6
evals, evecs = np.linalg.eigh(W + eps*np.eye(m, dtype=np.float32))
Winv = (evecs / evals) @ evecs.T
Winvsqrt = (evecs / np.sqrt(evals)) @ evecs.T
u = C.sum(axis=0)
v = Winv @ u
d = C @ v
s = (1.0 / np.sqrt(d + 1e-12)).astype(np.float32)
E = (C * s[:, None]) @ Winvsqrt
U, S, VT = randomized_svd(E, n_components=3, random_state=0)
Y = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)
labels = KMeans(n_clusters=3, n_init=10, random_state=0).fit_predict(Y)
scaff_order = np.argsort(labels)
scaff_keys = list(scaff_to_idx.keys())

In [6]:
prot_counts = {}
for k, idx in scaff_to_idx.items():
    sub = df.iloc[idx][protein_cols]
    prot_counts[k] = sub.notna().sum().to_dict()
total_per_prot = pd.DataFrame.from_dict(prot_counts, orient='index')[protein_cols].fillna(0).sum(axis=0).astype(int)
need_cover = set([p for p in protein_cols if total_per_prot[p] >= 3])

In [7]:
n_scaff = len(scaff_to_idx)
ratios = np.array([0.8,0.1,0.1])
targets = (ratios * n_scaff).astype(int); targets[0] = n_scaff - targets[1] - targets[2]
assign = {k: None for k in scaff_to_idx.keys()}
split_sizes = np.zeros(3, dtype=int)
cov = {p: np.zeros(3, dtype=int) for p in protein_cols}
order = sorted(scaff_to_idx.keys(), key=lambda k: (min([total_per_prot[p] for p,c in prot_counts[k].items() if c>0]) if any(c>0 for c in prot_counts[k].values()) else 10**9,))
for k in order:
    gains = []
    pc = prot_counts[k]
    for s in range(3):
        g = sum(1 for p,c in pc.items() if c>0 and p in need_cover and cov[p][s]==0)
        over = 1 if split_sizes[s] >= targets[s] else 0
        gains.append((g, -over, -split_sizes[s], s))
    s = max(gains)[-1]
    assign[k] = s
    split_sizes[s] += 1
    for p,c in prot_counts[k].items():
        if c>0: cov[p][s] += 1

In [8]:
mol_split = np.full(n_mols, -1, dtype=int)
for k, s in assign.items():
    idx = scaff_to_idx[k]
    mol_split[idx] = s

In [9]:
train_idx = np.where(mol_split==0)[0]
val_idx = np.where(mol_split==1)[0]
test_idx = np.where(mol_split==2)[0]

In [10]:
df_selectivity = pd.DataFrame({
    "mol_idx": np.arange(n_mols),
    "smiles": df['smiles'].values,
    "split": mol_split,
    "cluster": -1,
    "scaffold": [scaf_vocab[i] if i >= 0 else None for i in scaf_id]
})
df_selectivity.to_csv("split_molecules_selectivity.csv", index=False)

In [11]:
import os
os.makedirs('datasets', exist_ok=True)
df_full = pd.read_csv('datasets/chembl_pretraining.csv')
df_join = df_full.merge(df_selectivity[['smiles','split']], on="smiles", how="inner", validate="many_to_one")
for split_val, name in enumerate(['train','val','test']):
    df_join[df_join['split']==split_val].drop(columns=['split']).to_csv(f'datasets/chembl_pretraining_{name}.csv', index=False)