In [1]:
"""
Example Notebook 1: Data Preparation for MTL-GNN-DTA
This notebook demonstrates how to prepare and process data for training
"""

# %% [markdown]
# # Data Preparation for MTL-GNN-DTA
# 
# This notebook demonstrates:
# 1. Loading and processing raw affinity data
# 2. Standardizing molecular structures
# 3. Computing molecular properties
# 4. Preparing data for model training

# %% [markdown]
# ## 1. Setup and Imports

# %%
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path for imports
sys.path.append('../../')

# Import MTL-GNN-DTA modules
from mtl_gnn_dta import Config, AffinityPredictor
from mtl_gnn_dta.preprocessing import (
    standardize_protein,
    standardize_ligand,
    validate_structures
)
from mtl_gnn_dta.features import DrugFeaturizer, ProteinFeaturizer
from mtl_gnn_dta.utils import setup_logging

import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Setup logging
setup_logging()

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

print("MTL-GNN-DTA Data Preparation")
print("="*50)

# %% [markdown]
# ## 2. Load Configuration

# %%
# Initialize configuration
config = Config()

# Or load from file
# config = Config('../../experiments/configs/default_config.yaml')

print("Configuration loaded:")
print(f"  Data directory: {config.data.processed_dir}")
print(f"  Task names: {config.model.task_names}")
print(f"  Batch size: {config.data.batch_size}")

# %% [markdown]
# ## 3. Load Raw Data

# %%
# Example: Load a sample dataset
# In practice, replace with your actual data loading

# Create sample data for demonstration
sample_data = pd.DataFrame({
    'protein_pdb_path': ['data/structures/protein_001.pdb'] * 100,
    'ligand_sdf_path': ['data/structures/ligand_001.sdf'] * 100,
    'smiles': ['CC(C)CC1=CC=C(C=C1)C(C)C(O)=O'] * 100,
    'pKi': np.random.normal(7.0, 1.5, 100),
    'pEC50': np.random.normal(6.5, 1.2, 100),
    'pKd': np.random.normal(7.2, 1.3, 100),
    'pIC50': np.random.normal(6.8, 1.4, 100)
})

# Add some missing values to simulate real data
for col in ['pKi', 'pEC50', 'pKd', 'pIC50']:
    mask = np.random.random(100) < 0.2
    sample_data.loc[mask, col] = np.nan

print(f"Loaded {len(sample_data)} data points")
print("\nData overview:")
print(sample_data.info())

# %% [markdown]
# ## 4. Data Quality Analysis

# %%
# Analyze task distributions
task_cols = ['pKi', 'pEC50', 'pKd', 'pIC50']

fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

for i, task in enumerate(task_cols):
    ax = axes[i]
    valid_data = sample_data[task].dropna()
    
    ax.hist(valid_data, bins=30, alpha=0.7, edgecolor='black')
    ax.axvline(valid_data.mean(), color='red', linestyle='--', label=f'Mean: {valid_data.mean():.2f}')
    ax.axvline(valid_data.median(), color='green', linestyle='--', label=f'Median: {valid_data.median():.2f}')
    
    ax.set_xlabel(task)
    ax.set_ylabel('Count')
    ax.set_title(f'{task} Distribution (n={len(valid_data)})')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Activity Value Distributions', fontsize=14)
plt.tight_layout()
plt.show()

# Print statistics
print("\nTask Statistics:")
print("="*50)
for task in task_cols:
    valid_data = sample_data[task].dropna()
    print(f"{task:10s}: n={len(valid_data):4d}, mean={valid_data.mean():.2f}, "
          f"std={valid_data.std():.2f}, min={valid_data.min():.2f}, max={valid_data.max():.2f}")

# %% [markdown]
# ## 5. Molecular Featurization Example

# %%
# Initialize featurizers
drug_featurizer = DrugFeaturizer()
protein_featurizer = ProteinFeaturizer()

# Example: Featurize a drug from SMILES
smiles_example = "CC(C)CC1=CC=C(C=C1)C(C)C(O)=O"
print(f"Featurizing SMILES: {smiles_example}")

drug_graph = drug_featurizer.featurize_from_smiles(smiles_example)
if drug_graph:
    print(f"  Nodes: {drug_graph.x.shape[0]}")
    print(f"  Node features: {drug_graph.x.shape[1]}")
    print(f"  Edges: {drug_graph.edge_index.shape[1]}")
    print(f"  Edge features: {drug_graph.edge_attr.shape[1] if drug_graph.edge_attr is not None else 0}")

# %% [markdown]
# ## 6. Data Splitting

# %%
from sklearn.model_selection import train_test_split

# Split data into train/validation/test
train_data, test_data = train_test_split(
    sample_data, 
    test_size=0.2, 
    random_state=42
)

train_data, val_data = train_test_split(
    train_data, 
    test_size=0.125,  # 0.125 * 0.8 = 0.1 of total
    random_state=42
)

print("Data split:")
print(f"  Train: {len(train_data)} ({len(train_data)/len(sample_data)*100:.1f}%)")
print(f"  Validation: {len(val_data)} ({len(val_data)/len(sample_data)*100:.1f}%)")
print(f"  Test: {len(test_data)} ({len(test_data)/len(sample_data)*100:.1f}%)")

# %% [markdown]
# ## 7. Calculate Task Ranges for Multi-Task Learning

# %%
# Calculate task ranges for loss weighting
task_ranges = {}
for task in task_cols:
    valid_values = train_data[task].dropna()
    if len(valid_values) > 0:
        task_ranges[task] = valid_values.max() - valid_values.min()
    else:
        task_ranges[task] = 1.0

print("\nTask ranges for loss weighting:")
print("="*50)
for task, range_val in task_ranges.items():
    weight = 1.0 / range_val if range_val > 0 else 1.0
    normalized_weight = weight / sum(1.0/r if r > 0 else 1.0 for r in task_ranges.values())
    print(f"{task:10s}: range={range_val:.2f}, normalized_weight={normalized_weight:.4f}")

# %% [markdown]
# ## 8. Save Processed Data

# %%
# Save processed data
output_dir = Path(config.data.processed_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# Save splits
train_data.to_parquet(output_dir / 'train_data.parquet', index=False)
val_data.to_parquet(output_dir / 'val_data.parquet', index=False)
test_data.to_parquet(output_dir / 'test_data.parquet', index=False)

# Save task ranges
import json
with open(output_dir / 'task_ranges.json', 'w') as f:
    json.dump(task_ranges, f, indent=2)

print(f"\nData saved to {output_dir}")
print("Files created:")
for file in output_dir.glob('*.parquet'):
    print(f"  - {file.name}")

# %% [markdown]
# ## 9. Data Quality Checks

# %%
# Check for missing values
print("\nMissing values per task:")
print("="*50)
for task in task_cols:
    missing = train_data[task].isna().sum()
    total = len(train_data)
    print(f"{task:10s}: {missing:4d} / {total:4d} ({missing/total*100:.1f}%)")

# Check correlations between tasks
correlation_matrix = train_data[task_cols].corr()

plt.figure(figsize=(8, 6))
sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            square=True, linewidths=1, cbar_kws={"shrink": 0.8})
plt.title('Task Correlation Matrix')
plt.tight_layout()
plt.show()

# %% [markdown]
# ## 10. Next Steps
# 
# Now that the data is prepared, you can:
# 1. Move to `02_model_training.ipynb` to train the model
# 2. Customize the featurization process for your specific molecules
# 3. Add additional data preprocessing steps as needed
# 
# The prepared data includes:
# - Standardized molecular structures
# - Computed molecular properties
# - Task value distributions
# - Train/validation/test splits
# - Task ranges for multi-task learning

ModuleNotFoundError: No module named 'torch_scatter'