In [1]:
import numpy as np, pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdFingerprintGenerator
from sklearn.utils.extmath import randomized_svd
from sklearn.cluster import KMeans

chembl = pd.read_csv('datasets/chembl_pretraining.csv')
pkis2  = pd.read_csv('datasets/pkis2_finetuning.csv')
chembl['__ds__'] = 'chembl_pretraining'
pkis2['__ds__']  = 'pkis2_finetuning'
dfu = pd.concat([chembl, pkis2], ignore_index=True)
smiles = dfu['smiles'].tolist()
prot_cols = [c for c in dfu.columns if c not in ['smiles','__ds__']]


In [5]:
def bm_scaf(s):
    if not isinstance(s, str):
        return None
    s = s.strip()
    if not s:
        return None
    m = Chem.MolFromSmiles(s)
    if m is None:
        return None
    sc = MurckoScaffold.GetScaffoldForMol(m)
    if sc is None or sc.GetNumAtoms() == 0:
        return None
    return Chem.MolToSmiles(sc, isomericSmiles=False)

dfu['scaffold_smiles'] = dfu['smiles'].map(bm_scaf)
scaf_id, scaf_vocab = pd.factorize(dfu['scaffold_smiles'], sort=True)
dfu['scaf_id'] = scaf_id

mols = [Chem.MolFromSmiles(s.strip()) if isinstance(s, str) and s.strip() else None
        for s in dfu['smiles'].tolist()]

grp = dfu.groupby('scaf_id').indices
scaff_to_idx = {k: np.array(v, dtype=int) for k, v in grp.items() if k >= 0}
rep_idx = [v[0] for v in scaff_to_idx.values()]

gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=2048)
X = np.zeros((len(rep_idx), 2048), dtype=bool)
for r, i in enumerate(rep_idx):
    m = mols[i]
    if m is None:
        continue
    bv = gen.GetFingerprint(m)
    on = list(bv.GetOnBits())
    if on:
        X[r, on] = True

In [7]:
def jaccard_AB(A, B, chunk=2048):
    B_u16 = B.astype(np.uint16); B_t = B_u16.T
    bsum = B.sum(axis=1).astype(np.int32)
    n = A.shape[0]
    out = np.empty((n, B.shape[0]), dtype=np.float32)
    for i in range(0, n, chunk):
        Ab = A[i:i+chunk]
        inter = Ab.astype(np.uint16) @ B_t
        asum = Ab.sum(axis=1).astype(np.int32)[:, None]
        union = asum + bsum[None, :] - inter
        union = np.maximum(union, 1)
        out[i:i+chunk] = inter / union
    return out

rng = np.random.default_rng(0)
n = X.shape[0]
m = min(1024, n)
L = rng.choice(n, size=m, replace=False)
B = X[L]
C = jaccard_AB(X, B, chunk=2048).astype(np.float32)
W = jaccard_AB(B, B, chunk=1024).astype(np.float32)
eps = 1e-6
evals, evecs = np.linalg.eigh(W + eps*np.eye(m, dtype=np.float32))
Winv = (evecs / evals) @ evecs.T
Winvsqrt = (evecs / np.sqrt(evals)) @ evecs.T
u = C.sum(axis=0)
v = Winv @ u
d = C @ v
s = (1.0 / np.sqrt(d + 1e-12)).astype(np.float32)
E = (C * s[:, None]) @ Winvsqrt
U, S, VT = randomized_svd(E, n_components=3, random_state=0)
Y = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)
km = KMeans(n_clusters=3, n_init=10, random_state=0)
labels = km.fit_predict(Y)

In [8]:
scaff_keys = list(scaff_to_idx.keys())
scaff_label = {k: labels[i] for i, k in enumerate(scaff_keys)}

prot_counts = {}
sizes_total = {}
sizes_by_ds = {}
for k, idx in scaff_to_idx.items():
    sub = dfu.iloc[idx]
    prot_counts[k] = sub[prot_cols].notna().sum().to_dict()
    sizes_total[k] = len(idx)
    sizes_by_ds[k] = sub['__ds__'].value_counts().to_dict()

total_per_prot = pd.DataFrame.from_dict(prot_counts, orient='index')[prot_cols].fillna(0).sum(axis=0).astype(int)
need_cover = set([p for p in prot_cols if total_per_prot[p] >= 3])

order = sorted(scaff_to_idx.keys(), key=lambda k: ((min([total_per_prot[p] for p,c in prot_counts[k].items() if c>0]) if any(c>0 for c in prot_counts[k].values()) else 10**9), scaff_label[k], -sizes_total[k]))

ratios = np.array([0.8,0.1,0.1])
mol_counts_ds = dfu.groupby('__ds__').size().to_dict()
targets_ds = {ds: (ratios * mol_counts_ds[ds]).astype(int) for ds in mol_counts_ds}
for ds in targets_ds: targets_ds[ds][0] = mol_counts_ds[ds] - targets_ds[ds][1] - targets_ds[ds][2]

split_sizes_ds = {ds: np.zeros(3, dtype=int) for ds in mol_counts_ds}
cov = {p: np.zeros(3, dtype=int) for p in prot_cols}
assign = {}

In [9]:
def penalty_for(k, s):
    pen = 0.0
    for ds in split_sizes_ds:
        sz = sizes_by_ds[k].get(ds,0)
        if sz:
            cur = split_sizes_ds[ds][s]
            tgt = targets_ds[ds][s]
            pen += (cur + sz - tgt)**2 - (cur - tgt)**2
    for p, c in prot_counts[k].items():
        if c>0:
            pen += 0.01 * c * cov[p][s]
    return pen

for k in order:
    best_s, best_pen = None, None
    for sidx in range(3):
        pen = penalty_for(k, sidx)
        if best_pen is None or pen < best_pen:
            best_pen, best_s = pen, sidx
    assign[k] = best_s
    for ds in split_sizes_ds:
        split_sizes_ds[ds][best_s] += sizes_by_ds[k].get(ds,0)
    for p, c in prot_counts[k].items():
        if c>0: cov[p][best_s] += c

mol_split = -np.ones(len(dfu), dtype=int)
for k, s in assign.items():
    idx = scaff_to_idx[k]
    mol_split[idx] = s

sel = pd.DataFrame({"smiles": dfu['smiles'].values, "split": mol_split, "scaffold": [scaf_vocab[i] if i >= 0 else None for i in scaf_id]})
sel = sel[sel['split']>=0]

for ds_name in ['chembl_pretraining','pkis2_finetuning']:
    dfull = dfu[dfu['__ds__']==ds_name].drop(columns=['__ds__','scaffold_smiles','scaf_id'])
    out = dfull.merge(sel[['smiles','split']], on='smiles', how='inner', validate='many_to_one')
    for v, nm in enumerate(['train','val','test']):
        out[out['split']==v].drop(columns=['split']).to_csv(f'datasets/{ds_name}_{nm}.csv', index=False)