In [2]:
# Jupyter-ready splitter (pretraining + finetuning, scaffold-based, family-stratified)

import os, hashlib, json
import numpy as np, pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold as MS

def canon_smiles(s):
    m = Chem.MolFromSmiles(str(s))
    return None if m is None else Chem.MolToSmiles(m, canonical=True)
def inchikey(s):
    m = Chem.MolFromSmiles(s)
    return None if m is None else Chem.MolToInchiKey(m)
def scaffold_smiles(s):
    m = Chem.MolFromSmiles(s)
    if m is None: return None
    scaf = MS.GetScaffoldForMol(m)
    return Chem.MolToSmiles(scaf, canonical=True) if scaf is not None else None
def det_order(keys, seed):
    def h(x): return int(hashlib.md5((str(seed)+x).encode()).hexdigest(),16)
    return sorted(keys, key=h)
def load_map(path, protein_col, family_col):
    mp = pd.read_csv(path)
    mp = mp[[protein_col, family_col]].dropna()
    mp[protein_col] = mp[protein_col].astype(str)
    mp[family_col] = mp[family_col].astype(str)
    return dict(zip(mp[protein_col], mp[family_col]))
def prep_df(path, smiles_col, protein2fam):
    df = pd.read_csv(path)
    df['__smiles'] = df[smiles_col].map(canon_smiles)
    df = df.dropna(subset=['__smiles']).drop_duplicates('__smiles').reset_index(drop=True)
    df['__inchikey'] = df['__smiles'].map(inchikey)
    pcols = [c for c in df.columns if c in protein2fam]
    fams = sorted(set(protein2fam[c] for c in pcols))
    fam2idx = {f:i for i,f in enumerate(fams)}
    M = np.zeros((len(df), len(fams)), dtype=int)
    for f in fams:
        cols = [c for c in pcols if protein2fam[c]==f]
        if cols:
            M[:, fam2idx[f]] = (df[cols].fillna(0).astype(float).values>0).any(axis=1).astype(int)
    df['__scaffold'] = df['__smiles'].map(scaffold_smiles)
    df = df.dropna(subset=['__scaffold']).reset_index(drop=True)
    return df, fams, M
def aggregate_scaffolds(df, fams, M):
    rows = []
    for scaf, idxs in df.groupby('__scaffold').indices.items():
        ids = list(idxs)
        vec = M[ids].sum(axis=0)
        dom = fams[int(vec.argmax())] if vec.sum()>0 else None
        rows.append({'scaffold':scaf,'n':len(ids),'smiles_set':set(df.loc[ids,'__smiles']), 'fam_vec':vec, 'dominant_family':dom})
    return pd.DataFrame(rows)
def greedy_split(scaf_df, fams, ratios, seed, forbid_train_scaffolds=set(), forbid_train_smiles=set()):
    total = scaf_df['n'].sum()
    F = len(fams)
    fam_tot = np.zeros(F, dtype=float)
    for v in scaf_df['fam_vec']: fam_tot += v
    tgt_size = {k:int(round(total*r)) for k,r in ratios.items()}
    tgt_fam = {k:(fam_tot*r) for k,r in ratios.items()}
    cur_size = {k:0 for k in ratios}
    cur_fam = {k:np.zeros(F, dtype=float) for k in ratios}
    assign = {}
    order = det_order(list(scaf_df['scaffold']), seed)
    pos = {s:i for i,s in enumerate(scaf_df['scaffold'])}
    for s in order:
        row = scaf_df.iloc[pos[s]]
        gains = {}
        for split in ratios:
            if split=='train' and (s in forbid_train_scaffolds or (row['smiles_set'] & forbid_train_smiles)):
                gains[split] = -1e18
                continue
            fam_need = np.maximum(tgt_fam[split]-cur_fam[split], 0)
            size_need = max(tgt_size[split]-cur_size[split], 0)
            gains[split] = float(np.minimum(row['fam_vec'], fam_need).sum()) + min(row['n'], size_need)*0.1
        best = max(gains, key=lambda k:(gains[k], -cur_size[k]))
        assign[s] = best
        cur_size[best] += row['n']
        cur_fam[best] += row['fam_vec']
    scaf_df['split'] = scaf_df['scaffold'].map(assign)
    return scaf_df
def finalize_compounds(df, scaf_assign):
    m = dict(zip(scaf_assign['scaffold'], scaf_assign['split']))
    df['split'] = df['__scaffold'].map(m)
    return df
def family_summary(df, fams, M):
    out = []
    for sp in ['train','val','test']:
        mask = (df['split']==sp).values
        vec = np.zeros(len(fams),dtype=int) if mask.sum()==0 else M[mask].sum(axis=0)
        out.append(pd.DataFrame({'family':fams,'split':sp,'active_compounds':vec}))
    return pd.concat(out, ignore_index=True)
def align(df, fams_all, fams_local, M_local):
    idx = {f:i for i,f in enumerate(fams_local)}
    M = np.zeros((len(df), len(fams_all)), dtype=int)
    for j,f in enumerate(fams_all):
        if f in idx: M[:,j] = M_local[:, idx[f]]
    return M

# ---- Configure paths/params (edit these) ----
pretraining_path = 'datasets/chembl_pretraining.csv'
finetuning_path = 'datasets/pkis2_finetuning.csv'
mapping_path = 'uniprotkb_chembl_ID_kinomescan_cleaned.csv'
smiles_col = 'smiles'
protein_col = 'KINOMEscan® Gene Symbol'
family_col = 'Family'
ratios = {'train':0.8,'val':0.1,'test':0.1}
seed = 42
outdir = 'splits_out'
os.makedirs(outdir, exist_ok=True)

# ---- Run ----
protein2fam = load_map(mapping_path, protein_col, family_col)
df_fine, fams_f, M_f_local = prep_df(finetuning_path, smiles_col, protein2fam)
df_pre,  fams_p, M_p_local = prep_df(pretraining_path, smiles_col, protein2fam)
fams = sorted(set(fams_f) | set(fams_p))
M_f = align(df_fine, fams, fams_f, M_f_local)
M_p = align(df_pre,  fams, fams_p, M_p_local)

scaf_f = aggregate_scaffolds(df_fine, fams, M_f)
scaf_f = greedy_split(scaf_f, fams, ratios, seed)
df_fine = finalize_compounds(df_fine, scaf_f)

forbid_scafs = set(scaf_f.loc[scaf_f['split'].isin(['val','test']),'scaffold'])
lst = scaf_f.loc[scaf_f['split'].isin(['val','test']),'smiles_set'].tolist()
forbid_smiles = set().union(*lst) if len(lst)>0 else set()

scaf_p = aggregate_scaffolds(df_pre, fams, M_p)
scaf_p = greedy_split(scaf_p, fams, ratios, seed+1, forbid_scafs, forbid_smiles)
df_pre = finalize_compounds(df_pre, scaf_p)

df_fine.to_csv(os.path.join(outdir,'finetuning_with_splits.csv'), index=False)
df_pre.to_csv(os.path.join(outdir,'pretraining_with_splits.csv'), index=False)
for dname, df in [('finetuning', df_fine), ('pretraining', df_pre)]:
    for sp in ['train','val','test']:
        df[df['split']==sp].to_csv(os.path.join(outdir,f'{dname}_{sp}.csv'), index=False)

sa = pd.concat([
    scaf_f.assign(dataset='finetuning')[['scaffold','split','n','dominant_family']],
    scaf_p.assign(dataset='pretraining')[['scaffold','split','n','dominant_family']]
])
sa.to_csv(os.path.join(outdir,'scaffold_assignments.csv'), index=False)
fs_f = family_summary(df_fine, fams, M_f).assign(dataset='finetuning')
fs_p = family_summary(df_pre, fams, M_p).assign(dataset='pretraining')
pd.concat([fs_f, fs_p]).to_csv(os.path.join(outdir,'family_coverage_summary.csv'), index=False)
overlap = pd.DataFrame({'scaffold':list(set(scaf_f['scaffold']) & set(scaf_p['scaffold']))})
overlap = overlap.merge(scaf_f[['scaffold','split']].rename(columns={'split':'finetuning_split'}), on='scaffold', how='left')
overlap = overlap.merge(scaf_p[['scaffold','split']].rename(columns={'split':'pretraining_split'}), on='scaffold', how='left')
overlap.to_csv(os.path.join(outdir,'overlap_report.csv'), index=False)

assert len(set(df_fine.groupby('__scaffold')['split'].nunique()))==1
assert len(set(df_pre.groupby('__scaffold')['split'].nunique()))==1
assert not bool(forbid_scafs & set(scaf_p.loc[scaf_p['split']=='train','scaffold']))
