In [None]:
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdFingerprintGenerator
from sklearn.utils.extmath import randomized_svd
from sklearn.cluster import KMeans
from collections import defaultdict

# Load datasets and family mapping
df_pretrain_orig = pd.read_csv('datasets/chembl_pretraining.csv')
df_finetune_orig = pd.read_csv('datasets/pkis2_finetuning.csv')
df_families = pd.read_csv('uniprotkb_chembl_ID_kinomescan_cleaned.csv')

# Find SMILES column in finetune dataset
smiles_col_finetune = None
for col in df_finetune_orig.columns:
    if col.lower() == 'smiles':
        smiles_col_finetune = col
        break
if smiles_col_finetune is None:
    raise KeyError("No 'smiles' column found in pkis2_finetuning.csv")
if smiles_col_finetune != 'smiles':
    df_finetune_orig['smiles'] = df_finetune_orig[smiles_col_finetune]
    df_finetune_orig = df_finetune_orig.drop(columns=[smiles_col_finetune])

# Create protein to family mapping
protein_to_family = dict(zip(df_families['KINOMEscan® Gene Symbol'], df_families['Family']))

# Get protein columns
pretrain_proteins = [c for c in df_pretrain_orig.columns if c != 'smiles']
finetune_proteins = [c for c in df_finetune_orig.columns if c != 'smiles']
all_proteins = list(set(pretrain_proteins + finetune_proteins))

# Canonicalize SMILES
def canonicalize_smiles(s):
    if pd.isna(s): return None
    m = Chem.MolFromSmiles(s)
    if m is None: return None
    return Chem.MolToSmiles(m, isomericSmiles=False)

# Compute scaffolds
def bm_scaf(s):
    if pd.isna(s): return None
    m = Chem.MolFromSmiles(s)
    if m is None: return None
    sc = MurckoScaffold.GetScaffoldForMol(m)
    if sc is None or sc.GetNumAtoms()==0: return None
    return Chem.MolToSmiles(sc, isomericSmiles=False)

# Function to perform splitting with spectral clustering
def split_dataset(df, proteins, protein_to_family, forced_assignments=None):
    scaf_id, scaf_vocab = pd.factorize(df['scaffold_smiles'], sort=True)
    df['scaf_id'] = scaf_id
    
    grouped = df.groupby('scaf_id').indices
    scaff_to_idx = {k: np.array(v, dtype=int) for k, v in grouped.items() if k >= 0}
    
    # Spectral clustering
    mols = [Chem.MolFromSmiles(s) for s in df['canon_smiles']]
    rep_idx = [v[0] for v in scaff_to_idx.values()]
    gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=2048)
    X = np.zeros((len(rep_idx), 2048), dtype=bool)
    for r, i in enumerate(rep_idx):
        if mols[i] is not None:
            bv = gen.GetFingerprint(mols[i])
            on = list(bv.GetOnBits())
            if on: X[r, on] = True
    
    def jaccard_AB(A, B, chunk=2048):
        B_u16 = B.astype(np.uint16); B_t = B_u16.T
        bsum = B.sum(axis=1).astype(np.int32)
        n = A.shape[0]
        out = np.empty((n, B.shape[0]), dtype=np.float32)
        for i in range(0, n, chunk):
            Ab = A[i:i+chunk]
            inter = Ab.astype(np.uint16) @ B_t
            asum = Ab.sum(axis=1).astype(np.int32)[:, None]
            union = asum + bsum[None, :] - inter
            union = np.maximum(union, 1)
            out[i:i+chunk] = inter / union
        return out
    
    rng = np.random.default_rng(0)
    n = X.shape[0]
    m = min(1024, n)
    L = rng.choice(n, size=m, replace=False)
    B = X[L]
    C = jaccard_AB(X, B, chunk=2048).astype(np.float32)
    W = jaccard_AB(B, B, chunk=1024).astype(np.float32)
    eps = 1e-6
    evals, evecs = np.linalg.eigh(W + eps*np.eye(m, dtype=np.float32))
    Winv = (evecs / evals) @ evecs.T
    Winvsqrt = (evecs / np.sqrt(evals)) @ evecs.T
    u = C.sum(axis=0)
    v = Winv @ u
    d = C @ v
    s = (1.0 / np.sqrt(d + 1e-12)).astype(np.float32)
    E = (C * s[:, None]) @ Winvsqrt
    U, S, VT = randomized_svd(E, n_components=3, random_state=0)
    Y = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)
    cluster_labels = KMeans(n_clusters=3, n_init=10, random_state=0).fit_predict(Y)
    
    # Family coverage computation with protein tracking
    all_families = set(protein_to_family.values())
    family_counts = {}
    scaffold_activities = {}
    protein_to_scaffolds = defaultdict(set)
    protein_measurements = defaultdict(int)
    
    for k, idx in scaff_to_idx.items():
        sub = df.iloc[idx]
        family_counts[k] = defaultdict(int)
        scaffold_activities[k] = 0
        
        for _, row in sub.iterrows():
            orig_idx = row['original_idx']
            for prot in proteins:
                if pd.notna(df.iloc[orig_idx][prot]):
                    protein_to_scaffolds[prot].add(k)
                    protein_measurements[prot] += 1
                    if prot in protein_to_family:
                        family = protein_to_family[prot]
                        family_counts[k][family] += 1
                        scaffold_activities[k] += 1
    
    total_per_family = defaultdict(int)
    for k, families in family_counts.items():
        for family, count in families.items():
            total_per_family[family] += count
    
    rare_families = set([f for f in total_per_family if total_per_family[f] < 50])
    critical_families = set([f for f in total_per_family if total_per_family[f] < 20])
    
    # Split assignment
    n_scaff = len(scaff_to_idx)
    ratios = np.array([0.8, 0.1, 0.1])
    targets = (ratios * n_scaff).astype(int)
    targets[0] = n_scaff - targets[1] - targets[2]
    
    assign = {k: None for k in scaff_to_idx.keys()}
    split_sizes = np.zeros(3, dtype=int)
    family_cov = {f: np.zeros(3, dtype=int) for f in all_families}
    
    # Apply forced assignments
    if forced_assignments:
        for k in scaff_to_idx.keys():
            scaffold_smiles = scaf_vocab[k] if k >= 0 else None
            if scaffold_smiles in forced_assignments:
                assign[k] = 0
                split_sizes[0] += 1
                for f, c in family_counts[k].items():
                    if c > 0:
                        family_cov[f][0] += 1
    
    # Handle critical scaffolds
    critical_scaffolds = []
    for k in scaff_to_idx.keys():
        if assign[k] is None:
            has_critical = any(f in critical_families for f in family_counts[k])
            if has_critical:
                critical_scaffolds.append(k)
    
    for k in critical_scaffolds:
        best_split = None
        best_score = -1000
        
        for s in range(3):
            if split_sizes[s] >= targets[s]:
                continue
            score = 0
            for f in family_counts[k]:
                if f in critical_families:
                    if family_cov[f][s] == 0:
                        score += 100
                    elif family_cov[f][s] < 2:
                        score += 50
                elif f in rare_families:
                    if family_cov[f][s] == 0:
                        score += 10
            
            if score > best_score:
                best_score = score
                best_split = s
        
        if best_split is not None:
            assign[k] = best_split
            split_sizes[best_split] += 1
            for f, c in family_counts[k].items():
                if c > 0:
                    family_cov[f][best_split] += 1
    
    # Order remaining scaffolds
    remaining = [k for k in scaff_to_idx.keys() if assign[k] is None]
    order = sorted(remaining, key=lambda k: (
        min([total_per_family[f] for f, c in family_counts[k].items() if c > 0]) if family_counts[k] else 10**9,
        -len(family_counts[k]),
        cluster_labels[list(scaff_to_idx.keys()).index(k)]
    ))
    
    # Assign remaining scaffolds
    for k in order:
        gains = []
        fc = family_counts[k]
        
        for s in range(3):
            g = 0
            for f, c in fc.items():
                if c > 0:
                    if f in critical_families and family_cov[f][s] == 0:
                        g += 1000
                    elif f in rare_families and family_cov[f][s] == 0:
                        g += 100
                    elif family_cov[f][s] == 0:
                        g += 10
                    elif family_cov[f][s] < 3:
                        g += 1
            
            over = 100 if split_sizes[s] >= targets[s] else 0
            gains.append((g, -over, -split_sizes[s], s))
        
        s = max(gains)[-1]
        assign[k] = s
        split_sizes[s] += 1
        for f, c in family_counts[k].items():
            if c > 0:
                family_cov[f][s] += 1
    
    # Calculate protein coverage
    protein_coverage = {}
    for prot in proteins:
        if prot in protein_to_scaffolds:
            scaffolds = protein_to_scaffolds[prot]
            split_coverage = [False, False, False]
            for scaf_id in scaffolds:
                if scaf_id in assign and assign[scaf_id] is not None:
                    split_coverage[assign[scaf_id]] = True
            protein_coverage[prot] = split_coverage
        else:
            protein_coverage[prot] = [False, False, False]
    
    # Apply splits to molecules
    mol_split = np.full(len(df), -1, dtype=int)
    for k, s in assign.items():
        idx = scaff_to_idx[k]
        mol_split[idx] = s
    
    df['split'] = mol_split
    
    # Return split info
    split_scaffolds = {
        'train': set(scaf_vocab[k] for k, s in assign.items() if s == 0 and k >= 0),
        'val': set(scaf_vocab[k] for k, s in assign.items() if s == 1 and k >= 0),
        'test': set(scaf_vocab[k] for k, s in assign.items() if s == 2 and k >= 0)
    }
    
    return df, split_scaffolds, protein_coverage, protein_measurements

# Optimization loop to achieve 100% coverage
def optimize_protein_coverage(df_pretrain_orig, df_finetune_orig, pretrain_proteins, finetune_proteins, 
                            protein_to_family, target_coverage=1.0):
    df_pretrain = df_pretrain_orig.copy()
    df_finetune = df_finetune_orig.copy()
    
    df_pretrain['canon_smiles'] = df_pretrain['smiles'].apply(canonicalize_smiles)
    df_finetune['canon_smiles'] = df_finetune['smiles'].apply(canonicalize_smiles)
    
    df_pretrain = df_pretrain[df_pretrain['canon_smiles'].notna()].reset_index(drop=True)
    df_finetune = df_finetune[df_finetune['canon_smiles'].notna()].reset_index(drop=True)
    
    df_pretrain['original_idx'] = np.arange(len(df_pretrain))
    df_finetune['original_idx'] = np.arange(len(df_finetune))
    
    df_pretrain['scaffold_smiles'] = df_pretrain['canon_smiles'].map(bm_scaf)
    df_finetune['scaffold_smiles'] = df_finetune['canon_smiles'].map(bm_scaf)
    
    proteins_to_remove = set()
    iteration = 0
    
    while iteration < 50:
        iteration += 1
        
        current_pretrain_proteins = [p for p in pretrain_proteins if p not in proteins_to_remove]
        current_finetune_proteins = [p for p in finetune_proteins if p not in proteins_to_remove]
        
        # Split pretraining
        df_pretrain_split, pretrain_scaffolds, pretrain_coverage, pretrain_measurements = split_dataset(
            df_pretrain.copy(), current_pretrain_proteins, protein_to_family
        )
        
        # Identify critical scaffolds
        finetune_scaffolds = set(df_finetune['scaffold_smiles'].dropna())
        critical_scaffolds = pretrain_scaffolds['train'] & finetune_scaffolds
        
        # Split finetuning
        df_finetune_split, finetune_scaffolds_split, finetune_coverage, finetune_measurements = split_dataset(
            df_finetune.copy(), current_finetune_proteins, protein_to_family, 
            forced_assignments=critical_scaffolds
        )
        
        # Calculate coverage metrics
        pretrain_problematic = []
        finetune_problematic = []
        
        for prot, cov in pretrain_coverage.items():
            if not all(cov):
                score = sum([not c for c in cov])
                pretrain_problematic.append((prot, pretrain_measurements[prot], score))
        
        for prot, cov in finetune_coverage.items():
            if not all(cov):
                score = sum([not c for c in cov])
                finetune_problematic.append((prot, finetune_measurements[prot], score))
        
        total_proteins = len(current_pretrain_proteins) + len(current_finetune_proteins)
        perfect_coverage = len(pretrain_problematic) + len(finetune_problematic)
        coverage_rate = 1.0 - (perfect_coverage / total_proteins) if total_proteins > 0 else 1.0
        
        print(f"Iteration {iteration}: Coverage={coverage_rate:.3f}, Problematic={perfect_coverage}")
        
        if coverage_rate >= target_coverage or perfect_coverage == 0:
            break
        
        # Find protein to remove
        all_problematic = []
        for p, m, s in pretrain_problematic:
            all_problematic.append((p, m, s, 'pretrain'))
        for p, m, s in finetune_problematic:
            all_problematic.append((p, m, s, 'finetune'))
        
        if all_problematic:
            all_problematic.sort(key=lambda x: (x[1], -x[2]))
            protein_to_remove = all_problematic[0][0]
            proteins_to_remove.add(protein_to_remove)
            print(f"  Removing {protein_to_remove} ({all_problematic[0][1]} measurements)")
    
    return df_pretrain_split, df_finetune_split, proteins_to_remove, current_pretrain_proteins, current_finetune_proteins

# Run optimization
print("Starting optimization for 100% protein coverage...")
df_pretrain_split, df_finetune_split, removed_proteins, final_pretrain_proteins, final_finetune_proteins = optimize_protein_coverage(
    df_pretrain_orig, df_finetune_orig, pretrain_proteins, finetune_proteins, 
    protein_to_family, target_coverage=1.0
)

# Save pretraining splits
for split_val, name in enumerate(['train', 'val', 'test']):
    mask = df_pretrain_split['split'] == split_val
    orig_indices = df_pretrain_split[mask]['original_idx'].values
    df_out = df_pretrain_orig.loc[df_pretrain_orig.index.isin(orig_indices), ['smiles'] + final_pretrain_proteins]
    df_out.to_csv(f'datasets/chembl_pretraining_{name}.csv', index=False)

# Save finetuning splits
for split_val, name in enumerate(['train', 'val', 'test']):
    mask = df_finetune_split['split'] == split_val
    orig_indices = df_finetune_split[mask]['original_idx'].values
    df_out = df_finetune_orig.loc[df_finetune_orig.index.isin(orig_indices), ['smiles'] + final_finetune_proteins]
    df_out.to_csv(f'datasets/pkis2_finetuning_{name}.csv', index=False)

# Save split assignments
df_pretrain_split[['smiles', 'scaffold_smiles', 'split']].to_csv('split_pretrain.csv', index=False)
df_finetune_split[['smiles', 'scaffold_smiles', 'split']].to_csv('split_finetune.csv', index=False)

# Calculate final statistics
pretrain_scaffolds_final = {
    'train': set(df_pretrain_split[df_pretrain_split['split']==0]['scaffold_smiles'].dropna()),
    'val': set(df_pretrain_split[df_pretrain_split['split']==1]['scaffold_smiles'].dropna()),
    'test': set(df_pretrain_split[df_pretrain_split['split']==2]['scaffold_smiles'].dropna())
}

finetune_scaffolds_final = {
    'train': set(df_finetune_split[df_finetune_split['split']==0]['scaffold_smiles'].dropna()),
    'val': set(df_finetune_split[df_finetune_split['split']==1]['scaffold_smiles'].dropna()),
    'test': set(df_finetune_split[df_finetune_split['split']==2]['scaffold_smiles'].dropna())
}

overlap_val = pretrain_scaffolds_final['train'] & finetune_scaffolds_final['val']
overlap_test = pretrain_scaffolds_final['train'] & finetune_scaffolds_final['test']

# Print final results
print(f"\n=== OPTIMIZATION COMPLETE ===")
print(f"Removed {len(removed_proteins)} proteins to achieve 100% coverage:")
for prot in sorted(removed_proteins):
    print(f"  - {prot}")

print(f"\nFinal splits:")
print(f"  Pretraining: {(df_pretrain_split['split'] == 0).sum()}/{(df_pretrain_split['split'] == 1).sum()}/{(df_pretrain_split['split'] == 2).sum()}")
print(f"  Finetuning: {(df_finetune_split['split'] == 0).sum()}/{(df_finetune_split['split'] == 1).sum()}/{(df_finetune_split['split'] == 2).sum()}")

print(f"\nOrthogonality check:")
print(f"  pretrain_train ∩ finetune_val: {len(overlap_val)} scaffolds")
print(f"  pretrain_train ∩ finetune_test: {len(overlap_test)} scaffolds")

print(f"\nRemaining proteins:")
print(f"  Pretraining: {len(final_pretrain_proteins)}/{len(pretrain_proteins)} proteins")
print(f"  Finetuning: {len(final_finetune_proteins)}/{len(finetune_proteins)} proteins")