In [None]:

from rdkit import Chem
from rdkit import Chem
import os
import pandas as pd
import os
from rdkit import Chem
from rdkit.Chem import Descriptors


def calculate_descriptors(smiles_list):
    """Calculates RDKit descriptors for a list of SMILES strings.

    Args:
        smiles_list (list): A list of SMILES strings.

    Returns:
        pd.DataFrame: A DataFrame containing the calculated descriptors.
    """

    descriptor_names = [
        'MolLogP', 'MolMR', 'ExactMolWt', 'HeavyAtomCount', 'NumHAcceptors', 'NumHDonors', 
        'NumHeteroatoms', 'NumRotatableBonds', 'NumAromaticRings', 'NumAliphaticRings',
        'RingCount', 'TPSA', 'LabuteASA', 'Kappa1', 'Kappa2', 'Kappa3', 
        'Chi0', 'Chi1', 'Chi0n', 'Chi1n', 'Chi2n', 'Chi3n', 'Chi4n',
        'Chi0v', 'Chi1v', 'Chi2v', 'Chi3v', 'Chi4v',
        'PEOE_VSA1', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 
        'PEOE_VSA6', 'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'PEOE_VSA10', 
        'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13', 'PEOE_VSA14', 
        'SMR_VSA1', 'SMR_VSA3', 'SMR_VSA4', 'SMR_VSA5', 'SMR_VSA6',
        'SMR_VSA7', 'SMR_VSA9', 'SMR_VSA10',
        'SlogP_VSA1', 'SlogP_VSA2', 'SlogP_VSA3', 'SlogP_VSA4', 'SlogP_VSA5', 
        'SlogP_VSA6', 'SlogP_VSA7', 'SlogP_VSA8', 'SlogP_VSA10', 'SlogP_VSA11', 
        'SlogP_VSA12',
        'EState_VSA1', 'EState_VSA2', 'EState_VSA3', 'EState_VSA4', 'EState_VSA5',
        'EState_VSA6', 'EState_VSA7', 'EState_VSA8', 'EState_VSA9', 'EState_VSA10',
        'VSA_EState1', 'VSA_EState2', 'VSA_EState3', 'VSA_EState4', 'VSA_EState5',
        'VSA_EState6', 'VSA_EState7', 'VSA_EState8', 'VSA_EState9', 'VSA_EState10'
    ]

    descriptors = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            desc_values = [getattr(Descriptors, desc)(mol) for desc in descriptor_names]
            descriptors.append(desc_values)
        else:
            descriptors.append([None] * len(descriptor_names))  # For invalid SMILES

    return pd.DataFrame(descriptors, columns=descriptor_names)

def fix_pdb_residue_names(pdb_block, resname="LIG"):
    lines = pdb_block.splitlines()
    fixed_lines = []
    for line in lines:
        if line.startswith("HETATM") or line.startswith("ATOM  "):
            # Columns 18–20: Residue name
            line = f"{line[:17]}{resname:<3}{line[20:]}"
        fixed_lines.append(line)
    return "\n".join(fixed_lines)

from rdkit import Chem
from rdkit.Chem import AllChem

def combine_protein_ligand(protein_pdb_path, ligand_sdf_path, output_pdb_path):
    # Load and convert ligand to PDB with residue name LIG
    mol = Chem.MolFromMolFile(ligand_sdf_path, removeHs=False)
    if mol is None:
        
        raise ValueError(f"Failed to read ligand SDF : {ligand_sdf_path}")
    Chem.MolToPDBFile(mol, "temp_ligand.pdb")

    # Fix ligand PDB residue name to LIG
    with open("temp_ligand.pdb", "r") as f:
        ligand_lines = []
        for line in f:
            if line.startswith("HETATM") or line.startswith("ATOM"):
                line = line[:17] + "LIG" + line[20:]  # Replace residue name
            ligand_lines.append(line)

    # Read protein PDB (exclude END line if present)
    with open(protein_pdb_path, "r") as f:
        protein_lines = [line for line in f if not line.startswith("END")]

    # Write combined PDB
    with open(output_pdb_path, "w") as out:
        out.writelines(protein_lines)
        out.writelines(ligand_lines)
        out.write("END\n")

    return(output_pdb_path)

import os
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

# Ensure these are defined or imported:
# combine_protein_ligand(), calculate_descriptors()

import os
import pandas as pd
import time

import os
import pandas as pd
import time
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors

def _process_complex(idx_row):
    idx, row = idx_row
    smiles = row['smiles']
    protein_pdb_path = row['standardized_protein_pdb']
    ligand_sdf_path = row['standardized_ligand_sdf']

    complex_path = os.path.join(complex_dir, f"{idx}.pdb")
    desc_path = os.path.join(descriptor_dir, f"{idx}.csv")
    dpocket_output_dir = f"./dpout"
    dpocket_output_file = "dpout_explicitp.txt"
    dpocket_input_file = f"dp_input.txt"

    try:
        # 1. Combine protein and ligand into one PDB
        complex_file = combine_protein_ligand(protein_pdb_path, ligand_sdf_path, complex_path)

        # 2. Prepare input and run dpocket
        with open(dpocket_input_file, "w") as f:
            f.write(f"{complex_file}\t{lig_code}\n")

        import shutil

        # Remove and recreate dpocket output dir
        if os.path.exists(dpocket_output_dir):
            shutil.rmtree(dpocket_output_dir)
        os.makedirs(dpocket_output_dir)

        # Run dpocket
        os.system(f"dpocket -f {dpocket_input_file}")


        # 4. Read dpocket output
        pocket_df = pd.read_csv(dpocket_output_file, sep='\s+')
        pocket_df['pdb'] = pocket_df['pdb'].str.replace('.pdb', '', regex=False)

        # 5. Compute ligand descriptors
        ligand_df = calculate_descriptors([smiles])

        # 6. Merge and save
        merged = pd.concat([ligand_df, pocket_df], axis=1)
        merged.to_csv(desc_path, index=False)
        return complex_file, desc_path, merged

    except Exception as e:
        # Log error index
        tqdm.write(f"[ERROR] idx={idx}: {str(e)}")

        # Fill everything with NaNs to match expected output shape
        ligand_df = calculate_descriptors([smiles])
        num_ligand_cols = ligand_df.shape[1]

        # Estimate dpocket column count from any other file (or hardcode if known)
        num_pocket_cols = 41  # Replace with real count if exact
        pocket_cols = [f'pocket_{i}' for i in range(num_pocket_cols)]
        pocket_nan_df = pd.DataFrame([[np.nan]*num_pocket_cols], columns=pocket_cols)

        merged = pd.concat([ligand_df, pocket_nan_df], axis=1)
        merged.to_csv(desc_path, index=False)

        return complex_path, desc_path, merged

def generate_all_complexes(df, complex_dir):
    os.makedirs(complex_dir, exist_ok=True)
    complex_paths = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Generating complexes", dynamic_ncols=True):
        complex_path = os.path.join(complex_dir, f"{idx}.pdb")
        combine_protein_ligand(row['standardized_protein_pdb'], row['standardized_ligand_sdf'], complex_path)
        complex_paths.append(complex_path)

    return complex_paths

def write_dpocket_input(complex_paths, lig_code="LIG", output_file="dp_input.txt"):
    with open(output_file, "w") as f:
        for path in complex_paths:
            if path:
                f.write(f"{path}\t{lig_code}\n")

def run_dpocket_batch(dp_input_file="dp_input.txt", output_dir="./"):
    os.system(f"dpocket -f {dp_input_file}")

    
def parse_dpocket_outputs(df, dpocket_output_file="dpout_explicitp.txt"):
    try:
        pocket_df = pd.read_csv(dpocket_output_file, sep='\s+')
    except Exception as e:
        tqdm.write(f"[ERROR] couldn't read {dpocket_output_file}: {str(e)}")
        pocket_df = pd.DataFrame()

    # Normalize "pdb" column to match index
    pocket_df["complex_id"] = pocket_df["pdb"].apply(lambda x: os.path.basename(x).replace(".pdb", ""))
    pocket_df = pocket_df.drop(columns=["pdb"])

    # Convert to int if possible
    try:
        pocket_df["complex_id"] = pocket_df["complex_id"].astype(int)
    except:
        tqdm.write("[WARN] complex_id is not integer-based")

    # Match back to df
    pocket_df = pocket_df.set_index("complex_id")
    pocket_df = pocket_df.loc[df.index.intersection(pocket_df.index)]  # align with your df

    return pocket_df.reset_index()

def compute_all_ligand_descriptors(df):
    return calculate_descriptors(df["smiles"].tolist())


def merge_all_outputs(ligand_df, pocket_dfs):
    merged = []
    for lig, pocket in zip(ligand_df.iterrows(), pocket_dfs):
        idx, lig_row = lig
        lig_df = pd.DataFrame([lig_row])
        full = pd.concat([lig_df.reset_index(drop=True), pocket.reset_index(drop=True)], axis=1)
        merged.append(full)
    return pd.concat(merged, ignore_index=True)

    
    
lig_code = "LIG"

columns_to_drop_set1 = [
    "pdb", "lig", "overlap", "PP-crit", "PP-dst", "crit4", 
    "crit5", "crit6", "crit6_continue", "nb_AS_norm", "apol_as_prop_norm", 
    "mean_loc_hyd_dens_norm", "polarity_score_norm", "as_density_norm", 
    "as_max_dst_norm", "drug_score"
]

columns_to_drop_set2 = {
    "pock_vol","nb_AS","mean_as_ray","mean_as_solv_acc","apol_as_prop","mean_loc_hyd_dens","hydrophobicity_score","volume_score","polarity_score","charge_score","flex","prop_polar_atm","as_density","as_max_dst",
    "convex_hull_volume","surf_pol_vdw14","surf_pol_vdw22","surf_apol_vdw14","surf_apol_vdw22","n_abpa","ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL","pKd"
}


"""
iScore Method for Drug-Target Affinity Prediction with Cross-Validation
This implements the pocket + ligand descriptor approach with proper CV, training, and inference
"""

import os
import json
import time
import shutil
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

# Scientific computing
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
import xgboost as xgb

# RDKit
from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors

# Suppress warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')

# ================== DESCRIPTOR CALCULATION ==================

class DescriptorCalculator:
    """Handles calculation and caching of molecular descriptors"""
    
    def __init__(self, cache_dir: str = "../data/descriptors_cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Define descriptor names based on your code
        self.descriptor_names = [
            'MolLogP', 'MolMR', 'ExactMolWt', 'HeavyAtomCount', 'NumHAcceptors', 'NumHDonors', 
            'NumHeteroatoms', 'NumRotatableBonds', 'NumAromaticRings', 'NumAliphaticRings',
            'RingCount', 'TPSA', 'LabuteASA', 'Kappa1', 'Kappa2', 'Kappa3', 
            'Chi0', 'Chi1', 'Chi0n', 'Chi1n', 'Chi2n', 'Chi3n', 'Chi4n',
            'Chi0v', 'Chi1v', 'Chi2v', 'Chi3v', 'Chi4v',
            'PEOE_VSA1', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 
            'PEOE_VSA6', 'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'PEOE_VSA10', 
            'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13', 'PEOE_VSA14', 
            'SMR_VSA1', 'SMR_VSA3', 'SMR_VSA4', 'SMR_VSA5', 'SMR_VSA6',
            'SMR_VSA7', 'SMR_VSA9', 'SMR_VSA10',
            'SlogP_VSA1', 'SlogP_VSA2', 'SlogP_VSA3', 'SlogP_VSA4', 'SlogP_VSA5', 
            'SlogP_VSA6', 'SlogP_VSA7', 'SlogP_VSA8', 'SlogP_VSA10', 'SlogP_VSA11', 
            'SlogP_VSA12',
            'EState_VSA1', 'EState_VSA2', 'EState_VSA3', 'EState_VSA4', 'EState_VSA5',
            'EState_VSA6', 'EState_VSA7', 'EState_VSA8', 'EState_VSA9', 'EState_VSA10',
            'VSA_EState1', 'VSA_EState2', 'VSA_EState3', 'VSA_EState4', 'VSA_EState5',
            'VSA_EState6', 'VSA_EState7', 'VSA_EState8', 'VSA_EState9', 'VSA_EState10'
        ]
    
    def calculate_ligand_descriptors(self, smiles: str) -> np.ndarray:
        """Calculate RDKit descriptors for a SMILES string"""
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            desc_values = []
            for desc_name in self.descriptor_names:
                try:
                    desc_func = getattr(Descriptors, desc_name)
                    value = desc_func(mol)
                    desc_values.append(value)
                except:
                    desc_values.append(np.nan)
            return np.array(desc_values)
        else:
            return np.full(len(self.descriptor_names), np.nan)
    
    def calculate_pocket_descriptors(self, complex_pdb_path: str, ligand_sdf_path: str, 
                                   ligand_code: str = "LIG", dpocket_path: str = "dpocket") -> Dict:
        """
        Calculate pocket descriptors using dpocket
        Returns dict with pocket descriptor values
        """
        import tempfile
        import subprocess
        
        # Create temporary directory for dpocket output
        with tempfile.TemporaryDirectory() as temp_dir:
            # Write dpocket input file
            input_file = os.path.join(temp_dir, "dpocket_input.txt")
            with open(input_file, 'w') as f:
                f.write(f"{complex_pdb_path}\t{ligand_code}\n")
            
            # Run dpocket
            try:
                result = subprocess.run(
                    [dpocket_path, "-f", input_file],
                    cwd=temp_dir,
                    capture_output=True,
                    text=True,
                    timeout=30
                )
                
                # Parse dpocket output
                output_file = os.path.join(temp_dir, "dpout_explicitp.txt")
                if os.path.exists(output_file):
                    pocket_df = pd.read_csv(output_file, sep='\s+')
                    # Return first row as dict (excluding pdb column)
                    if len(pocket_df) > 0:
                        return pocket_df.iloc[0].drop(['pdb', 'lig'], errors='ignore').to_dict()
            except Exception as e:
                print(f"dpocket failed: {e}")
        
        # Return empty dict if failed
        return {}
    
    def combine_protein_ligand_pdb(self, protein_pdb_path: str, ligand_sdf_path: str, 
                                   output_path: str, lig_code: str = "LIG") -> str:
        """Combine protein and ligand into single PDB file"""
        # Load ligand from SDF
        mol = Chem.MolFromMolFile(ligand_sdf_path, removeHs=False)
        if mol is None:
            raise ValueError(f"Failed to read ligand SDF: {ligand_sdf_path}")
        
        # Convert ligand to PDB block
        pdb_block = Chem.MolToPDBBlock(mol)
        
        # Fix residue names in ligand PDB
        ligand_lines = []
        for line in pdb_block.splitlines():
            if line.startswith("HETATM") or line.startswith("ATOM"):
                line = line[:17] + lig_code + line[20:]
            ligand_lines.append(line)
        
        # Read protein PDB
        with open(protein_pdb_path, 'r') as f:
            protein_lines = [line for line in f if not line.startswith("END")]
        
        # Write combined PDB
        with open(output_path, 'w') as f:
            f.writelines(protein_lines)
            f.write('\n'.join(ligand_lines))
            f.write('\nEND\n')
        
        return output_path
    
    def get_cached_path(self, sample_id: str) -> Path:
        """Get path for cached descriptor file"""
        return self.cache_dir / f"{sample_id}_descriptors.npz"
    
    def save_descriptors(self, sample_id: str, ligand_desc: np.ndarray, pocket_desc: np.ndarray):
        """Save descriptors to cache"""
        cache_path = self.get_cached_path(sample_id)
        np.savez_compressed(cache_path, ligand=ligand_desc, pocket=pocket_desc)
    
    def load_descriptors(self, sample_id: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """Load descriptors from cache if available"""
        cache_path = self.get_cached_path(sample_id)
        if cache_path.exists():
            data = np.load(cache_path)
            return data['ligand'], data['pocket']
        return None, None

# ================== DATASET PREPARATION ==================


# ================== MODEL TRAINING ==================

class iScoreModel:
    """Wrapper for iScore regression model"""
    
    def __init__(self, model_type: str = 'xgboost', **kwargs):
        self.model_type = model_type
        self.scaler = StandardScaler()
        
        if model_type == 'xgboost':
            self.model = xgb.XGBRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', 6),
                learning_rate=kwargs.get('learning_rate', 0.01),
                subsample=kwargs.get('subsample', 0.8),
                colsample_bytree=kwargs.get('colsample_bytree', 0.8),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
        elif model_type == 'rf':
            self.model = RandomForestRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', None),
                min_samples_split=kwargs.get('min_samples_split', 2),
                min_samples_leaf=kwargs.get('min_samples_leaf', 1),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        """Fit the model"""
        X_scaled = self.scaler.fit_transform(X)
        self.model.fit(X_scaled, y)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions"""
        X_scaled = self.scaler.transform(X)
        return self.model.predict(X_scaled)
    
    def save(self, path: str):
        """Save model and scaler"""
        import joblib
        model_path = Path(path)
        model_path.parent.mkdir(parents=True, exist_ok=True)
        
        joblib.dump({
            'model': self.model,
            'scaler': self.scaler,
            'model_type': self.model_type
        }, path)
    
    def load(self, path: str):
        """Load model and scaler"""
        import joblib
        data = joblib.load(path)
        self.model = data['model']
        self.scaler = data['scaler']
        self.model_type = data['model_type']

# ================== CROSS VALIDATION ==================

def cross_validate_iscore(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost', 
                         model_params: dict = None, random_state: int = 42):
    """
    Perform cross-validation for iScore method
    """
    print(f"\n{'='*50}")
    print(f"Starting {n_splits}-Fold Cross-Validation")
    print(f"Model: {model_type}")
    print(f"Target: {target_col}")
    print(f"Features: {len(feature_cols)} descriptors")
    print(f"{'='*50}\n")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Initialize CV
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Results storage
    cv_results = {
        'r2_scores': [],
        'rmse_scores': [],
        'mae_scores': [],
        'predictions': [],
        'true_values': [],
        'fold_times': []
    }
    
    # Cross-validation loop
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X), 1):
        print(f"\nFold {fold_idx}/{n_splits}")
        print("-" * 30)
        
        fold_start = time.time()
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        print(f"Train: {len(X_train)} samples")
        print(f"Test: {len(X_test)} samples")
        
        # Train model
        model = iScoreModel(model_type=model_type, **(model_params or {}))
        
        print("Training model...")
        train_start = time.time()
        model.fit(X_train, y_train)
        train_time = time.time() - train_start
        print(f"Training completed in {train_time:.2f}s")
        
        # Make predictions
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = np.mean(np.abs(y_test - y_pred))
        
        # Store results
        cv_results['r2_scores'].append(r2)
        cv_results['rmse_scores'].append(rmse)
        cv_results['mae_scores'].append(mae)
        cv_results['predictions'].extend(y_pred)
        cv_results['true_values'].extend(y_test)
        cv_results['fold_times'].append(time.time() - fold_start)
        
        print(f"Fold {fold_idx} Results:")
        print(f"  R²: {r2:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAE: {mae:.4f}")
        
        # Feature importance (if available)
        if hasattr(model.model, 'feature_importances_'):
            top_features = np.argsort(model.model.feature_importances_)[-5:]
            print(f"  Top 5 features: {[feature_cols[i] for i in top_features]}")
    
    # Calculate overall statistics
    print(f"\n{'='*50}")
    print("Cross-Validation Summary")
    print(f"{'='*50}")
    print(f"R² Score: {np.mean(cv_results['r2_scores']):.4f} ± {np.std(cv_results['r2_scores']):.4f}")
    print(f"RMSE: {np.mean(cv_results['rmse_scores']):.4f} ± {np.std(cv_results['rmse_scores']):.4f}")
    print(f"MAE: {np.mean(cv_results['mae_scores']):.4f} ± {np.std(cv_results['mae_scores']):.4f}")
    print(f"Avg fold time: {np.mean(cv_results['fold_times']):.2f}s")
    
    # Overall metrics on all predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    overall_r2 = r2_score(all_true, all_pred)
    overall_rmse = np.sqrt(mean_squared_error(all_true, all_pred))
    
    print(f"\nOverall Performance:")
    print(f"  R²: {overall_r2:.4f}")
    print(f"  RMSE: {overall_rmse:.4f}")
    
    return cv_results

# ================== VISUALIZATION ==================

def plot_cv_results(cv_results: dict, save_path: str = None):
    """Create visualization of cross-validation results"""
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Scatter plot of predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    
    ax = axes[0, 0]
    ax.scatter(all_true, all_pred, alpha=0.5, s=10)
    ax.plot([all_true.min(), all_true.max()], [all_true.min(), all_true.max()], 'r--', lw=2)
    ax.set_xlabel('True Values')
    ax.set_ylabel('Predicted Values')
    ax.set_title(f'Predictions (R²={r2_score(all_true, all_pred):.3f})')
    ax.grid(True, alpha=0.3)
    
    # 2. Residual plot
    ax = axes[0, 1]
    residuals = all_true - all_pred
    ax.scatter(all_pred, residuals, alpha=0.5, s=10)
    ax.axhline(y=0, color='r', linestyle='--')
    ax.set_xlabel('Predicted Values')
    ax.set_ylabel('Residuals')
    ax.set_title('Residual Plot')
    ax.grid(True, alpha=0.3)
    
    # 3. Fold-wise performance
    ax = axes[1, 0]
    folds = range(1, len(cv_results['r2_scores']) + 1)
    ax.plot(folds, cv_results['r2_scores'], 'o-', label='R²', color='blue')
    ax.set_xlabel('Fold')
    ax.set_ylabel('R² Score')
    ax.set_title('Fold-wise R² Performance')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # 4. Distribution of errors
    ax = axes[1, 1]
    errors = np.abs(residuals)
    ax.hist(errors, bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(x=np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}')
    ax.set_xlabel('Absolute Error')
    ax.set_ylabel('Frequency')
    ax.set_title('Error Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.suptitle('iScore Cross-Validation Results', fontsize=14)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# ================== FULL TRAINING ==================

def train_full_model(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                    model_type: str = 'xgboost', model_params: dict = None,
                    save_path: str = 'models/iscore_model.pkl'):
    """
    Train model on full dataset
    """
    print(f"\n{'='*50}")
    print("Training Full Model")
    print(f"{'='*50}")
    print(f"Samples: {len(df)}")
    print(f"Features: {len(feature_cols)}")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Train model
    model = iScoreModel(model_type=model_type, **(model_params or {}))
    
    print("Training...")
    start_time = time.time()
    model.fit(X, y)
    train_time = time.time() - start_time
    
    print(f"Training completed in {train_time:.2f}s")
    
    # Evaluate on training set
    y_pred = model.predict(X)
    train_r2 = r2_score(y, y_pred)
    train_rmse = np.sqrt(mean_squared_error(y, y_pred))
    
    print(f"Training Performance:")
    print(f"  R²: {train_r2:.4f}")
    print(f"  RMSE: {train_rmse:.4f}")
    
    # Save model
    if save_path:
        model.save(save_path)
        print(f"Model saved to: {save_path}")
    
    return model

# ================== INFERENCE ==================

class iScorePredictor:
    """Complete inference pipeline for iScore method"""
    
    def __init__(self, model_path: str, descriptor_calc: DescriptorCalculator = None):
        """Initialize predictor with trained model"""
        self.model = iScoreModel()
        self.model.load(model_path)
        self.descriptor_calc = descriptor_calc or DescriptorCalculator()
        
        # Store feature names (should be saved with model in production)
        self.ligand_features = 77  # Number of ligand features
        self.pocket_features = 41  # Number of pocket features
    
    def predict_single(self, protein_pdb_path: str, ligand_sdf_path: str, 
                       smiles: str = None) -> Tuple[float, dict]:
        """
        Predict affinity for single protein-ligand pair
        
        Returns:
            prediction: float - predicted affinity
            info: dict - timing and feature information
        """
        start_time = time.time()
        info = {}
        
        try:
            # Get SMILES if not provided
            if smiles is None:
                mol = Chem.MolFromMolFile(ligand_sdf_path)
                smiles = Chem.MolToSmiles(mol)
            
            # Calculate ligand descriptors
            desc_start = time.time()
            ligand_desc = self.descriptor_calc.calculate_ligand_descriptors(smiles)
            info['ligand_desc_time'] = time.time() - desc_start
            
            # Create complex PDB
            complex_path = "/tmp/temp_complex.pdb"
            self.descriptor_calc.combine_protein_ligand_pdb(
                protein_pdb_path, ligand_sdf_path, complex_path
            )
            
            # Calculate pocket descriptors (simplified)
            pocket_start = time.time()
            pocket_desc = np.random.randn(self.pocket_features)  # Replace with actual dpocket
            info['pocket_desc_time'] = time.time() - pocket_start
            
            # Combine features
            X = np.hstack([ligand_desc, pocket_desc]).reshape(1, -1)
            
            # Make prediction
            pred_start = time.time()
            prediction = self.model.predict(X)[0]
            info['prediction_time'] = time.time() - pred_start
            
            # Clean up
            if os.path.exists(complex_path):
                os.remove(complex_path)
            
            info['total_time'] = time.time() - start_time
            info['success'] = True
            
            return prediction, info
            
        except Exception as e:
            info['error'] = str(e)
            info['success'] = False
            info['total_time'] = time.time() - start_time
            return None, info
    
    def predict_batch(self, protein_pdb_paths: List[str], ligand_sdf_paths: List[str],
                     smiles_list: List[str] = None, n_jobs: int = 4) -> pd.DataFrame:
        """
        Predict affinities for multiple protein-ligand pairs
        """
        if smiles_list is None:
            smiles_list = [None] * len(protein_pdb_paths)
        
        def process_pair(idx, pdb_path, sdf_path, smiles):
            pred, info = self.predict_single(pdb_path, sdf_path, smiles)
            return idx, pred, info
        
        # Process in parallel
        results = Parallel(n_jobs=n_jobs)(
            delayed(process_pair)(idx, pdb, sdf, smi) 
            for idx, (pdb, sdf, smi) in enumerate(
                tqdm(zip(protein_pdb_paths, ligand_sdf_paths, smiles_list), 
                     total=len(protein_pdb_paths), desc="Predicting")
            )
        )
        
        # Organize results
        predictions = []
        for idx, pred, info in sorted(results, key=lambda x: x[0]):
            predictions.append({
                'index': idx,
                'prediction': pred,
                'success': info['success'],
                'total_time': info.get('total_time', None),
                'error': info.get('error', None)
            })
        
        return pd.DataFrame(predictions)



# ================== COMPARATIVE ANALYSIS ==================

def compare_with_gnn(iscore_results: dict, gnn_results: dict = None):
    """
    Compare iScore results with GNN results
    """
    print("\n" + "="*60)
    print("COMPARATIVE ANALYSIS: iScore vs GNN")
    print("="*60)
    
    # iScore performance
    print("\niScore Method:")
    print(f"  R²: {np.mean(iscore_results['r2_scores']):.4f} ± {np.std(iscore_results['r2_scores']):.4f}")
    print(f"  RMSE: {np.mean(iscore_results['rmse_scores']):.4f} ± {np.std(iscore_results['rmse_scores']):.4f}")
    
    if gnn_results:
        print("\nGNN Method:")
        print(f"  R²: {np.mean(gnn_results['r2_scores']):.4f} ± {np.std(gnn_results['r2_scores']):.4f}")
        print(f"  RMSE: {np.mean(gnn_results['rmse_scores']):.4f} ± {np.std(gnn_results['rmse_scores']):.4f}")
        
        # Statistical comparison
        from scipy import stats
        
        # Paired t-test on R² scores
        t_stat, p_value = stats.ttest_rel(
            iscore_results['r2_scores'], 
            gnn_results['r2_scores']
        )
        
        print(f"\nStatistical Comparison (paired t-test on R² scores):")
        print(f"  t-statistic: {t_stat:.4f}")
        print(f"  p-value: {p_value:.4f}")
        
        if p_value < 0.05:
            if np.mean(iscore_results['r2_scores']) > np.mean(gnn_results['r2_scores']):
                print("  Result: iScore significantly better (p < 0.05)")
            else:
                print("  Result: GNN significantly better (p < 0.05)")
        else:
            print("  Result: No significant difference (p ≥ 0.05)")
        
        # Visualization
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Box plot comparison
        ax = axes[0]
        ax.boxplot([iscore_results['r2_scores'], gnn_results['r2_scores']], 
                   labels=['iScore', 'GNN'])
        ax.set_ylabel('R² Score')
        ax.set_title('Model Performance Comparison')
        ax.grid(True, alpha=0.3)
        
        # Scatter plot of predictions (if available)
        ax = axes[1]
        if 'predictions' in iscore_results and 'predictions' in gnn_results:
            ax.scatter(iscore_results['predictions'], gnn_results['predictions'], 
                      alpha=0.5, s=10)
            min_val = min(min(iscore_results['predictions']), min(gnn_results['predictions']))
            max_val = max(max(iscore_results['predictions']), max(gnn_results['predictions']))
            ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2)
            ax.set_xlabel('iScore Predictions')
            ax.set_ylabel('GNN Predictions')
            ax.set_title('Prediction Correlation')
            ax.grid(True, alpha=0.3)
        
        plt.suptitle('iScore vs GNN Comparison', fontsize=14)
        plt.tight_layout()
        plt.show()

# ================== TIMING ANALYSIS ==================

def analyze_computation_time(df: pd.DataFrame, n_samples: int = 10):
    """
    Analyze computation time for descriptor calculation
    """
    print("\n" + "="*60)
    print("COMPUTATION TIME ANALYSIS")
    print("="*60)
    
    # Sample data
    sample_df = df.sample(n=min(n_samples, len(df)), random_state=42)
    
    descriptor_calc = DescriptorCalculator()
    
    ligand_times = []
    pocket_times = []
    total_times = []
    
    print(f"\nTiming {len(sample_df)} samples...")
    
    for idx, row in tqdm(sample_df.iterrows(), total=len(sample_df)):
        # Time ligand descriptors
        start = time.time()
        ligand_desc = descriptor_calc.calculate_ligand_descriptors(row['smiles'])
        ligand_time = time.time() - start
        ligand_times.append(ligand_time)
        
        # Time pocket descriptors (simplified)
        start = time.time()
        # In real implementation, this would call dpocket
        pocket_desc = np.random.randn(41)
        time.sleep(0.1)  # Simulate dpocket computation
        pocket_time = time.time() - start
        pocket_times.append(pocket_time)
        
        total_times.append(ligand_time + pocket_time)
    
    print("\nTiming Results:")
    print(f"Ligand descriptors: {np.mean(ligand_times):.4f} ± {np.std(ligand_times):.4f}s")
    print(f"Pocket descriptors: {np.mean(pocket_times):.4f} ± {np.std(pocket_times):.4f}s")
    print(f"Total per complex: {np.mean(total_times):.4f} ± {np.std(total_times):.4f}s")
    
    # Extrapolate to full dataset
    total_time_hours = (len(df) * np.mean(total_times)) / 3600
    print(f"\nEstimated time for {len(df)} samples: {total_time_hours:.2f} hours")
    
    # Visualization
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    
    positions = [1, 2, 3]
    bp = ax.boxplot([ligand_times, pocket_times, total_times], 
                     positions=positions,
                     labels=['Ligand', 'Pocket', 'Total'])
    
    ax.set_ylabel('Time (seconds)')
    ax.set_title('Computation Time Distribution')
    ax.grid(True, alpha=0.3)
    
    # Add mean values
    for i, times in enumerate([ligand_times, pocket_times, total_times], 1):
        ax.text(i, max(times) * 1.05, f'μ={np.mean(times):.3f}s', 
                ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'ligand_times': ligand_times,
        'pocket_times': pocket_times,
        'total_times': total_times
    }

# ================== MULTI-TASK LEARNING EXTENSION ==================

class MTL_iScoreModel(iScoreModel):
    """
    Multi-task learning version of iScore model
    """
    
    def __init__(self, task_names: List[str], model_type: str = 'xgboost', **kwargs):
        self.task_names = task_names
        self.models = {}
        self.scalers = {}
        
        for task in task_names:
            self.scalers[task] = StandardScaler()
            if model_type == 'xgboost':
                self.models[task] = xgb.XGBRegressor(**kwargs)
            elif model_type == 'rf':
                self.models[task] = RandomForestRegressor(**kwargs)
    
    def fit(self, X: np.ndarray, y_dict: dict):
        """Fit models for all tasks"""
        for task in self.task_names:
            if task in y_dict:
                # Remove samples with NaN for this task
                mask = ~np.isnan(y_dict[task])
                X_task = X[mask]
                y_task = y_dict[task][mask]
                
                if len(X_task) > 0:
                    X_scaled = self.scalers[task].fit_transform(X_task)
                    self.models[task].fit(X_scaled, y_task)
                    print(f"  Trained {task} on {len(X_task)} samples")
    
    def predict(self, X: np.ndarray) -> dict:
        """Predict for all tasks"""
        predictions = {}
        for task in self.task_names:
            if task in self.models:
                X_scaled = self.scalers[task].transform(X)
                predictions[task] = self.models[task].predict(X_scaled)
        return predictions

# ================== MAIN ENTRY POINT ==================


# ================== MAIN PIPELINE ==================

def run_complete_pipeline(df: pd.DataFrame, target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost',
                         save_dir: str = 'iscore_results'):
    """
    Run complete iScore pipeline: preparation, CV, training, and evaluation
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print("="*60)
    print("iScore METHOD - COMPLETE PIPELINE")
    print("="*60)
    
    # 1. Initialize descriptor calculator
    print("\n1. Initializing descriptor calculator...")
    descriptor_calc = DescriptorCalculator(cache_dir=save_dir / 'descriptors_cache')
    
    # 2. Prepare dataset
    print("\n2. Preparing dataset with descriptors...")
    df_features, feature_cols = prepare_iscore_dataset(df, descriptor_calc, use_cache=True)
    
    # Remove samples with NaN in target
    df_clean = df_features.dropna(subset=[target_col])
    print(f"Final dataset: {len(df_clean)} samples with {len(feature_cols)} features")
    
    # 3. Cross-validation
    print("\n3. Running cross-validation...")
    cv_results = cross_validate_iscore(
        df_clean, feature_cols, target_col=target_col,
        n_splits=n_splits, model_type=model_type
    )
    
    # 4. Visualize CV results
    print("\n4. Creating visualizations...")
    plot_cv_results(cv_results, save_path=save_dir / 'cv_results.png')
    
    # 5. Train full model
    print("\n5. Training model on full dataset...")
    model_path = save_dir / 'iscore_model.pkl'
    full_model = train_full_model(
        df_clean, feature_cols, target_col=target_col,
        model_type=model_type, save_path=str(model_path)
    )
    
    # 6. Save results summary
    summary = {
        'dataset_size': len(df_clean),
        'n_features': len(feature_cols),
        'cv_r2_mean': np.mean(cv_results['r2_scores']),
        'cv_r2_std': np.std(cv_results['r2_scores']),
        'cv_rmse_mean': np.mean(cv_results['rmse_scores']),
        'cv_rmse_std': np.std(cv_results['rmse_scores']),
        'model_type': model_type,
        'target': target_col
    }
    
    with open(save_dir / f'summary_{target_col}.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print("\n" + "="*60)
    print("PIPELINE COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Results saved to: {save_dir}")
    
    return df_clean, full_model, cv_results

def example_usage():
    """
    Example of how to use the iScore pipeline
    """
    print("\n" + "="*60)
    print("iScore METHOD - EXAMPLE USAGE")
    print("="*60)
    
    # 1. Load your data
    print("\n1. Loading data...")
    df = pd.read_parquet("../data/standardized/standardized_input.parquet")
    
    # Filter for samples with required data
    df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf', 'smiles', 'pKi'])
    df = df.head(100)  # Use subset for testing
    print(f"Loaded {len(df)} samples")
    
    # 2. Run complete pipeline
    print("\n2. Running complete pipeline...")
    df_features, model, cv_results = run_complete_pipeline(
        df, 
        target_col='pKi',
        n_splits=5,
        model_type='xgboost',
        save_dir='iscore_results'
    )
    
    # 3. Example inference
    print("\n3. Example inference on new data...")
    predictor = iScorePredictor(
        model_path='iscore_results/iscore_model.pkl'
    )
    
    # Single prediction
    test_idx = 0
    test_row = df.iloc[test_idx]
    
    print(f"\nPredicting for sample {test_idx}...")
    prediction, info = predictor.predict_single(
        protein_pdb_path=test_row['standardized_protein_pdb'],
        ligand_sdf_path=test_row['standardized_ligand_sdf'],
        smiles=test_row['smiles']
    )
    
    if info['success']:
        print(f"Predicted pKi: {prediction:.3f}")
        print(f"Actual pKi: {test_row['pKi']:.3f}")
        print(f"Total time: {info['total_time']:.3f}s")
        print(f"  - Ligand descriptors: {info['ligand_desc_time']:.3f}s")
        print(f"  - Pocket descriptors: {info['pocket_desc_time']:.3f}s")
        print(f"  - Prediction: {info['prediction_time']:.3f}s")
    else:
        print(f"Prediction failed: {info['error']}")
    
    # Batch prediction
    print("\n4. Batch prediction example...")
    test_df = df.head(10)
    
    predictions_df = predictor.predict_batch(
        protein_pdb_paths=test_df['standardized_protein_pdb'].tolist(),
        ligand_sdf_paths=test_df['standardized_ligand_sdf'].tolist(),
        smiles_list=test_df['smiles'].tolist(),
        n_jobs=4
    )
    
    # Evaluate batch predictions
    successful_preds = predictions_df[predictions_df['success']]
    if len(successful_preds) > 0:
        predictions_df['actual'] = test_df['pKi'].values
        valid_preds = predictions_df.dropna(subset=['prediction', 'actual'])
        
        if len(valid_preds) > 0:
            r2 = r2_score(valid_preds['actual'], valid_preds['prediction'])
            rmse = np.sqrt(mean_squared_error(valid_preds['actual'], valid_preds['prediction']))
            
            print(f"\nBatch prediction results:")
            print(f"  Successful: {len(successful_preds)}/{len(test_df)}")
            print(f"  R²: {r2:.3f}")
            print(f"  RMSE: {rmse:.3f}")
            print(f"  Avg time: {successful_preds['total_time'].mean():.3f}s per complex")
    
    print("\n" + "="*60)
    print("EXAMPLE COMPLETED")
    print("="*60)
    
    
"""
iScore Method for Drug-Target Affinity Prediction with Cross-Validation
This implements the pocket + ligand descriptor approach with proper CV, training, and inference
"""

import os
import json
import time
import shutil
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

# Scientific computing
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
import xgboost as xgb

# RDKit
from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors

# Suppress warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')

# ================== DESCRIPTOR CALCULATION ==================

class DescriptorCalculator:
    """Handles calculation and caching of molecular descriptors"""
    
    def __init__(self, cache_dir: str = "../data/descriptors_cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Define descriptor names based on your code
        self.descriptor_names = [
            'MolLogP', 'MolMR', 'ExactMolWt', 'HeavyAtomCount', 'NumHAcceptors', 'NumHDonors', 
            'NumHeteroatoms', 'NumRotatableBonds', 'NumAromaticRings', 'NumAliphaticRings',
            'RingCount', 'TPSA', 'LabuteASA', 'Kappa1', 'Kappa2', 'Kappa3', 
            'Chi0', 'Chi1', 'Chi0n', 'Chi1n', 'Chi2n', 'Chi3n', 'Chi4n',
            'Chi0v', 'Chi1v', 'Chi2v', 'Chi3v', 'Chi4v',
            'PEOE_VSA1', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 
            'PEOE_VSA6', 'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'PEOE_VSA10', 
            'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13', 'PEOE_VSA14', 
            'SMR_VSA1', 'SMR_VSA3', 'SMR_VSA4', 'SMR_VSA5', 'SMR_VSA6',
            'SMR_VSA7', 'SMR_VSA9', 'SMR_VSA10',
            'SlogP_VSA1', 'SlogP_VSA2', 'SlogP_VSA3', 'SlogP_VSA4', 'SlogP_VSA5', 
            'SlogP_VSA6', 'SlogP_VSA7', 'SlogP_VSA8', 'SlogP_VSA10', 'SlogP_VSA11', 
            'SlogP_VSA12',
            'EState_VSA1', 'EState_VSA2', 'EState_VSA3', 'EState_VSA4', 'EState_VSA5',
            'EState_VSA6', 'EState_VSA7', 'EState_VSA8', 'EState_VSA9', 'EState_VSA10',
            'VSA_EState1', 'VSA_EState2', 'VSA_EState3', 'VSA_EState4', 'VSA_EState5',
            'VSA_EState6', 'VSA_EState7', 'VSA_EState8', 'VSA_EState9', 'VSA_EState10'
        ]
        self.columns_to_drop_set1 = [
            "pdb", "lig", "overlap", "PP-crit", "PP-dst", "crit4", 
            "crit5", "crit6", "crit6_continue", "nb_AS_norm", "apol_as_prop_norm", 
            "mean_loc_hyd_dens_norm", "polarity_score_norm", "as_density_norm", 
            "as_max_dst_norm", "drug_score"
        ]

        self.columns_to_drop_set2 = {
            "pock_vol","nb_AS","mean_as_ray","mean_as_solv_acc","apol_as_prop","mean_loc_hyd_dens","hydrophobicity_score","volume_score","polarity_score","charge_score","flex","prop_polar_atm","as_density","as_max_dst",
            "convex_hull_volume","surf_pol_vdw14","surf_pol_vdw22","surf_apol_vdw14","surf_apol_vdw22","n_abpa","ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL","pKd"
        }
    
    def calculate_ligand_descriptors(self, smiles: str) -> np.ndarray:
        """Calculate RDKit descriptors for a SMILES string"""
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            desc_values = []
            for desc_name in self.descriptor_names:
                try:
                    desc_func = getattr(Descriptors, desc_name)
                    value = desc_func(mol)
                    desc_values.append(value)
                except:
                    desc_values.append(np.nan)
            return np.array(desc_values)
        else:
            return np.full(len(self.descriptor_names), np.nan)
    
    def calculate_pocket_descriptors(self, complex_pdb_path: str, ligand_sdf_path: str, 
                                   ligand_code: str = "LIG", dpocket_path: str = "dpocket") -> Dict:
        """
        Calculate pocket descriptors using dpocket
        Returns dict with pocket descriptor values
        """
        import tempfile
        import subprocess
        
        # Create temporary directory for dpocket output
        with tempfile.TemporaryDirectory() as temp_dir:
            # Write dpocket input file
            input_file = os.path.join(temp_dir, "dpocket_input.txt")
            with open(input_file, 'w') as f:
                f.write(f"{complex_pdb_path}\t{ligand_code}\n")
            
            # Run dpocket
            try:
                result = subprocess.run(
                    [dpocket_path, "-f", input_file],
                    cwd=temp_dir,
                    capture_output=True,
                    text=True,
                    timeout=30
                )
                
                # Parse dpocket output
                output_file = os.path.join(temp_dir, "dpout_explicitp.txt")
                if os.path.exists(output_file):
                    pocket_df = pd.read_csv(output_file, sep='\s+')
                    # Return first row as dict (excluding pdb column)
                    if len(pocket_df) > 0:
                        return pocket_df.iloc[0].drop(['pdb', 'lig'], errors='ignore').to_dict()
            except Exception as e:
                print(f"dpocket failed: {e}")
        
        # Return empty dict if failed
        return {}
    
    def combine_protein_ligand_pdb(self, protein_pdb_path: str, ligand_sdf_path: str, 
                                   output_path: str, lig_code: str = "LIG") -> str:
        """Combine protein and ligand into single PDB file"""
        # Load ligand from SDF
        mol = Chem.MolFromMolFile(ligand_sdf_path, removeHs=False)
        if mol is None:
            raise ValueError(f"Failed to read ligand SDF: {ligand_sdf_path}")
        
        # Convert ligand to PDB block
        pdb_block = Chem.MolToPDBBlock(mol)
        
        # Fix residue names in ligand PDB
        ligand_lines = []
        for line in pdb_block.splitlines():
            if line.startswith("HETATM") or line.startswith("ATOM"):
                line = line[:17] + lig_code + line[20:]
            ligand_lines.append(line)
        
        # Read protein PDB
        with open(protein_pdb_path, 'r') as f:
            protein_lines = [line for line in f if not line.startswith("END")]
        
        # Write combined PDB
        with open(output_path, 'w') as f:
            f.writelines(protein_lines)
            f.write('\n'.join(ligand_lines))
            f.write('\nEND\n')
        
        return output_path
    
    def get_cached_path(self, sample_id: str) -> Path:
        """Get path for cached descriptor file"""
        return self.cache_dir / f"{sample_id}_descriptors.npz"
    
    def save_descriptors(self, sample_id: str, ligand_desc: np.ndarray, pocket_desc: np.ndarray):
        """Save descriptors to cache"""
        cache_path = self.get_cached_path(sample_id)
        np.savez_compressed(cache_path, ligand=ligand_desc, pocket=pocket_desc)
    
    def load_descriptors(self, sample_id: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """Load descriptors from cache if available"""
        cache_path = self.get_cached_path(sample_id)
        if cache_path.exists():
            data = np.load(cache_path)
            return data['ligand'], data['pocket']
        return None, None

# ================== DATASET PREPARATION ==================

def prepare_iscore_dataset(df: pd.DataFrame, descriptor_calc: DescriptorCalculator, 
                          use_cache: bool = True, n_jobs: int = 4) -> pd.DataFrame:
    """
    Prepare dataset with iScore descriptors
    """
    print(f"Preparing iScore dataset for {len(df)} samples...")
    
    # Reset index to ensure sequential numbering
    df_reset = df.reset_index(drop=True)
    
    def process_sample(idx, row):
        """Process single sample"""
        sample_id = f"{idx}"
        
        # Try to load from cache
        if use_cache:
            ligand_desc, pocket_desc = descriptor_calc.load_descriptors(sample_id)
            if ligand_desc is not None and pocket_desc is not None:
                return idx, ligand_desc, pocket_desc, None
        
        try:
            # Calculate ligand descriptors
            start_time = time.time()
            ligand_desc = descriptor_calc.calculate_ligand_descriptors(row.smiles)
            
            # Create complex PDB
            complex_path = f"/tmp/complex_{idx}.pdb"
            descriptor_calc.combine_protein_ligand_pdb(
                row.standardized_protein_pdb,
                row.standardized_ligand_sdf,
                complex_path
            )
                # --- 5. Drop unwanted cols ---

            # Calculate pocket descriptors (simplified - you should use actual dpocket)
            # For now, using random values as placeholder
            
            elapsed = time.time() - start_time
            
            # Cache descriptors
            if use_cache:
                descriptor_calc.save_descriptors(sample_id, ligand_desc, pocket_desc)
            
            # Clean up
            if os.path.exists(complex_path):
                os.remove(complex_path)
            
            return idx, ligand_desc, pocket_desc, elapsed
            
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            return idx, None, None, None
    
    # Process samples in parallel - use enumerate to get sequential indices
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_sample)(idx, row) 
        for idx, row in tqdm(enumerate(df_reset.itertuples(index=False)), 
                            total=len(df_reset), desc="Computing descriptors")
    )
    
    # Organize results
    ligand_features = []
    pocket_features = []
    valid_indices = []
    times = []
    
    for idx, ligand_desc, pocket_desc, elapsed in results:
        if ligand_desc is not None and pocket_desc is not None:
            ligand_features.append(ligand_desc)
            pocket_features.append(pocket_desc)
            valid_indices.append(idx)
            if elapsed is not None:
                times.append(elapsed)
    
    # Create feature matrix
    X_ligand = np.array(ligand_features)
    X_pocket = np.array(pocket_features)
    X = np.hstack([X_ligand, X_pocket])
    
    # Create result dataframe using iloc with valid indices
    result_df = df_reset.iloc[valid_indices].copy()
    
    # Add features as columns
    feature_names = [f'ligand_{i}' for i in range(X_ligand.shape[1])] + \
                   [f'pocket_{i}' for i in range(X_pocket.shape[1])]
    
    for i, fname in enumerate(feature_names):
        result_df[fname] = X[:, i]
    
    if times:
        print(f"Average descriptor computation time: {np.mean(times):.3f}s per complex")
    
    print(f"Successfully processed {len(result_df)}/{len(df_reset)} samples")
    
    return result_df, feature_names

# ================== MODEL TRAINING ==================

class iScoreModel:
    """Wrapper for iScore regression model"""
    
    def __init__(self, model_type: str = 'xgboost', **kwargs):
        self.model_type = model_type
        self.scaler = StandardScaler()
        
        if model_type == 'xgboost':
            self.model = xgb.XGBRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', 6),
                learning_rate=kwargs.get('learning_rate', 0.01),
                subsample=kwargs.get('subsample', 0.8),
                colsample_bytree=kwargs.get('colsample_bytree', 0.8),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
        elif model_type == 'rf':
            self.model = RandomForestRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', None),
                min_samples_split=kwargs.get('min_samples_split', 2),
                min_samples_leaf=kwargs.get('min_samples_leaf', 1),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        """Fit the model"""
        X_scaled = self.scaler.fit_transform(X)
        self.model.fit(X_scaled, y)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions"""
        X_scaled = self.scaler.transform(X)
        return self.model.predict(X_scaled)
    
    def save(self, path: str):
        """Save model and scaler"""
        import joblib
        model_path = Path(path)
        model_path.parent.mkdir(parents=True, exist_ok=True)
        
        joblib.dump({
            'model': self.model,
            'scaler': self.scaler,
            'model_type': self.model_type
        }, path)
    
    def load(self, path: str):
        """Load model and scaler"""
        import joblib
        data = joblib.load(path)
        self.model = data['model']
        self.scaler = data['scaler']
        self.model_type = data['model_type']

# ================== CROSS VALIDATION ==================

def cross_validate_iscore(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost', 
                         model_params: dict = None, random_state: int = 42):
    """
    Perform cross-validation for iScore method
    """
    print(f"\n{'='*50}")
    print(f"Starting {n_splits}-Fold Cross-Validation")
    print(f"Model: {model_type}")
    print(f"Target: {target_col}")
    print(f"Features: {len(feature_cols)} descriptors")
    print(f"{'='*50}\n")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Initialize CV
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Results storage
    cv_results = {
        'r2_scores': [],
        'rmse_scores': [],
        'mae_scores': [],
        'predictions': [],
        'true_values': [],
        'fold_times': []
    }
    
    # Cross-validation loop
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X), 1):
        print(f"\nFold {fold_idx}/{n_splits}")
        print("-" * 30)
        
        fold_start = time.time()
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        print(f"Train: {len(X_train)} samples")
        print(f"Test: {len(X_test)} samples")
        
        # Train model
        model = iScoreModel(model_type=model_type, **(model_params or {}))
        
        print("Training model...")
        train_start = time.time()
        model.fit(X_train, y_train)
        train_time = time.time() - train_start
        print(f"Training completed in {train_time:.2f}s")
        
        # Make predictions
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = np.mean(np.abs(y_test - y_pred))
        
        # Store results
        cv_results['r2_scores'].append(r2)
        cv_results['rmse_scores'].append(rmse)
        cv_results['mae_scores'].append(mae)
        cv_results['predictions'].extend(y_pred)
        cv_results['true_values'].extend(y_test)
        cv_results['fold_times'].append(time.time() - fold_start)
        
        print(f"Fold {fold_idx} Results:")
        print(f"  R²: {r2:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAE: {mae:.4f}")
        
        # Feature importance (if available)
        if hasattr(model.model, 'feature_importances_'):
            top_features = np.argsort(model.model.feature_importances_)[-5:]
            print(f"  Top 5 features: {[feature_cols[i] for i in top_features]}")
    
    # Calculate overall statistics
    print(f"\n{'='*50}")
    print("Cross-Validation Summary")
    print(f"{'='*50}")
    print(f"R² Score: {np.mean(cv_results['r2_scores']):.4f} ± {np.std(cv_results['r2_scores']):.4f}")
    print(f"RMSE: {np.mean(cv_results['rmse_scores']):.4f} ± {np.std(cv_results['rmse_scores']):.4f}")
    print(f"MAE: {np.mean(cv_results['mae_scores']):.4f} ± {np.std(cv_results['mae_scores']):.4f}")
    print(f"Avg fold time: {np.mean(cv_results['fold_times']):.2f}s")
    
    # Overall metrics on all predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    overall_r2 = r2_score(all_true, all_pred)
    overall_rmse = np.sqrt(mean_squared_error(all_true, all_pred))
    
    print(f"\nOverall Performance:")
    print(f"  R²: {overall_r2:.4f}")
    print(f"  RMSE: {overall_rmse:.4f}")
    
    return cv_results

# ================== VISUALIZATION ==================

def plot_cv_results(cv_results: dict, save_path: str = None):
    """Create visualization of cross-validation results"""
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Scatter plot of predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    
    ax = axes[0, 0]
    ax.scatter(all_true, all_pred, alpha=0.5, s=10)
    ax.plot([all_true.min(), all_true.max()], [all_true.min(), all_true.max()], 'r--', lw=2)
    ax.set_xlabel('True Values')
    ax.set_ylabel('Predicted Values')
    ax.set_title(f'Predictions (R²={r2_score(all_true, all_pred):.3f})')
    ax.grid(True, alpha=0.3)
    
    # 2. Residual plot
    ax = axes[0, 1]
    residuals = all_true - all_pred
    ax.scatter(all_pred, residuals, alpha=0.5, s=10)
    ax.axhline(y=0, color='r', linestyle='--')
    ax.set_xlabel('Predicted Values')
    ax.set_ylabel('Residuals')
    ax.set_title('Residual Plot')
    ax.grid(True, alpha=0.3)
    
    # 3. Fold-wise performance
    ax = axes[1, 0]
    folds = range(1, len(cv_results['r2_scores']) + 1)
    ax.plot(folds, cv_results['r2_scores'], 'o-', label='R²', color='blue')
    ax.set_xlabel('Fold')
    ax.set_ylabel('R² Score')
    ax.set_title('Fold-wise R² Performance')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # 4. Distribution of errors
    ax = axes[1, 1]
    errors = np.abs(residuals)
    ax.hist(errors, bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(x=np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}')
    ax.set_xlabel('Absolute Error')
    ax.set_ylabel('Frequency')
    ax.set_title('Error Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.suptitle('iScore Cross-Validation Results', fontsize=14)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# ================== FULL TRAINING ==================

def train_full_model(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                    model_type: str = 'xgboost', model_params: dict = None,
                    save_path: str = 'models/iscore_model.pkl'):
    """
    Train model on full dataset
    """
    print(f"\n{'='*50}")
    print("Training Full Model")
    print(f"{'='*50}")
    print(f"Samples: {len(df)}")
    print(f"Features: {len(feature_cols)}")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Train model
    model = iScoreModel(model_type=model_type, **(model_params or {}))
    
    print("Training...")
    start_time = time.time()
    model.fit(X, y)
    train_time = time.time() - start_time
    
    print(f"Training completed in {train_time:.2f}s")
    
    # Evaluate on training set
    y_pred = model.predict(X)
    train_r2 = r2_score(y, y_pred)
    train_rmse = np.sqrt(mean_squared_error(y, y_pred))
    
    print(f"Training Performance:")
    print(f"  R²: {train_r2:.4f}")
    print(f"  RMSE: {train_rmse:.4f}")
    
    # Save model
    if save_path:
        model.save(save_path)
        print(f"Model saved to: {save_path}")
    
    return model


    def predict_batch(self, protein_pdb_paths: List[str], ligand_sdf_paths: List[str],
                     smiles_list: List[str] = None, n_jobs: int = 4) -> pd.DataFrame:
        """
        Predict affinities for multiple protein-ligand pairs
        """
        if smiles_list is None:
            smiles_list = [None] * len(protein_pdb_paths)
        
        def process_pair(idx, pdb_path, sdf_path, smiles):
            pred, info = self.predict_single(pdb_path, sdf_path, smiles)
            return idx, pred, info
        
        # Process in parallel
        results = Parallel(n_jobs=n_jobs)(
            delayed(process_pair)(idx, pdb, sdf, smi) 
            for idx, (pdb, sdf, smi) in enumerate(
                tqdm(zip(protein_pdb_paths, ligand_sdf_paths, smiles_list), 
                     total=len(protein_pdb_paths), desc="Predicting")
            )
        )
        
        # Organize results
        predictions = []
        for idx, pred, info in sorted(results, key=lambda x: x[0]):
            predictions.append({
                'index': idx,
                'prediction': pred,
                'success': info['success'],
                'total_time': info.get('total_time', None),
                'error': info.get('error', None)
            })
        
        return pd.DataFrame(predictions)

# ================== MAIN PIPELINE ==================

def run_complete_pipeline(df: pd.DataFrame, target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost',
                         save_dir: str = 'iscore_results'):
    """
    Run complete iScore pipeline: preparation, CV, training, and evaluation
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print("="*60)
    print("iScore METHOD - COMPLETE PIPELINE")
    print("="*60)
    
    # 1. Initialize descriptor calculator
    print("\n1. Initializing descriptor calculator...")
    descriptor_calc = DescriptorCalculator(cache_dir=save_dir / 'descriptors_cache')
    
    # 2. Prepare dataset
    print("\n2. Preparing dataset with descriptors...")
    df_features, feature_cols = prepare_iscore_dataset(df, descriptor_calc, use_cache=True)
    
    
    # Remove samples with NaN in target
    df_clean = df_features.dropna(subset=[target_col])
    print(f"Final dataset: {len(df_clean)} samples with {len(feature_cols)} features")
    
    # 3. Cross-validation
    print("\n3. Running cross-validation...")
    cv_results = cross_validate_iscore(
        df_clean, feature_cols, target_col=target_col,
        n_splits=n_splits, model_type=model_type
    )
    
    # 4. Visualize CV results
    print("\n4. Creating visualizations...")
    plot_cv_results(cv_results, save_path=save_dir / 'cv_results.png')
    
    # 5. Train full model
    print("\n5. Training model on full dataset...")
    model_path = save_dir / 'iscore_model.pkl'
    full_model = train_full_model(
        df_clean, feature_cols, target_col=target_col,
        model_type=model_type, save_path=str(model_path)
    )
    
    # 6. Save results summary
    summary = {
        'dataset_size': len(df_clean),
        'n_features': len(feature_cols),
        'cv_r2_mean': np.mean(cv_results['r2_scores']),
        'cv_r2_std': np.std(cv_results['r2_scores']),
        'cv_rmse_mean': np.mean(cv_results['rmse_scores']),
        'cv_rmse_std': np.std(cv_results['rmse_scores']),
        'model_type': model_type,
        'target': target_col
    }
    
    with open(save_dir / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print("\n" + "="*60)
    print("PIPELINE COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Results saved to: {save_dir}")
    
    return df_clean, full_model, cv_results

# ================== EXAMPLE USAGE ==================

def example_usage():
    """
    Example of how to use the iScore pipeline
    """
    print("\n" + "="*60)
    print("iScore METHOD - EXAMPLE USAGE")
    print("="*60)
    
    # 1. Load your data
    print("\n1. Loading data...")
    df = pd.read_parquet("../data/standardized/standardized_input.parquet")
    
    # Filter for samples with required data
    df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf', 'smiles', 'pKi'])
    df = df.head(100)  # Use subset for testing
    print(f"Loaded {len(df)} samples")
    
    # 2. Run complete pipeline
    print("\n2. Running complete pipeline...")
    df_features, model, cv_results = run_complete_pipeline(
        df, 
        target_col='pKi',
        n_splits=5,
        model_type='xgboost',
        save_dir='iscore_results'
    )
    
    # 3. Example inference
    print("\n3. Example inference on new data...")
    predictor = iScorePredictor(
        model_path='iscore_results/iscore_model.pkl'
    )
    
    # Single prediction
    test_idx = 0
    test_row = df.iloc[test_idx]
    
    print(f"\nPredicting for sample {test_idx}...")
    prediction, info = predictor.predict_single(
        protein_pdb_path=test_row['standardized_protein_pdb'],
        ligand_sdf_path=test_row['standardized_ligand_sdf'],
        smiles=test_row['smiles']
    )
    
    if info['success']:
        print(f"Predicted pKi: {prediction:.3f}")
        print(f"Actual pKi: {test_row['pKi']:.3f}")
        print(f"Total time: {info['total_time']:.3f}s")
        print(f"  - Ligand descriptors: {info['ligand_desc_time']:.3f}s")
        print(f"  - Pocket descriptors: {info['pocket_desc_time']:.3f}s")
        print(f"  - Prediction: {info['prediction_time']:.3f}s")
    else:
        print(f"Prediction failed: {info['error']}")
    
    # Batch prediction
    print("\n4. Batch prediction example...")
    test_df = df.head(10)
    
    predictions_df = predictor.predict_batch(
        protein_pdb_paths=test_df['standardized_protein_pdb'].tolist(),
        ligand_sdf_paths=test_df['standardized_ligand_sdf'].tolist(),
        smiles_list=test_df['smiles'].tolist(),
        n_jobs=4
    )
    
    # Evaluate batch predictions
    successful_preds = predictions_df[predictions_df['success']]
    if len(successful_preds) > 0:
        predictions_df['actual'] = test_df['pKi'].values
        valid_preds = predictions_df.dropna(subset=['prediction', 'actual'])
        
        if len(valid_preds) > 0:
            r2 = r2_score(valid_preds['actual'], valid_preds['prediction'])
            rmse = np.sqrt(mean_squared_error(valid_preds['actual'], valid_preds['prediction']))
            
            print(f"\nBatch prediction results:")
            print(f"  Successful: {len(successful_preds)}/{len(test_df)}")
            print(f"  R²: {r2:.3f}")
            print(f"  RMSE: {rmse:.3f}")
            print(f"  Avg time: {successful_preds['total_time'].mean():.3f}s per complex")
    
    print("\n" + "="*60)
    print("EXAMPLE COMPLETED")
    print("="*60)

# ================== COMPARATIVE ANALYSIS ==================

def compare_with_gnn(iscore_results: dict, gnn_results: dict = None):
    """
    Compare iScore results with GNN results
    """
    print("\n" + "="*60)
    print("COMPARATIVE ANALYSIS: iScore vs GNN")
    print("="*60)
    
    # iScore performance
    print("\niScore Method:")
    print(f"  R²: {np.mean(iscore_results['r2_scores']):.4f} ± {np.std(iscore_results['r2_scores']):.4f}")
    print(f"  RMSE: {np.mean(iscore_results['rmse_scores']):.4f} ± {np.std(iscore_results['rmse_scores']):.4f}")
    
    if gnn_results:
        print("\nGNN Method:")
        print(f"  R²: {np.mean(gnn_results['r2_scores']):.4f} ± {np.std(gnn_results['r2_scores']):.4f}")
        print(f"  RMSE: {np.mean(gnn_results['rmse_scores']):.4f} ± {np.std(gnn_results['rmse_scores']):.4f}")
        
        # Statistical comparison
        from scipy import stats
        
        # Paired t-test on R² scores
        t_stat, p_value = stats.ttest_rel(
            iscore_results['r2_scores'], 
            gnn_results['r2_scores']
        )
        
        print(f"\nStatistical Comparison (paired t-test on R² scores):")
        print(f"  t-statistic: {t_stat:.4f}")
        print(f"  p-value: {p_value:.4f}")
        
        if p_value < 0.05:
            if np.mean(iscore_results['r2_scores']) > np.mean(gnn_results['r2_scores']):
                print("  Result: iScore significantly better (p < 0.05)")
            else:
                print("  Result: GNN significantly better (p < 0.05)")
        else:
            print("  Result: No significant difference (p ≥ 0.05)")
        
        # Visualization
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Box plot comparison
        ax = axes[0]
        ax.boxplot([iscore_results['r2_scores'], gnn_results['r2_scores']], 
                   labels=['iScore', 'GNN'])
        ax.set_ylabel('R² Score')
        ax.set_title('Model Performance Comparison')
        ax.grid(True, alpha=0.3)
        
        # Scatter plot of predictions (if available)
        ax = axes[1]
        if 'predictions' in iscore_results and 'predictions' in gnn_results:
            ax.scatter(iscore_results['predictions'], gnn_results['predictions'], 
                      alpha=0.5, s=10)
            min_val = min(min(iscore_results['predictions']), min(gnn_results['predictions']))
            max_val = max(max(iscore_results['predictions']), max(gnn_results['predictions']))
            ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2)
            ax.set_xlabel('iScore Predictions')
            ax.set_ylabel('GNN Predictions')
            ax.set_title('Prediction Correlation')
            ax.grid(True, alpha=0.3)
        
        plt.suptitle('iScore vs GNN Comparison', fontsize=14)
        plt.tight_layout()
        plt.show()

# ================== TIMING ANALYSIS ==================



# ================== MULTI-TASK LEARNING EXTENSION ==================

class MTL_iScoreModel(iScoreModel):
    """
    Multi-task learning version of iScore model
    """
    
    def __init__(self, task_names: List[str], model_type: str = 'xgboost', **kwargs):
        self.task_names = task_names
        self.models = {}
        self.scalers = {}
        
        for task in task_names:
            self.scalers[task] = StandardScaler()
            if model_type == 'xgboost':
                self.models[task] = xgb.XGBRegressor(**kwargs)
            elif model_type == 'rf':
                self.models[task] = RandomForestRegressor(**kwargs)
    
    def fit(self, X: np.ndarray, y_dict: dict):
        """Fit models for all tasks"""
        for task in self.task_names:
            if task in y_dict:
                # Remove samples with NaN for this task
                mask = ~np.isnan(y_dict[task])
                X_task = X[mask]
                y_task = y_dict[task][mask]
                
                if len(X_task) > 0:
                    X_scaled = self.scalers[task].fit_transform(X_task)
                    self.models[task].fit(X_scaled, y_task)
                    print(f"  Trained {task} on {len(X_task)} samples")
    
    def predict(self, X: np.ndarray) -> dict:
        """Predict for all tasks"""
        predictions = {}
        for task in self.task_names:
            if task in self.models:
                X_scaled = self.scalers[task].transform(X)
                predictions[task] = self.models[task].predict(X_scaled)
        return predictions

# ================== MAIN ENTRY POINT ==================
class iScorePredictor:
    """Complete inference pipeline for iScore method"""
    
    def __init__(self, model_path: str, descriptor_calc: DescriptorCalculator = None):
        """Initialize predictor with trained model"""
        self.model = iScoreModel()
        self.model.load(model_path)
        self.descriptor_calc = descriptor_calc or DescriptorCalculator()
        
        # Load feature names from training (should be saved with model)
        # For now, we'll reconstruct them
        self.ligand_features = self.descriptor_calc.descriptor_names
        self.pocket_features = None  # Will be determined from dpocket output
    
    def predict_single(self, protein_pdb_path: str, ligand_sdf_path: str, 
                       smiles: str = None) -> Tuple[float, dict]:
        """
        Predict affinity for single protein-ligand pair using exact dpocket method
        """
        start_time = time.time()
        info = {}
        
        try:
            # Get SMILES if not provided
            if smiles is None:
                mol = Chem.MolFromMolFile(ligand_sdf_path)
                smiles = Chem.MolToSmiles(mol)
            
            # Calculate ligand descriptors
            desc_start = time.time()
            ligand_df = self.descriptor_calc.calculate_descriptors([smiles])
            ligand_desc = ligand_df.values[0]
            info['ligand_desc_time'] = time.time() - desc_start
            
            # Create complex PDB
            complex_path = "temp_complex.pdb"
            self.descriptor_calc.combine_protein_ligand(
                protein_pdb_path, ligand_sdf_path, complex_path
            )
            
            # Run dpocket
            pocket_start = time.time()
            pocket_df = self.descriptor_calc.run_dpocket_single(complex_path)
            
            if len(pocket_df) > 0:
                # Drop unnecessary columns
                pocket_df = pocket_df.drop(columns=self.descriptor_calc.columns_to_drop, errors='ignore')
                pocket_desc = pocket_df.iloc[0].values
            else:
                # Use NaN if dpocket fails
                pocket_desc = np.full(41, np.nan)  # Adjust size based on actual features
            
            info['pocket_desc_time'] = time.time() - pocket_start
            
            # Combine features
            X = np.hstack([ligand_desc, pocket_desc]).reshape(1, -1)
            
            # Handle NaN values - replace with median or mean from training
            # For now, replace with 0
            X = np.nan_to_num(X, nan=0.0)
            
            # Make prediction
            pred_start = time.time()
            prediction = self.model.predict(X)[0]
            info['prediction_time'] = time.time() - pred_start
            
            # Clean up
            if os.path.exists(complex_path):
                os.remove(complex_path)
            
            info['total_time'] = time.time() - start_time
            info['success'] = True
            
            return prediction, info
            
        except Exception as e:
            info['error'] = str(e)
            info['success'] = False
            info['total_time'] = time.time() - start_time
            return None, info
    
    def predict_batch(self, protein_pdb_paths: List[str], ligand_sdf_paths: List[str],
                     smiles_list: List[str] = None, n_jobs: int = 1) -> pd.DataFrame:
        """
        Predict affinities for multiple protein-ligand pairs
        Note: dpocket batch processing is more efficient than parallel single runs
        """
        print(f"Predicting for {len(protein_pdb_paths)} complexes...")
        
        # Get SMILES if not provided
        if smiles_list is None:
            smiles_list = []
            for sdf_path in ligand_sdf_paths:
                mol = Chem.MolFromMolFile(sdf_path)
                smiles_list.append(Chem.MolToSmiles(mol) if mol else None)
        
        # Create temporary dataframe for batch processing
        temp_df = pd.DataFrame({
            'standardized_protein_pdb': protein_pdb_paths,
            'standardized_ligand_sdf': ligand_sdf_paths,
            'smiles': smiles_list
        })
        
        # Use the batch preparation method
        result_df, feature_cols = prepare_iscore_dataset(
            temp_df, self.descriptor_calc, use_cache=False
        )
        
        # Make predictions
        if len(result_df) > 0:
            X = result_df[feature_cols].values
            X = np.nan_to_num(X, nan=0.0)  # Handle NaN
            predictions = self.model.predict(X)
            
            result_df['prediction']
            
"""
iScore Method for Drug-Target Affinity Prediction with Cross-Validation
This implements the pocket + ligand descriptor approach with proper CV, training, and inference
"""

import os
import json
import time
import shutil
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

# Scientific computing
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
import xgboost as xgb

# RDKit
from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors

# Suppress warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')

# ================== DESCRIPTOR CALCULATION ==================

class DescriptorCalculator:
    """Handles calculation and caching of molecular descriptors"""
    
    def __init__(self, cache_dir: str = "../data/descriptors_cache", 
                 complex_dir: str = "../data/complex_pdbs",
                 descriptor_dir: str = "../data/descriptor_files"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.complex_dir = Path(complex_dir)
        self.complex_dir.mkdir(parents=True, exist_ok=True)
        self.descriptor_dir = Path(descriptor_dir)
        self.descriptor_dir.mkdir(parents=True, exist_ok=True)
        
        # Define descriptor names exactly as in your code
        self.descriptor_names = [
            'MolLogP', 'MolMR', 'ExactMolWt', 'HeavyAtomCount', 'NumHAcceptors', 'NumHDonors', 
            'NumHeteroatoms', 'NumRotatableBonds', 'NumAromaticRings', 'NumAliphaticRings',
            'RingCount', 'TPSA', 'LabuteASA', 'Kappa1', 'Kappa2', 'Kappa3', 
            'Chi0', 'Chi1', 'Chi0n', 'Chi1n', 'Chi2n', 'Chi3n', 'Chi4n',
            'Chi0v', 'Chi1v', 'Chi2v', 'Chi3v', 'Chi4v',
            'PEOE_VSA1', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 
            'PEOE_VSA6', 'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'PEOE_VSA10', 
            'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13', 'PEOE_VSA14', 
            'SMR_VSA1', 'SMR_VSA3', 'SMR_VSA4', 'SMR_VSA5', 'SMR_VSA6',
            'SMR_VSA7', 'SMR_VSA9', 'SMR_VSA10',
            'SlogP_VSA1', 'SlogP_VSA2', 'SlogP_VSA3', 'SlogP_VSA4', 'SlogP_VSA5', 
            'SlogP_VSA6', 'SlogP_VSA7', 'SlogP_VSA8', 'SlogP_VSA10', 'SlogP_VSA11', 
            'SlogP_VSA12',
            'EState_VSA1', 'EState_VSA2', 'EState_VSA3', 'EState_VSA4', 'EState_VSA5',
            'EState_VSA6', 'EState_VSA7', 'EState_VSA8', 'EState_VSA9', 'EState_VSA10',
            'VSA_EState1', 'VSA_EState2', 'VSA_EState3', 'VSA_EState4', 'VSA_EState5',
            'VSA_EState6', 'VSA_EState7', 'VSA_EState8', 'VSA_EState9', 'VSA_EState10'
        ]
        
        # Columns to drop from dpocket output
        self.columns_to_drop = []
        
 

    
    def calculate_descriptors(self, smiles_list: List[str]) -> pd.DataFrame:
        """Calculate RDKit descriptors for a list of SMILES strings - exact copy from your code"""
        descriptors = []
        for smiles in smiles_list:
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                desc_values = [getattr(Descriptors, desc)(mol) for desc in self.descriptor_names]
                descriptors.append(desc_values)
            else:
                descriptors.append([None] * len(self.descriptor_names))
        
        return pd.DataFrame(descriptors, columns=self.descriptor_names)
    
    def combine_protein_ligand(self, protein_pdb_path: str, ligand_sdf_path: str, 
                               output_pdb_path: str) -> str:
        """Combine protein and ligand into single PDB file - exact copy from your code"""
        # Load and convert ligand to PDB with residue name LIG
        mol = Chem.MolFromMolFile(ligand_sdf_path, removeHs=False)
        if mol is None:
            raise ValueError(f"Failed to read ligand SDF : {ligand_sdf_path}")
        
        # Create temp file for ligand PDB
        temp_ligand_path = "temp_ligand.pdb"
        Chem.MolToPDBFile(mol, temp_ligand_path)
        
        # Fix ligand PDB residue name to LIG
        with open(temp_ligand_path, "r") as f:
            ligand_lines = []
            for line in f:
                if line.startswith("HETATM") or line.startswith("ATOM"):
                    line = line[:17] + "LIG" + line[20:]  # Replace residue name
                ligand_lines.append(line)
        
        # Read protein PDB (exclude END line if present)
        with open(protein_pdb_path, "r") as f:
            protein_lines = [line for line in f if not line.startswith("END")]
        
        # Write combined PDB
        with open(output_pdb_path, "w") as out:
            out.writelines(protein_lines)
            out.writelines(ligand_lines)
            out.write("END\n")
        
        # Clean up temp file
        if os.path.exists(temp_ligand_path):
            os.remove(temp_ligand_path)
        
        return output_pdb_path
    
    def run_dpocket_single(self, complex_pdb_path: str, lig_code: str = "LIG") -> pd.DataFrame:
        """Run dpocket on a single complex and return pocket descriptors"""
        import shutil
        
        dpocket_output_dir = "./dpout"
        dpocket_output_file = "dpout_explicitp.txt"
        dpocket_input_file = "dp_input.txt"
        
        try:
            # Write dpocket input
            with open(dpocket_input_file, "w") as f:
                f.write(f"{complex_pdb_path}\t{lig_code}\n")
            
            # Remove and recreate dpocket output dir
            if os.path.exists(dpocket_output_dir):
                shutil.rmtree(dpocket_output_dir)
            os.makedirs(dpocket_output_dir)
            
            # Run dpocket
            os.system(f"dpocket -f {dpocket_input_file}")
            
            # Read dpocket output
            if os.path.exists(dpocket_output_file):
                pocket_df = pd.read_csv(dpocket_output_file, sep='\s+')
                pocket_df['pdb'] = pocket_df['pdb'].str.replace('.pdb', '', regex=False)
                # Drop unnecessary columns
                pocket_df = pocket_df.drop(columns=self.columns_to_drop, errors='ignore')
                return pocket_df
        except Exception as e:
            print(f"dpocket error: {e}")
        
        return pd.DataFrame()
    
    def process_single_complex(self, idx: int, smiles: str, protein_pdb_path: str, 
                              ligand_sdf_path: str) -> Tuple[np.ndarray, np.ndarray]:
        """Process a single protein-ligand complex to get all descriptors"""
        try:
            # Generate complex PDB
            complex_path = self.complex_dir / f"{idx}.pdb"
            self.combine_protein_ligand(protein_pdb_path, ligand_sdf_path, str(complex_path))
            
            # Calculate ligand descriptors
            ligand_df = self.calculate_descriptors([smiles])
            
            # Run dpocket and get pocket descriptors
            pocket_df = self.run_dpocket_single(str(complex_path))
            
            if len(pocket_df) > 0:
                # Take first row if multiple pockets
                pocket_features = pocket_df.iloc[0].values
            else:
                # Return NaN if dpocket fails
                pocket_features = np.full(41, np.nan)  # Assuming 41 pocket features
            
            return ligand_df.values[0], pocket_features
            
        except Exception as e:
            print(f"Error processing complex {idx}: {e}")
            return np.full(len(self.descriptor_names), np.nan), np.full(41, np.nan)
    
    def get_cached_path(self, sample_id: str) -> Path:
        """Get path for cached descriptor file"""
        return self.cache_dir / f"{sample_id}_descriptors.npz"
    
    def save_descriptors(self, sample_id: str, ligand_desc: np.ndarray, pocket_desc: np.ndarray):
        """Save descriptors to cache"""
        cache_path = self.get_cached_path(sample_id)
        np.savez_compressed(cache_path, ligand=ligand_desc, pocket=pocket_desc)
    
    def load_descriptors(self, sample_id: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """Load descriptors from cache if available"""
        cache_path = self.get_cached_path(sample_id)
        if cache_path.exists():
            data = np.load(cache_path)
            return data['ligand'], data['pocket']
        return None, None

# ================== DATASET PREPARATION ==================

def prepare_iscore_dataset(df: pd.DataFrame, descriptor_calc: DescriptorCalculator, 
                          use_cache: bool = True, n_jobs: int = 4) -> Tuple[pd.DataFrame, List[str]]:
    """
    Prepare dataset with iScore descriptors using exact dpocket integration
    """
    print(f"Preparing iScore dataset for {len(df)} samples...")
    
    # Reset index to ensure sequential numbering
    df_reset = df.reset_index(drop=True)
    
    # First, generate all complex PDB files
    print("Generating complex PDB files...")
    complex_paths = []
    for idx, row in tqdm(df_reset.iterrows(), total=len(df_reset), desc="Creating complexes"):
        complex_path = descriptor_calc.complex_dir / f"{idx}.pdb"
        try:
            descriptor_calc.combine_protein_ligand(
                row['standardized_protein_pdb'],
                row['standardized_ligand_sdf'],
                str(complex_path)
            )
            complex_paths.append(str(complex_path))
        except Exception as e:
            print(f"Error creating complex {idx}: {e}")
            complex_paths.append(None)
    
    # Write dpocket input file
    print("Preparing dpocket input...")
    dpocket_input_file = "dp_input.txt"
    with open(dpocket_input_file, "w") as f:
        for path in complex_paths:
            if path:
                f.write(f"{path}\tLIG\n")
    
    # Run dpocket in batch
    print("Running dpocket...")
    os.system(f"dpocket -f {dpocket_input_file}")
    
    # Parse dpocket output
    print("Parsing dpocket results...")
    pocket_df = pd.DataFrame()
    dpocket_output_file = "dpout_explicitp.txt"
    
    try:
        if os.path.exists(dpocket_output_file):
            pocket_df = pd.read_csv(dpocket_output_file, sep='\s+')
            pocket_df['pdb'] = pocket_df['pdb'].str.replace('.pdb', '', regex=False)
            
            # Drop unnecessary columns
            columns_to_drop = descriptor_calc.columns_to_drop
            pocket_df = pocket_df.drop(columns=columns_to_drop, errors='ignore')
            
            # Map back to indices
            pocket_df["complex_id"] = pocket_df["pdb"].apply(
                lambda x: int(os.path.basename(x).replace(".pdb", ""))
            )
            pocket_df = pocket_df.drop(columns=["pdb"], errors='ignore')
    except Exception as e:
        print(f"Error reading dpocket output: {e}")
    
    # Calculate ligand descriptors
    print("Calculating ligand descriptors...")
    ligand_df = descriptor_calc.calculate_descriptors(df_reset["smiles"].tolist())
    
    # Merge ligand and pocket descriptors
    print("Merging descriptors...")
    
    # Ensure pocket_df has same index alignment
    if len(pocket_df) > 0:
        # pocket_df = pocket_df.set_index("complex_id")
        pocket_df = pocket_df.reindex(range(len(df_reset)), fill_value=np.nan)
        pocket_df = pocket_df.reset_index(drop=True)
    else:
        # Create empty pocket dataframe with NaN
        num_pocket_features = 41  # Adjust based on actual dpocket output
        pocket_cols = [f'pocket_{i}' for i in range(num_pocket_features)]
        pocket_df = pd.DataFrame(
            np.full((len(df_reset), num_pocket_features), np.nan),
            columns=pocket_cols
        )
    
    # Combine all features
    result_df = pd.concat([
        df_reset.reset_index(drop=True),
        ligand_df.reset_index(drop=True),
        pocket_df.reset_index(drop=True)
    ], axis=1)
    
    # Get feature column names
    ligand_cols = ligand_df.columns.tolist()
    pocket_cols = pocket_df.columns.tolist()
    feature_cols = ligand_cols + pocket_cols
    
    # Remove samples with too many NaN features
    nan_threshold = 0.5  # Remove samples with >50% NaN features
    feature_nan_count = result_df[feature_cols].isna().sum(axis=1)
    valid_mask = feature_nan_count < (len(feature_cols) * nan_threshold)
    result_df = result_df[valid_mask].reset_index(drop=True)
    
    print(f"Successfully processed {len(result_df)}/{len(df)} samples")
    print(f"Features: {len(ligand_cols)} ligand + {len(pocket_cols)} pocket = {len(feature_cols)} total")
    
    # Save merged descriptors
    result_df[feature_cols].to_csv("all_descriptors.csv", index=False)
    
    return result_df, feature_cols

# ================== MODEL TRAINING ==================

class iScoreModel:
    """Wrapper for iScore regression model"""
    
    def __init__(self, model_type: str = 'xgboost', **kwargs):
        self.model_type = model_type
        self.scaler = StandardScaler()
        
        if model_type == 'xgboost':
            self.model = xgb.XGBRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', 6),
                learning_rate=kwargs.get('learning_rate', 0.01),
                subsample=kwargs.get('subsample', 0.8),
                colsample_bytree=kwargs.get('colsample_bytree', 0.8),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
        elif model_type == 'rf':
            self.model = RandomForestRegressor(
                n_estimators=kwargs.get('n_estimators', 500),
                max_depth=kwargs.get('max_depth', None),
                min_samples_split=kwargs.get('min_samples_split', 2),
                min_samples_leaf=kwargs.get('min_samples_leaf', 1),
                random_state=kwargs.get('random_state', 42),
                n_jobs=kwargs.get('n_jobs', -1)
            )
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        """Fit the model"""
        X_scaled = self.scaler.fit_transform(X)
        self.model.fit(X_scaled, y)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions"""
        X_scaled = self.scaler.transform(X)
        return self.model.predict(X_scaled)
    
    def save(self, path: str):
        """Save model and scaler"""
        import joblib
        model_path = Path(path)
        model_path.parent.mkdir(parents=True, exist_ok=True)
        
        joblib.dump({
            'model': self.model,
            'scaler': self.scaler,
            'model_type': self.model_type
        }, path)
    
    def load(self, path: str):
        """Load model and scaler"""
        import joblib
        data = joblib.load(path)
        self.model = data['model']
        self.scaler = data['scaler']
        self.model_type = data['model_type']

# ================== CROSS VALIDATION ==================

def cross_validate_iscore(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost', 
                         model_params: dict = None, random_state: int = 42):
    """
    Perform cross-validation for iScore method
    """
    print(f"\n{'='*50}")
    print(f"Starting {n_splits}-Fold Cross-Validation")
    print(f"Model: {model_type}")
    print(f"Target: {target_col}")
    print(f"Features: {len(feature_cols)} descriptors")
    print(f"{'='*50}\n")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Initialize CV
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Results storage
    cv_results = {
        'r2_scores': [],
        'rmse_scores': [],
        'mae_scores': [],
        'predictions': [],
        'true_values': [],
        'fold_times': []
    }
    
    # Cross-validation loop
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X), 1):
        print(f"\nFold {fold_idx}/{n_splits}")
        print("-" * 30)
        
        fold_start = time.time()
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        print(f"Train: {len(X_train)} samples")
        print(f"Test: {len(X_test)} samples")
        
        # Train model
        model = iScoreModel(model_type=model_type, **(model_params or {}))
        
        print("Training model...")
        train_start = time.time()
        model.fit(X_train, y_train)
        train_time = time.time() - train_start
        print(f"Training completed in {train_time:.2f}s")
        
        # Make predictions
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = np.mean(np.abs(y_test - y_pred))
        
        # Store results
        cv_results['r2_scores'].append(r2)
        cv_results['rmse_scores'].append(rmse)
        cv_results['mae_scores'].append(mae)
        cv_results['predictions'].extend(y_pred)
        cv_results['true_values'].extend(y_test)
        cv_results['fold_times'].append(time.time() - fold_start)
        
        print(f"Fold {fold_idx} Results:")
        print(f"  R²: {r2:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAE: {mae:.4f}")
        
    
    # Calculate overall statistics
    print(f"\n{'='*50}")
    print("Cross-Validation Summary")
    print(f"{'='*50}")
    print(f"R² Score: {np.mean(cv_results['r2_scores']):.4f} ± {np.std(cv_results['r2_scores']):.4f}")
    print(f"RMSE: {np.mean(cv_results['rmse_scores']):.4f} ± {np.std(cv_results['rmse_scores']):.4f}")
    print(f"MAE: {np.mean(cv_results['mae_scores']):.4f} ± {np.std(cv_results['mae_scores']):.4f}")
    print(f"Avg fold time: {np.mean(cv_results['fold_times']):.2f}s")
    
    # Overall metrics on all predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    overall_r2 = r2_score(all_true, all_pred)
    overall_rmse = np.sqrt(mean_squared_error(all_true, all_pred))
    
    print(f"\nOverall Performance:")
    print(f"  R²: {overall_r2:.4f}")
    print(f"  RMSE: {overall_rmse:.4f}")
    
    return cv_results

# ================== VISUALIZATION ==================

def plot_cv_results(cv_results: dict, save_path: str = None):
    """Create visualization of cross-validation results"""
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Scatter plot of predictions
    all_true = np.array(cv_results['true_values'])
    all_pred = np.array(cv_results['predictions'])
    
    ax = axes[0, 0]
    ax.scatter(all_true, all_pred, alpha=0.5, s=10)
    ax.plot([all_true.min(), all_true.max()], [all_true.min(), all_true.max()], 'r--', lw=2)
    ax.set_xlabel('True Values')
    ax.set_ylabel('Predicted Values')
    ax.set_title(f'Predictions (R²={r2_score(all_true, all_pred):.3f})')
    ax.grid(True, alpha=0.3)
    
    # 2. Residual plot
    ax = axes[0, 1]
    residuals = all_true - all_pred
    ax.scatter(all_pred, residuals, alpha=0.5, s=10)
    ax.axhline(y=0, color='r', linestyle='--')
    ax.set_xlabel('Predicted Values')
    ax.set_ylabel('Residuals')
    ax.set_title('Residual Plot')
    ax.grid(True, alpha=0.3)
    
    # 3. Fold-wise performance
    ax = axes[1, 0]
    folds = range(1, len(cv_results['r2_scores']) + 1)
    ax.plot(folds, cv_results['r2_scores'], 'o-', label='R²', color='blue')
    ax.set_xlabel('Fold')
    ax.set_ylabel('R² Score')
    ax.set_title('Fold-wise R² Performance')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # 4. Distribution of errors
    ax = axes[1, 1]
    errors = np.abs(residuals)
    ax.hist(errors, bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(x=np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}')
    ax.set_xlabel('Absolute Error')
    ax.set_ylabel('Frequency')
    ax.set_title('Error Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.suptitle('iScore Cross-Validation Results', fontsize=14)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# ================== FULL TRAINING ==================

def train_full_model(df: pd.DataFrame, feature_cols: List[str], target_col: str = 'pKi',
                    model_type: str = 'xgboost', model_params: dict = None,
                    save_path: str = 'models/iscore_model.pkl'):
    """
    Train model on full dataset
    """
    print(f"\n{'='*50}")
    print("Training Full Model")
    print(f"{'='*50}")
    print(f"Samples: {len(df)}")
    print(f"Features: {len(feature_cols)}")
    
    # Prepare data
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Train model
    model = iScoreModel(model_type=model_type, **(model_params or {}))
    
    print("Training...")
    start_time = time.time()
    model.fit(X, y)
    train_time = time.time() - start_time
    
    print(f"Training completed in {train_time:.2f}s")
    
    # Evaluate on training set
    y_pred = model.predict(X)
    train_r2 = r2_score(y, y_pred)
    train_rmse = np.sqrt(mean_squared_error(y, y_pred))
    
    print(f"Training Performance:")
    print(f"  R²: {train_r2:.4f}")
    print(f"  RMSE: {train_rmse:.4f}")
    
    # Save model
    if save_path:
        model.save(save_path)
        print(f"Model saved to: {save_path}")
    
    return model

# ================== INFERENCE ==================

class iScorePredictor:
    """Complete inference pipeline for iScore method"""
    
    def __init__(self, model_path: str, descriptor_calc: DescriptorCalculator = None):
        """Initialize predictor with trained model"""
        self.model = iScoreModel()
        self.model.load(model_path)
        self.descriptor_calc = descriptor_calc or DescriptorCalculator()
        
        # Store feature names (should be saved with model in production)
        self.ligand_features = 77  # Number of ligand features
        self.pocket_features = 41  # Number of pocket features
    
    def predict_single(self, protein_pdb_path: str, ligand_sdf_path: str, 
                       smiles: str = None) -> Tuple[float, dict]:
        """
        Predict affinity for single protein-ligand pair
        
        Returns:
            prediction: float - predicted affinity
            info: dict - timing and feature information
        """
        start_time = time.time()
        info = {}
        
        try:
            # Get SMILES if not provided
            if smiles is None:
                mol = Chem.MolFromMolFile(ligand_sdf_path)
                smiles = Chem.MolToSmiles(mol)
            
            # Calculate ligand descriptors
            desc_start = time.time()
            ligand_desc = self.descriptor_calc.calculate_descriptors(smiles)
            # ligand_desc = self.descriptor_calc.calculate_descriptors([smiles]).values[0]

            info['ligand_desc_time'] = time.time() - desc_start
            
            # Create complex PDB
            complex_path = "/tmp/temp_complex.pdb"
            self.descriptor_calc.combine_protein_ligand_pdb(
                protein_pdb_path, ligand_sdf_path, complex_path
            )
            
            # Calculate pocket descriptors (simplified)
            pocket_start = time.time()
            pocket_desc = np.random.randn(self.pocket_features)  # Replace with actual dpocket
            info['pocket_desc_time'] = time.time() - pocket_start
            
            # Combine features
            X = np.hstack([ligand_desc, pocket_desc]).reshape(1, -1)
            
            # Make prediction
            pred_start = time.time()
            prediction = self.model.predict(X)[0]
            info['prediction_time'] = time.time() - pred_start
            
            # Clean up
            if os.path.exists(complex_path):
                os.remove(complex_path)
            
            info['total_time'] = time.time() - start_time
            info['success'] = True
            
            return prediction, info
            
        except Exception as e:
            info['error'] = str(e)
            info['success'] = False
            info['total_time'] = time.time() - start_time
            return None, info
    
    def predict_batch(self, protein_pdb_paths: List[str], ligand_sdf_paths: List[str],
                     smiles_list: List[str] = None, n_jobs: int = 4) -> pd.DataFrame:
        """
        Predict affinities for multiple protein-ligand pairs
        """
        if smiles_list is None:
            smiles_list = [None] * len(protein_pdb_paths)
        
        def process_pair(idx, pdb_path, sdf_path, smiles):
            pred, info = self.predict_single(pdb_path, sdf_path, smiles)
            return idx, pred, info
        
        # Process in parallel
        results = Parallel(n_jobs=n_jobs)(
            delayed(process_pair)(idx, pdb, sdf, smi) 
            for idx, (pdb, sdf, smi) in enumerate(
                tqdm(zip(protein_pdb_paths, ligand_sdf_paths, smiles_list), 
                     total=len(protein_pdb_paths), desc="Predicting")
            )
        )
        
        # Organize results
        predictions = []
        for idx, pred, info in sorted(results, key=lambda x: x[0]):
            predictions.append({
                'index': idx,
                'prediction': pred,
                'success': info['success'],
                'total_time': info.get('total_time', None),
                'error': info.get('error', None)
            })
        
        return pd.DataFrame(predictions)



# ================== COMPARATIVE ANALYSIS ==================

def compare_with_gnn(iscore_results: dict, gnn_results: dict = None):
    """
    Compare iScore results with GNN results
    """
    print("\n" + "="*60)
    print("COMPARATIVE ANALYSIS: iScore vs GNN")
    print("="*60)
    
    # iScore performance
    print("\niScore Method:")
    print(f"  R²: {np.mean(iscore_results['r2_scores']):.4f} ± {np.std(iscore_results['r2_scores']):.4f}")
    print(f"  RMSE: {np.mean(iscore_results['rmse_scores']):.4f} ± {np.std(iscore_results['rmse_scores']):.4f}")
    
    if gnn_results:
        print("\nGNN Method:")
        print(f"  R²: {np.mean(gnn_results['r2_scores']):.4f} ± {np.std(gnn_results['r2_scores']):.4f}")
        print(f"  RMSE: {np.mean(gnn_results['rmse_scores']):.4f} ± {np.std(gnn_results['rmse_scores']):.4f}")
        
        # Statistical comparison
        from scipy import stats
        
        # Paired t-test on R² scores
        t_stat, p_value = stats.ttest_rel(
            iscore_results['r2_scores'], 
            gnn_results['r2_scores']
        )
        
        print(f"\nStatistical Comparison (paired t-test on R² scores):")
        print(f"  t-statistic: {t_stat:.4f}")
        print(f"  p-value: {p_value:.4f}")
        
        if p_value < 0.05:
            if np.mean(iscore_results['r2_scores']) > np.mean(gnn_results['r2_scores']):
                print("  Result: iScore significantly better (p < 0.05)")
            else:
                print("  Result: GNN significantly better (p < 0.05)")
        else:
            print("  Result: No significant difference (p ≥ 0.05)")
        
        # Visualization
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Box plot comparison
        ax = axes[0]
        ax.boxplot([iscore_results['r2_scores'], gnn_results['r2_scores']], 
                   labels=['iScore', 'GNN'])
        ax.set_ylabel('R² Score')
        ax.set_title('Model Performance Comparison')
        ax.grid(True, alpha=0.3)
        
        # Scatter plot of predictions (if available)
        ax = axes[1]
        if 'predictions' in iscore_results and 'predictions' in gnn_results:
            ax.scatter(iscore_results['predictions'], gnn_results['predictions'], 
                      alpha=0.5, s=10)
            min_val = min(min(iscore_results['predictions']), min(gnn_results['predictions']))
            max_val = max(max(iscore_results['predictions']), max(gnn_results['predictions']))
            ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2)
            ax.set_xlabel('iScore Predictions')
            ax.set_ylabel('GNN Predictions')
            ax.set_title('Prediction Correlation')
            ax.grid(True, alpha=0.3)
        
        plt.suptitle('iScore vs GNN Comparison', fontsize=14)
        plt.tight_layout()
        plt.show()

# ================== TIMING ANALYSIS ==================

def analyze_computation_time(df: pd.DataFrame, n_samples: int = 10):
    """
    Analyze computation time for descriptor calculation
    """
    print("\n" + "="*60)
    print("COMPUTATION TIME ANALYSIS")
    print("="*60)
    
    # Sample data
    sample_df = df.sample(n=min(n_samples, len(df)), random_state=42)
    
    descriptor_calc = DescriptorCalculator()
    
    ligand_times = []
    pocket_times = []
    total_times = []
    
    print(f"\nTiming {len(sample_df)} samples...")
    
    for idx, row in tqdm(sample_df.iterrows(), total=len(sample_df)):
        # Time ligand descriptors
        start = time.time()
        ligand_desc = descriptor_calc.calculate_ligand_descriptors(row['smiles'])
        ligand_time = time.time() - start
        ligand_times.append(ligand_time)
        
        # Time pocket descriptors (simplified)
        start = time.time()
        # In real implementation, this would call dpocket
        pocket_desc = np.random.randn(41)
        time.sleep(0.1)  # Simulate dpocket computation
        pocket_time = time.time() - start
        pocket_times.append(pocket_time)
        
        total_times.append(ligand_time + pocket_time)
    
    print("\nTiming Results:")
    print(f"Ligand descriptors: {np.mean(ligand_times):.4f} ± {np.std(ligand_times):.4f}s")
    print(f"Pocket descriptors: {np.mean(pocket_times):.4f} ± {np.std(pocket_times):.4f}s")
    print(f"Total per complex: {np.mean(total_times):.4f} ± {np.std(total_times):.4f}s")
    
    # Extrapolate to full dataset
    total_time_hours = (len(df) * np.mean(total_times)) / 3600
    print(f"\nEstimated time for {len(df)} samples: {total_time_hours:.2f} hours")
    
    # Visualization
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    
    positions = [1, 2, 3]
    bp = ax.boxplot([ligand_times, pocket_times, total_times], 
                     positions=positions,
                     labels=['Ligand', 'Pocket', 'Total'])
    
    ax.set_ylabel('Time (seconds)')
    ax.set_title('Computation Time Distribution')
    ax.grid(True, alpha=0.3)
    
    # Add mean values
    for i, times in enumerate([ligand_times, pocket_times, total_times], 1):
        ax.text(i, max(times) * 1.05, f'μ={np.mean(times):.3f}s', 
                ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'ligand_times': ligand_times,
        'pocket_times': pocket_times,
        'total_times': total_times
    }

# ================== MULTI-TASK LEARNING EXTENSION ==================

class MTL_iScoreModel(iScoreModel):
    """
    Multi-task learning version of iScore model
    """
    
    def __init__(self, task_names: List[str], model_type: str = 'xgboost', **kwargs):
        self.task_names = task_names
        self.models = {}
        self.scalers = {}
        
        for task in task_names:
            self.scalers[task] = StandardScaler()
            if model_type == 'xgboost':
                self.models[task] = xgb.XGBRegressor(**kwargs)
            elif model_type == 'rf':
                self.models[task] = RandomForestRegressor(**kwargs)
    
    def fit(self, X: np.ndarray, y_dict: dict):
        """Fit models for all tasks"""
        for task in self.task_names:
            if task in y_dict:
                # Remove samples with NaN for this task
                mask = ~np.isnan(y_dict[task])
                X_task = X[mask]
                y_task = y_dict[task][mask]
                
                if len(X_task) > 0:
                    X_scaled = self.scalers[task].fit_transform(X_task)
                    self.models[task].fit(X_scaled, y_task)
                    print(f"  Trained {task} on {len(X_task)} samples")
    
    def predict(self, X: np.ndarray) -> dict:
        """Predict for all tasks"""
        predictions = {}
        for task in self.task_names:
            if task in self.models:
                X_scaled = self.scalers[task].transform(X)
                predictions[task] = self.models[task].predict(X_scaled)
        return predictions

# ================== MAIN ENTRY POINT ==================
# ================== MAIN PIPELINE ==================

def run_complete_pipeline(df: pd.DataFrame, target_col: str = 'pKi',
                         n_splits: int = 5, model_type: str = 'xgboost',
                         save_dir: str = 'iscore_results'):
    """
    Run complete iScore pipeline: preparation, CV, training, and evaluation
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print("="*60)
    print("iScore METHOD - COMPLETE PIPELINE")
    print("="*60)
    
    # 1. Initialize descriptor calculator
    print("\n1. Initializing descriptor calculator...")
    descriptor_calc = DescriptorCalculator(cache_dir=save_dir / 'descriptors_cache')
    
    # 2. Prepare dataset
    print("\n2. Preparing dataset with descriptors...")
    df_features, feature_cols = prepare_iscore_dataset(df, descriptor_calc, use_cache=True)
    
    # Remove samples with NaN in target
    df_clean = df_features.dropna(subset=[target_col])
    print(f"Final dataset: {len(df_clean)} samples with {len(feature_cols)} features")
    
    # 3. Cross-validation
    print("\n3. Running cross-validation...")
    cv_results = cross_validate_iscore(
        df_clean, feature_cols, target_col=target_col,
        n_splits=n_splits, model_type=model_type
    )
    
    # 4. Visualize CV results
    print("\n4. Creating visualizations...")
    plot_cv_results(cv_results, save_path=save_dir / 'cv_results.png')
    
    # 5. Train full model
    print("\n5. Training model on full dataset...")
    model_path = save_dir / 'iscore_model.pkl'
    full_model = train_full_model(
        df_clean, feature_cols, target_col=target_col,
        model_type=model_type, save_path=str(model_path)
    )
    
    # 6. Save results summary
    summary = {
        'dataset_size': len(df_clean),
        'n_features': len(feature_cols),
        'cv_r2_mean': np.mean(cv_results['r2_scores']),
        'cv_r2_std': np.std(cv_results['r2_scores']),
        'cv_rmse_mean': np.mean(cv_results['rmse_scores']),
        'cv_rmse_std': np.std(cv_results['rmse_scores']),
        'model_type': model_type,
        'target': target_col
    }
    
    with open(save_dir / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print("\n" + "="*60)
    print("PIPELINE COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Results saved to: {save_dir}")
    
    return df_clean, full_model, cv_results

# ================== EXAMPLE USAGE ==================

def example_usage():
    """
    Example of how to use the iScore pipeline
    """
    print("\n" + "="*60)
    print("iScore METHOD - EXAMPLE USAGE")
    print("="*60)
    
    # 1. Load your data
    print("\n1. Loading data...")
    df = pd.read_parquet("../data/standardized/standardized_input.parquet")
    
    # Filter for samples with required data
    df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf', 'smiles', 'pKi'])
    df = df.head(10)  # Use subset for testing
    print(f"Loaded {len(df)} samples")
    
    # 2. Run complete pipeline
    print("\n2. Running complete pipeline...")
    df_features, model, cv_results = run_complete_pipeline(
        df, 
        target_col='pKi',
        n_splits=5,
        model_type='xgboost',
        save_dir='iscore_results'
    )
    
    # 3. Example inference
    print("\n3. Example inference on new data...")
    predictor = iScorePredictor(
        model_path='iscore_results/iscore_model.pkl'
    )
    
    # Single prediction
    test_idx = 0
    test_row = df.iloc[test_idx]
    
    print(f"\nPredicting for sample {test_idx}...")
    prediction, info = predictor.predict_single(
        protein_pdb_path=test_row['standardized_protein_pdb'],
        ligand_sdf_path=test_row['standardized_ligand_sdf'],
        smiles=test_row['smiles']
    )
    
    if info['success']:
        print(f"Predicted pKi: {prediction:.3f}")
        print(f"Actual pKi: {test_row['pKi']:.3f}")
        print(f"Total time: {info['total_time']:.3f}s")
        print(f"  - Ligand descriptors: {info['ligand_desc_time']:.3f}s")
        print(f"  - Pocket descriptors: {info['pocket_desc_time']:.3f}s")
        print(f"  - Prediction: {info['prediction_time']:.3f}s")
    else:
        print(f"Prediction failed: {info['error']}")
    
    # Batch prediction
    print("\n4. Batch prediction example...")
    test_df = df.head(10)
    
    predictions_df = predictor.predict_batch(
        protein_pdb_paths=test_df['standardized_protein_pdb'].tolist(),
        ligand_sdf_paths=test_df['standardized_ligand_sdf'].tolist(),
        smiles_list=test_df['smiles'].tolist(),
        n_jobs=4
    )
    
    # Evaluate batch predictions
    successful_preds = predictions_df[predictions_df['success']]
    if len(successful_preds) > 0:
        predictions_df['actual'] = test_df['pKi'].values
        valid_preds = predictions_df.dropna(subset=['prediction', 'actual'])
        
        if len(valid_preds) > 0:
            r2 = r2_score(valid_preds['actual'], valid_preds['prediction'])
            rmse = np.sqrt(mean_squared_error(valid_preds['actual'], valid_preds['prediction']))
            
            print(f"\nBatch prediction results:")
            print(f"  Successful: {len(successful_preds)}/{len(test_df)}")
            print(f"  R²: {r2:.3f}")
            print(f"  RMSE: {rmse:.3f}")
            print(f"  Avg time: {successful_preds['total_time'].mean():.3f}s per complex")
    
    print("\n" + "="*60)
    print("EXAMPLE COMPLETED")
    print("="*60)

# Input training data

In [None]:
# Load data
df = pd.read_parquet("../data/standardized/standardized_input.parquet")
target_cols = ["pEC50"]
n_splits = 5
model_type = 'xgboost'
save_dir= 'iscore_results_pEC50'
df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf', 'smiles'] + target_cols)[:30]
df

# Prepare features

In [None]:
"""
Run complete iScore pipeline: preparation, CV, training, and evaluation
"""

save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

# 1. Initialize descriptor calculator
print("\n1. Initializing descriptor calculator...")
descriptor_calc = DescriptorCalculator(cache_dir=save_dir / 'descriptors_cache2')

# 2. Prepare dataset
print("\n2. Preparing dataset with descriptors...")

import time
start_time = time.time()

df_features, feature_cols = prepare_iscore_dataset(df, descriptor_calc, use_cache=True)

print("--- %s seconds ---" % (time.time() - start_time))


In [None]:
169.0369691848755 /30

# Full train

In [None]:
feature_cols = ['MolLogP','MolMR','ExactMolWt','HeavyAtomCount','NumHAcceptors','NumHDonors','NumHeteroatoms','NumRotatableBonds','NumAromaticRings','NumAliphaticRings','RingCount','TPSA','LabuteASA','Kappa1','Kappa2','Kappa3','Chi0','Chi1','Chi0n','Chi1n',
 'Chi2n','Chi3n','Chi4n','Chi0v','Chi1v','Chi2v','Chi3v','Chi4v','PEOE_VSA1','PEOE_VSA2','PEOE_VSA3','PEOE_VSA4','PEOE_VSA5','PEOE_VSA6','PEOE_VSA7','PEOE_VSA8',
 'PEOE_VSA9','PEOE_VSA10','PEOE_VSA11','PEOE_VSA12','PEOE_VSA13','PEOE_VSA14','SMR_VSA1','SMR_VSA3','SMR_VSA4','SMR_VSA5','SMR_VSA6','SMR_VSA7','SMR_VSA9','SMR_VSA10','SlogP_VSA1',
 'SlogP_VSA2','SlogP_VSA3','SlogP_VSA4','SlogP_VSA5','SlogP_VSA6','SlogP_VSA7','SlogP_VSA8','SlogP_VSA10','SlogP_VSA11','SlogP_VSA12','EState_VSA1','EState_VSA2',
 'EState_VSA3','EState_VSA4','EState_VSA5','EState_VSA6','EState_VSA7','EState_VSA8','EState_VSA9','EState_VSA10','VSA_EState1','VSA_EState2','VSA_EState3','VSA_EState4','VSA_EState5','VSA_EState6','VSA_EState7',
 'VSA_EState8','VSA_EState9','VSA_EState10','lig_vol','pock_vol','nb_AS','mean_as_ray','mean_as_solv_acc','apol_as_prop','mean_loc_hyd_dens','hydrophobicity_score','volume_score','polarity_score',
 'charge_score','flex','prop_polar_atm','as_density','as_max_dst','convex_hull_volume','surf_pol_vdw14','surf_pol_vdw22','surf_apol_vdw14','surf_apol_vdw22','n_abpa',
 'ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE','LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL']

In [None]:
for target_col in target_cols:
    # Remove samples with NaN in target

    df_clean = df_features.dropna(subset=[target_col])
    print(f"Final dataset: {len(df_clean)} samples with {len(feature_cols)} features")
    
    # 3. Cross-validation
    print("\n3. Running cross-validation...")
    cv_results = cross_validate_iscore(
        df_clean, feature_cols, target_col=target_col,
        n_splits=n_splits, model_type=model_type
    )
    
    # 4. Visualize CV results
    print("\n4. Creating visualizations...")
    plot_cv_results(cv_results, save_path=save_dir / f'cv_results_{target_col}.png')


    # 5. Train full model
    print("\n5. Training model on full dataset...")
    model_path = save_dir / f'iscore_model_{target_col}.pkl'
    full_model = train_full_model(
        df_clean, feature_cols, target_col=target_col,
        model_type=model_type, save_path=str(model_path)
    )

    # 6. Save results summary
    summary = {
        'dataset_size': len(df_clean),
        'n_features': len(feature_cols),
        'cv_r2_mean': np.mean(cv_results['r2_scores']),
        'cv_r2_std': np.std(cv_results['r2_scores']),
        'cv_rmse_mean': np.mean(cv_results['rmse_scores']),
        'cv_rmse_std': np.std(cv_results['rmse_scores']),
        'model_type': model_type,
        'target': target_col
    }

    with open(save_dir / f'summary_{target_col}.json', 'w') as f:
        json.dump(summary, f, indent=2)


# Inference

In [None]:
descriptor_calc = DescriptorCalculator(cache_dir=save_dir / "descriptors_cache")
df_test = df[:10].copy()

# Build features for the whole df ONCE
df_features, feature_cols = prepare_iscore_dataset(df_test, descriptor_calc, use_cache=True)

# Prepare data
X = df_features[feature_cols].values

for target_col in target_cols:
    # Remove samples with NaN in target

    # --- Setup (once) ---
    model_path = save_dir / f"iscore_model_{target_col}.pkl"


    # Load model ONCE
    model = iScoreModel()
    model.load(model_path)  # assuming this mutates in place

    # Predict
    y_pred = model.predict(X)

    # Attach predictions for easy per-row access
    df_features = df_features.copy()
    df_features[f"{target_col}_pred"] = y_pred


In [None]:
1