# Data Preparation and Standardization for MTL-GNN-DTA

This notebook provides a complete pipeline for:
1. Loading and combining multiple affinity datasets
2. Standardizing protein and ligand structures
3. Computing molecular properties and ligand efficiency
4. Quality filtering and validation
5. Preparing data for model training

---

## 1. Setup and Imports

In [3]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path for imports
sys.path.append('../../')

# Standard imports
import pandas as pd
import numpy as np
from pathlib import Path
import json
import pickle
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from sklearn.model_selection import train_test_split

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Chemistry
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski, rdMolDescriptors, QED, Draw
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.warning')

# Protein processing
from Bio.PDB import PDBParser, PDBIO
from pdbfixer import PDBFixer
from openmm.app import PDBFile

# MTL-GNN-DTA imports
from mtl_gnn_dta.preprocessing import pdb_processor, sdf_processor, validator
from mtl_gnn_dta.utils import setup_logging

# Setup
setup_logging()
N_PROC = cpu_count() - 1

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

print(f"MTL-GNN-DTA Data Preparation Pipeline")
print(f"="*50)
print(f"Using {N_PROC} CPU cores for parallel processing")

ModuleNotFoundError: No module named 'pdbfixer'

In [2]:
pip install pdbfixer

[31mERROR: Could not find a version that satisfies the requirement pdbfixer (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for pdbfixer[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


## 2. Define Helper Functions for Data Curation

In [None]:
# Define constants
POLAR_HEAVY = {7, 8, 15, 16}  # N, O, P, S

def dg_to_kd(delta_g_kcal, temp_k=298.15):
    """Convert ΔG to Kd"""
    import math
    R = 1.987e-3  # kcal/mol·K
    kd_molar = -math.log10(math.exp(delta_g_kcal / (R * temp_k)))
    return kd_molar

def standardize_smiles_from_sdf(sdf_path):
    """Standardize SMILES from SDF file"""
    POLAR = {7, 8, 15, 16}  # N, O, P, S
    
    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False, sanitize=False)
        if mol is None:
            return None
        
        # Sanitize
        sanitize_result = Chem.SanitizeMol(mol, catchErrors=True)
        if sanitize_result != Chem.SanitizeFlags.SANITIZE_NONE:
            mol_cleaned = rdMolStandardize.Cleanup(mol)
            if mol_cleaned is None:
                return None
            mol = mol_cleaned
        
        # Remove non-polar hydrogens
        to_del = []
        for a in mol.GetAtoms():
            if a.GetAtomicNum() == 1:
                nbs = a.GetNeighbors()
                if len(nbs) > 0 and nbs[0].GetAtomicNum() not in POLAR:
                    to_del.append(a.GetIdx())
        
        if to_del:
            em = Chem.EditableMol(mol)
            for idx in sorted(to_del, reverse=True):
                em.RemoveAtom(idx)
            mol = em.GetMol()
        
        # Update and standardize
        mol.UpdatePropertyCache(strict=False)
        AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
        Chem.AssignStereochemistry(mol, force=True, cleanIt=False)
        
        # Add polar hydrogens
        targets = []
        for a in mol.GetAtoms():
            if a.GetAtomicNum() in POLAR and a.GetNumImplicitHs() > 0:
                targets.append(a.GetIdx())
        
        if targets:
            mol = Chem.AddHs(mol, addCoords=False, onlyOnAtoms=targets)
        
        # Final standardization
        mol = rdMolStandardize.Cleanup(mol)
        mol = rdMolStandardize.Normalizer().normalize(mol)
        mol = rdMolStandardize.FragmentParent(mol)
        mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
        
        # Clear isotopes
        for atom in mol.GetAtoms():
            atom.SetIsotope(0)
        
        Chem.AssignStereochemistry(mol, force=True, cleanIt=True)
        return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
        
    except Exception as e:
        print(f"Error standardizing {sdf_path}: {e}")
        return None

def compute_props(smiles):
    """Compute molecular properties"""
    if not isinstance(smiles, str) or smiles.strip() == '':
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 
                                  'NumHDonors', 'NumHAcceptors', 'NumRotatableBonds', 
                                  'TPSA', 'LogP']}
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 
                                  'NumHDonors', 'NumHAcceptors', 'NumRotatableBonds', 
                                  'TPSA', 'LogP']}
    
    return {
        'InChIKey': Chem.MolToInchiKey(mol),
        'MolWt': Descriptors.MolWt(mol),
        'HeavyAtomCount': mol.GetNumHeavyAtoms(),
        'QED': QED.qed(mol),
        'NumHDonors': Lipinski.NumHDonors(mol),
        'NumHAcceptors': Lipinski.NumHAcceptors(mol),
        'NumRotatableBonds': Lipinski.NumRotatableBonds(mol),
        'TPSA': rdMolDescriptors.CalcTPSA(mol),
        'LogP': Crippen.MolLogP(mol)
    }

## 3. Load and Combine Multiple Datasets

In [None]:
# Define data directories
data_dir = Path("../data/curated/exp/")
data_dir.mkdir(parents=True, exist_ok=True)

# Load all available datasets
df_list = []
dataset_info = []

# List of expected dataset files
dataset_files = [
    "pKd_FEP_Wang_2015.parquet",
    "pKd_FEP_Zariquiey_extended_Wang_2015.parquet",
    "pKi_PDBbind2020.parquet",
    "pKd_PDBbind2020.parquet",
    "pKi_HiQBind.parquet",
    "pKd_HiQBind.parquet",
    "pIC50_HiQBind.parquet",
    "pEC50_HiQBind.parquet",
    "pKi_BioLip2.parquet",
    "pKd_BioLip2.parquet",
    "pIC50_BioLip2.parquet",
    "pEC50_BioLip2.parquet",
    "pKi_BindingNetv1.parquet",
    "pKd_BindingNetv1.parquet",
    "pIC50_BindingNetv1.parquet",
    "pEC50_BindingNetv1.parquet",
    "pKi_BindingNetv2.parquet",
    "pKd_BindingNetv2.parquet",
    "pIC50_BindingNetv2.parquet",
    "pEC50_BindingNetv2.parquet"
]

print("Loading datasets...")
for fname in dataset_files:
    full_path = data_dir / fname
    if full_path.exists():
        try:
            df = pd.read_parquet(full_path)
            df["source_file"] = fname
            df["is_experimental"] = any(x in fname for x in ["BioLip", "PDBbind", "HiQBind"])
            df_list.append(df)
            
            dataset_info.append({
                'file': fname,
                'samples': len(df),
                'experimental': df["is_experimental"].iloc[0]
            })
            print(f"  ✓ {fname}: {len(df)} samples")
        except Exception as e:
            print(f"  ✗ Error loading {fname}: {e}")

# Combine all datasets
if df_list:
    df_combined = pd.concat(df_list, ignore_index=True)
    print(f"\nTotal samples loaded: {len(df_combined):,}")
else:
    print("No datasets found. Creating sample data...")
    # Create sample data for demonstration
    df_combined = pd.DataFrame({
        'protein_pdb_path': ['data/protein_001.pdb'] * 1000,
        'ligand_sdf_path': ['data/ligand_001.sdf'] * 1000,
        'smiles': ['CC(C)CC1=CC=C(C=C1)C(C)C(O)=O'] * 1000,
        'pKi': np.random.normal(7.0, 1.5, 1000),
        'pKd': np.random.normal(7.2, 1.3, 1000),
        'pIC50': np.random.normal(6.8, 1.4, 1000),
        'pEC50': np.random.normal(6.5, 1.2, 1000),
        'resolution': np.random.uniform(1.5, 2.5, 1000),
        'source_file': 'sample_data',
        'is_experimental': True
    })

# Clean up source file names
df_combined["source_file"] = df_combined["source_file"].str.replace(r"^[^_]*_", "", regex=True).str.replace(".parquet", "", regex=False)

# Replace infinite values
df_combined.replace([np.inf, -np.inf], np.nan, inplace=True)

print(f"\nDataset shape: {df_combined.shape}")
print(f"Columns: {df_combined.columns.tolist()}")

## 4. Data Quality Analysis

In [None]:
# Visualize data distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Dataset distribution
ax = axes[0, 0]
source_counts = df_combined["source_file"].value_counts()
source_counts.head(10).plot(kind='barh', ax=ax)
ax.set_xlabel("Count")
ax.set_title("Top 10 Data Sources")
ax.set_xscale('log')

# Experimental vs Computational
ax = axes[0, 1]
exp_counts = df_combined["is_experimental"].value_counts()
exp_counts.plot(kind='bar', ax=ax)
ax.set_xticklabels(['Computational', 'Experimental'], rotation=0)
ax.set_ylabel("Count")
ax.set_title("Experimental vs Computational Data")
ax.set_yscale('log')

# Resolution distribution (if available)
ax = axes[1, 0]
if 'resolution' in df_combined.columns:
    resolution_data = df_combined['resolution'].dropna()
    if len(resolution_data) > 0:
        resolution_data.hist(bins=50, ax=ax)
        ax.axvline(2.5, color='red', linestyle='--', label='2.5Å cutoff')
        ax.set_xlabel("Resolution (Å)")
        ax.set_ylabel("Count")
        ax.set_title("Crystal Structure Resolution")
        ax.legend()

# Missing data analysis
ax = axes[1, 1]
task_cols = ['pKi', 'pKd', 'pIC50', 'pEC50']
missing_data = []
for col in task_cols:
    if col in df_combined.columns:
        missing_pct = df_combined[col].isna().sum() / len(df_combined) * 100
        missing_data.append(missing_pct)
    else:
        missing_data.append(100)

ax.bar(task_cols, missing_data)
ax.set_ylabel("Missing (%)")
ax.set_title("Missing Data by Task")
ax.set_ylim(0, 100)

plt.tight_layout()
plt.show()

## 5. Activity Value Distribution Analysis

In [None]:
# Analyze activity value distributions
task_cols = ['pKi', 'pKd', 'pIC50', 'pEC50', 'pKd (Wang, FEP)']
available_tasks = [col for col in task_cols if col in df_combined.columns]

if available_tasks:
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()
    
    for i, col in enumerate(available_tasks[:6]):
        ax = axes[i]
        data = df_combined[col].dropna()
        
        if len(data) > 0:
            # Histogram with KDE
            data.hist(bins=50, ax=ax, alpha=0.7, edgecolor='black')
            
            # Add statistics
            mean_val = data.mean()
            median_val = data.median()
            ax.axvline(mean_val, color='red', linestyle='--', label=f'Mean: {mean_val:.2f}')
            ax.axvline(median_val, color='green', linestyle='--', label=f'Median: {median_val:.2f}')
            
            ax.set_xlabel(col)
            ax.set_ylabel('Count')
            ax.set_title(f'{col} (n={len(data):,})')
            ax.legend(fontsize=8)
    
    # Hide unused subplots
    for i in range(len(available_tasks), 6):
        axes[i].set_visible(False)
    
    plt.suptitle('Activity Value Distributions', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nActivity Value Statistics:")
    print("="*60)
    for col in available_tasks:
        data = df_combined[col].dropna()
        if len(data) > 0:
            print(f"{col:20s}: n={len(data):6d}, mean={data.mean():6.2f}, "
                  f"std={data.std():5.2f}, min={data.min():5.2f}, max={data.max():5.2f}")

## 6. Initial Data Filtering

In [None]:
print("Initial data filtering...")
print(f"Starting samples: {len(df_combined):,}")

# Filter by activity values (remove outliers)
for col in available_tasks:
    if col in df_combined.columns:
        before = len(df_combined)
        df_combined = df_combined[df_combined[col].isna() | (df_combined[col] > 3)]
        df_combined = df_combined[df_combined[col].isna() | (df_combined[col] < 15)]
        after = len(df_combined)
        if before != after:
            print(f"  Filtered {before - after} samples with {col} outliers")

# Filter by resolution
if 'resolution' in df_combined.columns:
    before = len(df_combined)
    df_combined = df_combined[df_combined['resolution'].isna() | (df_combined['resolution'] > 0)]
    df_combined = df_combined[df_combined['resolution'].isna() | (df_combined['resolution'] < 3)]
    after = len(df_combined)
    if before != after:
        print(f"  Filtered {before - after} samples with poor resolution")

print(f"Samples after filtering: {len(df_combined):,}")

## 7. Standardize Protein Structures

In [None]:
# Create standardization directories
protein_out_dir = Path("../data/standardized_clean/protein")
ligand_out_dir = Path("../data/standardized_clean/ligand")
protein_out_dir.mkdir(parents=True, exist_ok=True)
ligand_out_dir.mkdir(parents=True, exist_ok=True)

# Sample a subset for demonstration (remove this line for full processing)
df_sample = df_combined.sample(min(100, len(df_combined)), random_state=42).reset_index(drop=True)

print(f"Processing {len(df_sample)} samples...")

In [None]:
# Standardize proteins (if paths exist)
if 'protein_pdb_path' in df_sample.columns:
    from mtl_gnn_dta.preprocessing.pdb_processor import clean_protein_structure
    
    def process_protein(args):
        idx, row = args
        try:
            in_path = row['protein_pdb_path']
            out_path = protein_out_dir / f"{idx}.pdb"
            
            if os.path.exists(in_path):
                success = clean_protein_structure(in_path, str(out_path))
                if success:
                    return str(out_path)
        except Exception as e:
            print(f"Error processing protein {idx}: {e}")
        return None
    
    print("Standardizing protein structures...")
    protein_args = [(idx, row) for idx, row in df_sample.iterrows()]
    
    with Pool(min(4, N_PROC)) as pool:
        protein_paths = list(tqdm(
            pool.imap(process_protein, protein_args),
            total=len(protein_args),
            desc="Proteins"
        ))
    
    df_sample['standardized_protein_pdb'] = protein_paths
else:
    df_sample['standardized_protein_pdb'] = None
    print("No protein paths found")

## 8. Standardize Ligand Structures

In [None]:
# Standardize ligands (if paths exist)
if 'ligand_sdf_path' in df_sample.columns:
    from mtl_gnn_dta.preprocessing.sdf_processor import standardize_ligand
    
    def process_ligand(args):
        idx, row = args
        try:
            in_path = row['ligand_sdf_path']
            out_path = ligand_out_dir / f"{idx}.sdf"
            
            if os.path.exists(in_path):
                success = standardize_ligand(in_path, str(out_path))
                if success:
                    return str(out_path)
        except Exception as e:
            print(f"Error processing ligand {idx}: {e}")
        return None
    
    print("Standardizing ligand structures...")
    ligand_args = [(idx, row) for idx, row in df_sample.iterrows()]
    
    with Pool(min(4, N_PROC)) as pool:
        ligand_paths = list(tqdm(
            pool.imap(process_ligand, ligand_args),
            total=len(ligand_args),
            desc="Ligands"
        ))
    
    df_sample['standardized_ligand_sdf'] = ligand_paths
else:
    df_sample['standardized_ligand_sdf'] = None
    print("No ligand paths found")

## 9. Standardize SMILES

In [None]:
# Standardize SMILES
print("Standardizing SMILES...")

# If we have standardized SDF files, extract SMILES from them
if 'standardized_ligand_sdf' in df_sample.columns:
    smiles_list = []
    for sdf_path in tqdm(df_sample['standardized_ligand_sdf'], desc="Extracting SMILES"):
        if sdf_path and os.path.exists(sdf_path):
            smiles = standardize_smiles_from_sdf(sdf_path)
        else:
            smiles = None
        smiles_list.append(smiles)
    df_sample['std_smiles'] = smiles_list
# Otherwise, standardize from existing SMILES
elif 'smiles' in df_sample.columns:
    df_sample['std_smiles'] = df_sample['smiles'].apply(
        lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x), canonical=True) 
        if pd.notna(x) and Chem.MolFromSmiles(x) else None
    )

# Remove samples with invalid SMILES
before = len(df_sample)
df_sample = df_sample[df_sample['std_smiles'].notna()]
after = len(df_sample)
print(f"Removed {before - after} samples with invalid SMILES")
print(f"Remaining samples: {len(df_sample)}")

## 10. Compute Molecular Properties

In [None]:
# Compute molecular properties
print("Computing molecular properties...")

from joblib import Parallel, delayed

smiles_list = df_sample['std_smiles'].tolist()
props = Parallel(n_jobs=N_PROC)(
    delayed(compute_props)(smi) for smi in tqdm(smiles_list, desc="Properties")
)

props_df = pd.DataFrame(props)
df_sample = pd.concat([df_sample.reset_index(drop=True), props_df], axis=1)

print(f"Computed properties for {len(df_sample)} molecules")

## 11. Calculate Ligand Efficiency

In [None]:
# Calculate ligand efficiency
print("Calculating ligand efficiency...")

# Find available activity columns
activity_cols = []
for col in df_sample.columns:
    col_lower = col.strip().lower()
    if col_lower in ['pki', 'pkd', 'pec50', 'pic50'] or 'pkd' in col_lower:
        activity_cols.append(col)

print(f"Found activity columns: {activity_cols}")

# Calculate LE for each activity
for col in activity_cols:
    le_col = f'LE_{col}'
    le_norm_col = f'LEnorm_{col}'
    
    df_sample[le_col] = df_sample.apply(
        lambda row: row[col] / row['HeavyAtomCount']
        if pd.notnull(row[col]) and pd.notnull(row['HeavyAtomCount']) and row['HeavyAtomCount'] > 0
        else None,
        axis=1
    )
    
    df_sample[le_norm_col] = df_sample.apply(
        lambda row: row[le_col] / row['MolWt']
        if pd.notnull(row.get(le_col)) and pd.notnull(row.get('MolWt')) and row['MolWt'] > 0
        else None,
        axis=1
    )

# Calculate mean LE
le_cols = [c for c in df_sample.columns if c.startswith("LE_") and not c.startswith("LEnorm_")]
le_norm_cols = [c for c in df_sample.columns if c.startswith("LEnorm_")]

if le_cols:
    df_sample['LE'] = df_sample[le_cols].mean(axis=1, skipna=True)
if le_norm_cols:
    df_sample['LE_norm'] = df_sample[le_norm_cols].mean(axis=1, skipna=True)

print("Ligand efficiency calculated")

## 12. Visualize Molecular Properties

In [None]:
# Plot property distributions
property_cols = ['LogP', 'QED', 'MolWt', 'HeavyAtomCount', 'TPSA', 'NumRotatableBonds']
if 'LE' in df_sample.columns:
    property_cols.append('LE')
if 'LE_norm' in df_sample.columns:
    property_cols.append('LE_norm')

available_props = [col for col in property_cols if col in df_sample.columns]

if available_props:
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    for i, col in enumerate(available_props[:8]):
        ax = axes[i]
        data = df_sample[col].dropna()
        
        if len(data) > 0:
            data.hist(bins=30, ax=ax, alpha=0.7, edgecolor='black')
            ax.set_xlabel(col)
            ax.set_ylabel('Count')
            ax.set_title(f'{col} Distribution')
            ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(len(available_props), 8):
        axes[i].set_visible(False)
    
    plt.suptitle('Molecular Property Distributions', fontsize=14)
    plt.tight_layout()
    plt.show()

## 13. Quality-Based Filtering

In [None]:
# Define filtering criteria
print("\nApplying quality filters...")
print(f"Starting samples: {len(df_sample)}")

# Count carbon atoms
def count_carbon_atoms(smiles):
    if pd.isna(smiles):
        return 0
    return smiles.count('C') + smiles.count('c')

df_sample['carbon_count'] = df_sample['std_smiles'].apply(count_carbon_atoms)

# Define filters
filters = {
    'carbon_lt_4': df_sample['carbon_count'] < 4,
    'low_heavy': df_sample['HeavyAtomCount'] < 5,
    'high_heavy': df_sample['HeavyAtomCount'] > 75,
    'high_MW': df_sample['MolWt'] > 1000,
}

# Add LE filters if available
if 'LE' in df_sample.columns:
    filters['low_le'] = df_sample['LE'] <= 0.05
    filters['high_le'] = df_sample['LE'] >= 0.7

if 'LE_norm' in df_sample.columns:
    filters['high_le_norm'] = df_sample['LE_norm'] >= 0.003

# Apply filters
bad_filter = pd.Series([False] * len(df_sample))
for name, filter_mask in filters.items():
    n_filtered = filter_mask.sum()
    if n_filtered > 0:
        print(f"  {name}: {n_filtered} samples")
        bad_filter |= filter_mask

# Split good and bad samples
df_bad = df_sample[bad_filter]
df_good = df_sample[~bad_filter]

print(f"\nFiltered out {len(df_bad)} samples")
print(f"Remaining good samples: {len(df_good)}")

## 14. Display Top/Bottom Molecules by Ligand Efficiency

In [None]:
# Display top and bottom molecules by LE
if 'LE' in df_good.columns and len(df_good) > 0:
    df_le_valid = df_good.dropna(subset=['LE'])
    
    if len(df_le_valid) > 0:
        # Get top and bottom 5
        top_le = df_le_valid.nlargest(min(5, len(df_le_valid)), 'LE')
        bottom_le = df_le_valid.nsmallest(min(5, len(df_le_valid)), 'LE')
        
        print("\nTop 5 molecules by Ligand Efficiency:")
        print(top_le[['std_smiles', 'LE', 'MolWt', 'HeavyAtomCount']].to_string())
        
        print("\nBottom 5 molecules by Ligand Efficiency:")
        print(bottom_le[['std_smiles', 'LE', 'MolWt', 'HeavyAtomCount']].to_string())
        
        # Visualize molecules
        try:
            from rdkit.Chem import Draw
            
            print("\nTop 5 molecules visualization:")
            top_mols = [Chem.MolFromSmiles(smi) for smi in top_le['std_smiles'].head(5)]
            img = Draw.MolsToGridImage(top_mols, molsPerRow=5, subImgSize=(200, 200))
            display(img)
            
            print("\nBottom 5 molecules visualization:")
            bottom_mols = [Chem.MolFromSmiles(smi) for smi in bottom_le['std_smiles'].head(5)]
            img = Draw.MolsToGridImage(bottom_mols, molsPerRow=5, subImgSize=(200, 200))
            display(img)
        except Exception as e:
            print(f"Could not visualize molecules: {e}")

## 15. Data Validation

In [None]:
# Validate final dataset
from mtl_gnn_dta.preprocessing.validator import DataValidator

validator = DataValidator()

# Generate validation report
print("\nGenerating validation report...")
validation_report = validator.generate_validation_report(
    df_good, 
    [col for col in activity_cols if col in df_good.columns]
)

print("\nValidation Report:")
print("="*60)
print(f"Total samples: {validation_report['total_samples']}")

if validation_report['molecule_validation']:
    print(f"\nMolecule validation:")
    print(f"  Valid: {validation_report['molecule_validation']['valid']}")
    print(f"  Invalid: {validation_report['molecule_validation']['invalid']}")
    print(f"  Percentage valid: {validation_report['molecule_validation']['percentage_valid']:.1f}%")

print(f"\nActivity validation:")
print(f"  Valid: {validation_report['activity_validation']['valid']}")
print(f"  Invalid: {validation_report['activity_validation']['invalid']}")
print(f"  Percentage valid: {validation_report['activity_validation']['percentage_valid']:.1f}%")

print(f"\nFile validation:")
print(f"  Valid: {validation_report['file_validation']['valid']}")
print(f"  Invalid: {validation_report['file_validation']['invalid']}")
print(f"  Percentage valid: {validation_report['file_validation']['percentage_valid']:.1f}%")

print(f"\nSummary:")
print(f"  Overall valid: {validation_report['summary']['overall_valid']:.1f}%")
print(f"  Ready for training: {validation_report['summary']['ready_for_training']}")

## 16. Train/Validation/Test Split

In [None]:
# Split data for training
print("\nSplitting data into train/val/test sets...")

# First split: train+val vs test (80/20)
train_val_df, test_df = train_test_split(
    df_good, 
    test_size=0.2, 
    random_state=42
)

# Second split: train vs val (90/10 of train+val)
train_df, val_df = train_test_split(
    train_val_df, 
    test_size=0.1, 
    random_state=42
)

print(f"Train set: {len(train_df)} samples ({len(train_df)/len(df_good)*100:.1f}%)")
print(f"Validation set: {len(val_df)} samples ({len(val_df)/len(df_good)*100:.1f}%)")
print(f"Test set: {len(test_df)} samples ({len(test_df)/len(df_good)*100:.1f}%)")

## 17. Calculate Task Ranges for Multi-Task Learning

In [None]:
# Calculate task ranges for loss weighting
task_ranges = {}
task_weights = {}

for task in activity_cols:
    if task in train_df.columns:
        valid_values = train_df[task].dropna()
        if len(valid_values) > 0:
            task_range = valid_values.max() - valid_values.min()
            task_ranges[task] = task_range
            task_weights[task] = 1.0 / task_range if task_range > 0 else 1.0

# Normalize weights
if task_weights:
    total_weight = sum(task_weights.values())
    task_weights = {k: v/total_weight for k, v in task_weights.items()}

print("\nTask ranges and weights for multi-task learning:")
print("="*60)
for task in task_ranges:
    print(f"{task:20s}: range={task_ranges[task]:.2f}, weight={task_weights[task]:.4f}")

## 18. Save Processed Data

In [None]:
# Create output directory
output_dir = Path("../data/processed")
output_dir.mkdir(parents=True, exist_ok=True)

# Save splits
print("\nSaving processed data...")

train_df.to_parquet(output_dir / "train_data.parquet", index=False)
val_df.to_parquet(output_dir / "val_data.parquet", index=False)
test_df.to_parquet(output_dir / "test_data.parquet", index=False)

# Save task ranges and weights
with open(output_dir / "task_ranges.json", 'w') as f:
    json.dump(task_ranges, f, indent=2)

with open(output_dir / "task_weights.json", 'w') as f:
    json.dump(task_weights, f, indent=2)

# Save complete standardized dataset
df_good.to_parquet(output_dir / "standardized_data.parquet", index=False)

# Save metadata
metadata = {
    'total_samples': len(df_good),
    'train_samples': len(train_df),
    'val_samples': len(val_df),
    'test_samples': len(test_df),
    'task_columns': activity_cols,
    'task_ranges': task_ranges,
    'task_weights': task_weights,
    'validation_report': validation_report
}

with open(output_dir / "metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"Data saved to {output_dir}")
print("\nFiles created:")
for file in output_dir.glob("*"):
    print(f"  - {file.name}")

## 19. Summary Statistics

In [None]:
# Print final summary
print("\n" + "="*60)
print("DATA PREPARATION COMPLETE")
print("="*60)

print(f"\nFinal dataset statistics:")
print(f"  Total samples: {len(df_good):,}")
print(f"  Unique molecules: {df_good['InChIKey'].nunique():,}")
print(f"  Average molecular weight: {df_good['MolWt'].mean():.1f} ± {df_good['MolWt'].std():.1f}")
print(f"  Average heavy atoms: {df_good['HeavyAtomCount'].mean():.1f} ± {df_good['HeavyAtomCount'].std():.1f}")

if 'LE' in df_good.columns:
    print(f"  Average ligand efficiency: {df_good['LE'].mean():.3f} ± {df_good['LE'].std():.3f}")

print(f"\nData splits:")
print(f"  Training: {len(train_df):,} ({len(train_df)/len(df_good)*100:.1f}%)")
print(f"  Validation: {len(val_df):,} ({len(val_df)/len(df_good)*100:.1f}%)")
print(f"  Test: {len(test_df):,} ({len(test_df)/len(df_good)*100:.1f}%)")

print(f"\nTasks available for training:")
for task in task_ranges:
    n_train = train_df[task].notna().sum()
    n_val = val_df[task].notna().sum()
    n_test = test_df[task].notna().sum()
    print(f"  {task}: train={n_train}, val={n_val}, test={n_test}")

print("\n✅ Data is ready for model training!")
print("Next step: Run the model training notebook (02_model_training.ipynb)")