# 1: Imports and Setup

In [1]:
# ==================================================================================
# CELL 1: IMPORTS AND SETUP
# ==================================================================================


import os
import sys
import json
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch_geometric
from pathlib import Path
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add the parent directory to the Python path
# This allows importing the gnn_dta_mtl package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Set device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

# Import your package - use absolute import instead of relative
from gnn_dta_mtl import (
    MTL_DTAModel, DTAModel,
    MTL_DTA, DTA,
    CrossValidator, MTLTrainer,
    StructureStandardizer, StructureProcessor, StructureChunkLoader,
    ESMEmbedder,
    add_molecular_properties_parallel,
    compute_ligand_efficiency,
    compute_mean_ligand_efficiency,
    filter_by_properties,
    prepare_mtl_experiment,
    build_mtl_dataset, build_mtl_dataset_optimized,
    evaluate_model,
    plot_results, plot_predictions, create_summary_report,
    ExperimentLogger,
    save_model, save_results, create_output_dir
)

import os
import sys
import json
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch_geometric
from pathlib import Path
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add the parent directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

# Import your package
from gnn_dta_mtl import (
    MTL_DTAModel,
    StructureChunkLoader,
    prepare_mtl_experiment,
    build_mtl_dataset_optimized,
    save_model, save_results, create_output_dir
)

# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("\n✓ Environment setup complete!")

PyTorch version: 2.0.1+cu118
CUDA available: True
Number of GPUs: 16
  GPU 0: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 1: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 2: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 3: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 4: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 5: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 6: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 7: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 8: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 9: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 10: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 11: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 12: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 13: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 14: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB
  GPU 15: NVIDIA A100-SXM4-40GB
    Memory: 42.3 GB

✓ Environment setup complete!


In [2]:


# ==================================================================================
# CELL 2: CONFIGURATION
# ==================================================================================
from datetime import datetime

# Main configuration
CONFIG = {
    # Data paths - POINT TO YOUR DATA HERE
    'data_path': './deduplicated_complexes_parallel.csv',  # YOUR DEDUPLICATED DATA
    'structure_chunks_dir': '../input/chunk/',  # YOUR STRUCTURE CHUNKS
    
    # Task configuration
    'task_cols': ['pKi', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency'],
    
    # Model configuration
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_edge_in_dim': [16, 1],
        'drug_edge_h_dims': [32, 1],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    },
    
    # Training configuration for multi-GPU
    'training_config': {
        'batch_size_per_gpu': 512,  # 512 * 16 GPUs = 8192 total batch size
        'n_epochs': 200,
        'learning_rate': 0.001,
        'patience': 30,
        'n_folds': 5
    },
    
    # Other settings
    'seed': SEED,
    'gradient_accumulation_steps': 1,
    'use_mixed_precision': True  # For A100 tensor cores
}

# Create experiment directory
experiment_name = f'ddp_training_{datetime.now():%Y%m%d_%H%M%S}'
experiment_dir = Path(f'../output/experiments/{experiment_name}')
experiment_dir.mkdir(parents=True, exist_ok=True)

CONFIG['experiment_dir'] = str(experiment_dir)
CONFIG['checkpoint_dir'] = str(experiment_dir / 'checkpoints')
CONFIG['log_dir'] = str(experiment_dir / 'logs')

# Create subdirectories
(experiment_dir / 'checkpoints').mkdir(exist_ok=True)
(experiment_dir / 'logs').mkdir(exist_ok=True)
(experiment_dir / 'results').mkdir(exist_ok=True)
(experiment_dir / 'figures').mkdir(exist_ok=True)

# Save configuration
config_path = experiment_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print(f"✓ Experiment: {experiment_name}")
print(f"✓ Config saved to: {config_path}")
print(f"✓ Using {torch.cuda.device_count()} GPUs")


✓ Experiment: ddp_training_20250926_074933
✓ Config saved to: ../output/experiments/ddp_training_20250926_074933/config.json
✓ Using 16 GPUs


In [3]:

# ==================================================================================
# CELL 3: LOAD AND PREPARE DATA
# ==================================================================================
print("\nLoading data...")

# Load your deduplicated data
df = pd.read_csv(CONFIG['data_path'])[:1000]
print(f"✓ Loaded {len(df):,} samples")

# Display data info
print("\nData columns:", df.columns.tolist())
print("\nData shape:", df.shape)

# Check task coverage
print("\nTask coverage:")
for task in CONFIG['task_cols']:
    if task in df.columns:
        coverage = df[task].notna().sum()
        print(f"  {task}: {coverage:,} samples ({coverage/len(df)*100:.1f}%)")

# Calculate task ranges for weighting
task_ranges = prepare_mtl_experiment(df, CONFIG['task_cols'])

# Create chunk loader for structure loading
print("\nInitializing structure chunk loader...")
chunk_loader = StructureChunkLoader(
    chunk_dir=CONFIG['structure_chunks_dir'],
    cache_size=10
)
print("✓ Chunk loader ready")

# Save task ranges to config
CONFIG['task_ranges'] = task_ranges
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\n✓ Data preparation complete!")


Loading data...
✓ Loaded 1,000 samples

Data columns: ['protein_pdb_path', 'ligand_sdf_path', 'smiles', 'pKi', 'source_file', 'is_experimental', 'resolution', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'SMILES', 'potency', 'assay', 'standardized_protein_pdb', 'standardized_ligand_sdf', 'std_smiles', 'protein_id', 'InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 'NumHDonors', 'NumHAcceptors', 'NumRotatableBonds', 'TPSA', 'LogP', 'LE_pKi', 'LEnorm_pKi', 'LE_pEC50', 'LEnorm_pEC50', 'LE_pKd (Wang, FEP)', 'LEnorm_pKd (Wang, FEP)', 'LE_pKd', 'LEnorm_pKd', 'LE_pIC50', 'LEnorm_pIC50', 'LE_potency', 'LEnorm_potency', 'LE', 'LE_norm', 'carbon_count', 'sequence_id', 'processed_inchikey', 'inchikey_hash', 'has_valid_inchikey', 'num_merged', 'merged_from_sources']

Data shape: (1000, 48)

Task coverage:
  pKi: 218 samples (21.8%)
  pEC50: 56 samples (5.6%)
  pKd (Wang, FEP): 1 samples (0.1%)
  pKd: 77 samples (7.7%)
  pIC50: 652 samples (65.2%)
  potency: 0 samples (0.0%)
Task ranges for weighting:


IOStream.flush timed out
IOStream.flush timed out


Loaded 406088 structures from 10 chunks
Optimization: mmap=True, pickle=False, preload=False
✓ Chunk loader ready

✓ Data preparation complete!


In [10]:


# ==================================================================================
# CELL 4: LAUNCH FULL TRAINING (MULTI-GPU)
# ==================================================================================
import subprocess
import time
import threading
from IPython.display import display, clear_output

def launch_ddp_training(config_path, mode='train', n_gpus=16):
    """
    Launch DDP training using torchrun.
    """
    # Save data path for the launch script
    config_with_data = CONFIG.copy()
    config_with_data['data_path'] = os.path.abspath(CONFIG['data_path'])
    with open(config_path, 'w') as f:
        json.dump(config_with_data, f, indent=2)
    
    cmd = [
        'torchrun',
        '--nproc_per_node', str(n_gpus),
        '--master_port', '12355',
        '../training/launch_training.py',
        '--config', str(config_path),
        '--mode', mode,
        '--n_gpus', str(n_gpus)
    ]
    
    print(f"Launching {mode} with {n_gpus} GPUs...")
    print(f"Command: {' '.join(cmd)}")
    
    # Launch process
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )
    
    # Print output in real-time
    for line in iter(process.stdout.readline, ''):
        if line:
            print(line.rstrip())
    
    process.wait()
    print(f"\n{mode} completed!")
    
    return process.returncode

# Launch full training
print("="*70)
print("LAUNCHING FULL TRAINING ON 16 GPUs")
print("="*70)
print(f"Total batch size: {CONFIG['training_config']['batch_size_per_gpu'] * 16}")
print(f"Learning rate: {CONFIG['training_config']['learning_rate']}")
print(f"Epochs: {CONFIG['training_config']['n_epochs']}")
print(f"Mixed precision: {CONFIG['use_mixed_precision']}")
print("="*70)

return_code = launch_ddp_training(config_path, mode='train', n_gpus=16)

if return_code == 0:
    print("\n✓ Training completed successfully!")
else:
    print(f"\n⚠ Training finished with return code {return_code}")


LAUNCHING FULL TRAINING ON 16 GPUs
Total batch size: 8192
Learning rate: 0.001
Epochs: 200
Mixed precision: True
Launching train with 16 GPUs...
Command: torchrun --nproc_per_node 16 --master_port 12355 ../training/launch_training.py --config ../output/experiments/ddp_training_20250926_074933/config.json --mode train --n_gpus 16
W0926 08:08:42.672000 56446 site-packages/torch/distributed/run.py:766]
W0926 08:08:42.672000 56446 site-packages/torch/distributed/run.py:766] *****************************************
W0926 08:08:42.672000 56446 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0926 08:08:42.672000 56446 site-packages/torch/distributed/run.py:766] *****************************************

train completed!

✓ Training completed successfully!


In [14]:
# ==================================================================================
# CELL 1: CHECK WHAT HAPPENED
# ==================================================================================
import os
from pathlib import Path

# Check if log files were created
log_dir = Path(CONFIG['log_dir'])
checkpoint_dir = Path(CONFIG['checkpoint_dir'])

print("Checking outputs...")
print(f"Log directory exists: {log_dir.exists()}")
print(f"Checkpoint directory exists: {checkpoint_dir.exists()}")

if log_dir.exists():
    log_files = list(log_dir.glob('*'))
    print(f"Log files: {log_files}")

if checkpoint_dir.exists():
    checkpoint_files = list(checkpoint_dir.glob('*'))
    print(f"Checkpoint files: {checkpoint_files}")

# ==================================================================================
# CELL 2: CREATE A SIMPLER SINGLE-FILE TRAINING SCRIPT
# ==================================================================================

# Let's create a simpler, single-file solution that's easier to debug
training_script = '''
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
import argparse
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from datetime import datetime

# Add parent directory to path
sys.path.append(str(Path(__file__).parent.parent))

from gnn_dta_mtl import (
    MTL_DTAModel, 
    build_mtl_dataset_optimized, 
    StructureChunkLoader,
    prepare_mtl_experiment,
    create_data_splits
)

class SimpleDDPTrainer:
    def __init__(self, rank, world_size, config):
        self.rank = rank
        self.world_size = world_size
        self.config = config
        
        # Setup process group
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
        
        # Set device
        self.device = torch.device(f'cuda:{rank}')
        torch.cuda.set_device(rank)
        
    def train(self):
        """Main training function"""
        print(f"Process {self.rank} starting...")
        
        # Load data on rank 0 only
        if self.rank == 0:
            print("Loading data...")
            df = pd.read_csv(self.config['data_path'])
            
            # Limit data size for testing
            if 'sample_size' in self.config:
                df = df.head(self.config['sample_size'])
                print(f"Using {len(df)} samples")
            
            # Calculate task ranges
            from gnn_dta_mtl import prepare_mtl_experiment
            task_ranges = prepare_mtl_experiment(df, self.config['task_cols'])
            
            # Create chunk loader
            chunk_loader = StructureChunkLoader(
                chunk_dir=self.config['structure_chunks_dir'],
                cache_size=10
            )
            
            # Create splits
            from gnn_dta_mtl.datasets import create_data_splits
            splits = create_data_splits(df, split_method='random', split_frac=[0.8, 0.1, 0.1], seed=42)
            df_train = splits['train']
            df_valid = splits['valid']
            
            print(f"Train: {len(df_train)}, Valid: {len(df_valid)}")
            
            # Build datasets
            print("Building datasets...")
            train_dataset = build_mtl_dataset_optimized(df_train, chunk_loader, self.config['task_cols'])
            valid_dataset = build_mtl_dataset_optimized(df_valid, chunk_loader, self.config['task_cols'])
            
            data_ready = True
        else:
            train_dataset = None
            valid_dataset = None
            task_ranges = None
            data_ready = None
        
        # Broadcast data ready signal
        data_ready = [data_ready]
        dist.broadcast_object_list(data_ready, src=0)
        
        if not data_ready[0]:
            print(f"Process {self.rank}: Data preparation failed")
            return
        
        # Broadcast datasets
        if self.rank == 0:
            datasets_and_ranges = [train_dataset, valid_dataset, task_ranges]
        else:
            datasets_and_ranges = [None, None, None]
        
        dist.broadcast_object_list(datasets_and_ranges, src=0)
        train_dataset, valid_dataset, task_ranges = datasets_and_ranges
        
        # Create distributed samplers
        train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True
        )
        
        valid_sampler = DistributedSampler(
            valid_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False
        )
        
        # Create data loaders
        import torch_geometric
        batch_size = self.config['training_config'].get('batch_size_per_gpu', 32)
        
        train_loader = torch_geometric.loader.DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_sampler,
            num_workers=2,
            pin_memory=True
        )
        
        valid_loader = torch_geometric.loader.DataLoader(
            valid_dataset,
            batch_size=batch_size,
            sampler=valid_sampler,
            num_workers=2,
            pin_memory=True
        )
        
        # Create model
        model = MTL_DTAModel(
            task_names=self.config['task_cols'],
            **self.config['model_config']
        ).to(self.device)
        
        # Wrap in DDP
        model = DDP(model, device_ids=[self.rank])
        
        # Create optimizer and loss
        optimizer = torch.optim.Adam(model.parameters(), lr=self.config['training_config']['learning_rate'])
        
        # Simple MSE loss
        criterion = nn.MSELoss()
        
        # Training loop
        n_epochs = self.config['training_config'].get('n_epochs', 5)
        
        if self.rank == 0:
            print(f"Starting training for {n_epochs} epochs...")
            checkpoint_dir = Path(self.config['checkpoint_dir'])
            checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        for epoch in range(n_epochs):
            # Set epoch for sampler
            train_sampler.set_epoch(epoch)
            
            # Training
            model.train()
            train_loss = 0
            n_batches = 0
            
            if self.rank == 0:
                pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
            else:
                pbar = train_loader
            
            for batch in pbar:
                xd = batch['drug'].to(self.device)
                xp = batch['protein'].to(self.device)
                y = batch['y'].to(self.device)
                
                optimizer.zero_grad()
                pred = model(xd, xp)
                
                # Handle NaN values
                mask = ~torch.isnan(y)
                if mask.sum() > 0:
                    loss = criterion(pred[mask], y[mask])
                    loss.backward()
                    optimizer.step()
                    
                    train_loss += loss.item()
                    n_batches += 1
            
            # Print progress
            if self.rank == 0 and n_batches > 0:
                avg_train_loss = train_loss / n_batches
                print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}")
                
                # Save checkpoint
                if epoch % 5 == 0 or epoch == n_epochs - 1:
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': avg_train_loss
                    }
                    torch.save(checkpoint, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
                    print(f"Saved checkpoint at epoch {epoch+1}")
        
        if self.rank == 0:
            print("Training completed!")
            # Save final model
            torch.save(model.module.state_dict(), checkpoint_dir / 'final_model.pt')
        
        # Clean up
        dist.destroy_process_group()

def run_process(rank, world_size, config):
    """Function to run in each process"""
    trainer = SimpleDDPTrainer(rank, world_size, config)
    trainer.train()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--n_gpus', type=int, default=16)
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = json.load(f)
    
    # Spawn processes
    mp.spawn(run_process, args=(args.n_gpus, config), nprocs=args.n_gpus, join=True)

if __name__ == '__main__':
    main()
'''

# Save the training script
script_path = Path('../simple_ddp_train.py')
with open(script_path, 'w') as f:
    f.write(training_script)

print(f"✓ Created simplified training script: {script_path}")

# ==================================================================================
# CELL 3: RUN THE SIMPLIFIED TRAINING
# ==================================================================================

import subprocess

def run_simple_training(config_path, n_gpus=16):
    """Run the simplified training script"""
    
    # Update config for quick test
    CONFIG['sample_size'] = 1000  # Use only 1000 samples
    CONFIG['training_config']['n_epochs'] = 3  # Just 3 epochs for testing
    CONFIG['training_config']['batch_size_per_gpu'] = 16  # Smaller batch
    
    # Save updated config
    with open(config_path, 'w') as f:
        json.dump(CONFIG, f, indent=2)
    
    cmd = [
        'torchrun',
        '--nproc_per_node', str(n_gpus),
        '--master_port', '12356',  # Different port to avoid conflicts
        '../simple_ddp_train.py',
        '--config', str(config_path),
        '--n_gpus', str(n_gpus)
    ]
    
    print(f"Running command: {' '.join(cmd)}")
    
    # Run the training
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True
    )
    
    # Print output
    for line in iter(process.stdout.readline, ''):
        if line:
            print(line.rstrip())
    
    process.wait()
    return process.returncode

print("="*70)
print("RUNNING SIMPLIFIED DDP TRAINING")
print("="*70)
print(f"GPUs: 16 x A100")
print(f"Samples: 1000 (for testing)")
print(f"Epochs: 3")
print("="*70)

return_code = run_simple_training(config_path, n_gpus=16)

if return_code == 0:
    print("\n✓ Training completed successfully!")
    
    # Check outputs
    checkpoint_dir = Path(CONFIG['checkpoint_dir'])
    if checkpoint_dir.exists():
        checkpoints = list(checkpoint_dir.glob('*.pt'))
        print(f"\nSaved checkpoints:")
        for ckpt in checkpoints:
            print(f"  - {ckpt.name}")
else:
    print(f"\n⚠ Training failed with return code {return_code}")

Checking outputs...
Log directory exists: True
Checkpoint directory exists: True
Log files: []
Checkpoint files: []
✓ Created simplified training script: ../simple_ddp_train.py
RUNNING SIMPLIFIED DDP TRAINING
GPUs: 16 x A100
Samples: 1000 (for testing)
Epochs: 3
Running command: torchrun --nproc_per_node 16 --master_port 12356 ../simple_ddp_train.py --config ../output/experiments/ddp_training_20250926_074933/config.json --n_gpus 16
W0926 08:17:01.782000 59034 site-packages/torch/distributed/run.py:766]
W0926 08:17:01.782000 59034 site-packages/torch/distributed/run.py:766] *****************************************
W0926 08:17:01.782000 59034 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0926 08:17:01.782000 59034 site-packages/torch/distributed/run.py:766] *************************

In [5]:

# ==================================================================================
# CELL 5: MONITOR TRAINING PROGRESS
# ==================================================================================
import json
import glob

def check_training_progress(log_dir):
    """Check training progress from log files."""
    log_dir = Path(log_dir)
    
    # Find metrics file
    metrics_files = list(log_dir.glob('metrics_*.json'))
    if not metrics_files:
        print("No metrics files found yet...")
        return None
    
    # Load latest metrics
    latest_file = max(metrics_files, key=os.path.getmtime)
    with open(latest_file, 'r') as f:
        metrics = json.load(f)
    
    if metrics:
        latest = metrics[-1]
        print(f"Latest epoch: {latest.get('epoch', 'N/A')}")
        print(f"Train loss: {latest.get('train_loss', 0):.4f}")
        print(f"Valid loss: {latest.get('val_loss', 0):.4f}")
        
        if 'task_metrics' in latest:
            print("\nTask metrics:")
            for task, task_metrics in latest['task_metrics'].items():
                print(f"  {task}: RMSE={task_metrics['rmse']:.3f}, R²={task_metrics['r2']:.3f}")
    
    return metrics

# Check progress
print("Training Progress:")
print("-" * 50)
metrics = check_training_progress(CONFIG['log_dir'])

# Plot training curves if available
if metrics and len(metrics) > 1:
    epochs = [m['epoch'] for m in metrics]
    train_losses = [m.get('train_loss', 0) for m in metrics]
    val_losses = [m.get('val_loss', 0) for m in metrics]
    
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Valid Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

Training Progress:
--------------------------------------------------
No metrics files found yet...


In [6]:


# ==================================================================================
# CELL 6: LAUNCH CROSS-VALIDATION (MULTI-GPU)
# ==================================================================================
print("="*70)
print("LAUNCHING CROSS-VALIDATION ON 16 GPUs")
print("="*70)
print(f"Number of folds: {CONFIG['training_config']['n_folds']}")
print(f"Total batch size: {CONFIG['training_config']['batch_size_per_gpu'] * 16}")
print("="*70)

return_code = launch_ddp_training(config_path, mode='cv', n_gpus=16)

if return_code == 0:
    print("\n✓ Cross-validation completed successfully!")
else:
    print(f"\n⚠ Cross-validation finished with return code {return_code}")

LAUNCHING CROSS-VALIDATION ON 16 GPUs
Number of folds: 5
Total batch size: 8192
Launching cv with 16 GPUs...
Command: torchrun --nproc_per_node 16 --master_port 12355 ../gnn_dta_mtl/training/launch_training.py --config ../output/experiments/ddp_training_20250926_074933/config.json --mode cv --n_gpus 16
W0926 08:06:24.358000 56009 site-packages/torch/distributed/run.py:766]
W0926 08:06:24.358000 56009 site-packages/torch/distributed/run.py:766] *****************************************
W0926 08:06:24.358000 56009 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0926 08:06:24.358000 56009 site-packages/torch/distributed/run.py:766] *****************************************
/opt/conda/bin/python3.10: can't open file '/home/HX46_FR5/github/APEX2/gnn_dta_mtl/notebooks/../gnn_dta_mtl/traini

In [7]:


# ==================================================================================
# CELL 7: CHECK RESULTS
# ==================================================================================
def check_training_results(checkpoint_dir):
    """Check final training results."""
    checkpoint_dir = Path(checkpoint_dir)
    
    # Check for best model
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        checkpoint = torch.load(best_model_path, map_location='cpu')
        print("✓ Best Model Found!")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Validation Loss: {checkpoint.get('val_loss', 'N/A'):.4f}")
        
        if 'task_metrics' in checkpoint:
            print("\n  Task Metrics:")
            for task, metrics in checkpoint['task_metrics'].items():
                print(f"    {task}: RMSE={metrics['rmse']:.3f}, R²={metrics['r2']:.3f}")
    
    # Check for CV results
    cv_results_path = checkpoint_dir / 'cv_results.json'
    if cv_results_path.exists():
        with open(cv_results_path, 'r') as f:
            cv_results = json.load(f)
        
        print("\n✓ Cross-Validation Results:")
        for task, metrics in cv_results.items():
            print(f"  {task}:")
            print(f"    R²: {metrics['r2_mean']:.3f} ± {metrics['r2_std']:.3f}")
            print(f"    RMSE: {metrics['rmse_mean']:.3f} ± {metrics['rmse_std']:.3f}")

print("Final Results:")
print("="*70)
check_training_results(CONFIG['checkpoint_dir'])

Final Results:


In [8]:


# ==================================================================================
# CELL 8: LOAD BEST MODEL FOR INFERENCE
# ==================================================================================
def load_best_model(checkpoint_dir, model_config, task_cols):
    """Load the best trained model."""
    # Create model
    model = MTL_DTAModel(
        task_names=task_cols,
        **model_config
    )
    
    # Load checkpoint
    checkpoint_path = Path(checkpoint_dir) / 'best_model.pt'
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
        return model
    else:
        print("⚠ No best model found!")
        return None

# Load the best model
best_model = load_best_model(
    CONFIG['checkpoint_dir'],
    CONFIG['model_config'],
    CONFIG['task_cols']
)

if best_model:
    # Move to GPU for inference
    best_model = best_model.cuda()
    best_model.eval()
    print("✓ Model ready for inference")

⚠ No best model found!


In [9]:


# ==================================================================================
# CELL 9: INFERENCE EXAMPLE
# ==================================================================================
if best_model:
    # Example: Make predictions on a sample
    print("\nPreparing sample for inference...")
    
    # Take a small sample from your data
    sample_df = df.sample(n=10, random_state=42)
    
    # Build dataset
    sample_dataset = build_mtl_dataset_optimized(
        sample_df, 
        chunk_loader, 
        CONFIG['task_cols']
    )
    
    # Create loader
    sample_loader = torch_geometric.loader.DataLoader(
        sample_dataset,
        batch_size=10,
        shuffle=False
    )
    
    # Make predictions
    best_model.eval()
    with torch.no_grad():
        for batch in sample_loader:
            xd = batch['drug'].cuda()
            xp = batch['protein'].cuda()
            y_true = batch['y'].cuda()
            
            # Predict
            y_pred = best_model(xd, xp)
            
            # Display results
            print("\nSample predictions:")
            for i, task in enumerate(CONFIG['task_cols']):
                if not torch.isnan(y_true[0, i]):
                    print(f"{task}:")
                    print(f"  True: {y_true[0, i].item():.3f}")
                    print(f"  Pred: {y_pred[0, i].item():.3f}")
            break

print("\n✓ All done!")


✓ All done!


In [5]:
len(df)

406088

# Drop duplicates

In [18]:
import pandas as pd
import numpy as np
from Bio import PDB
from Bio.PDB import PDBParser
from collections import Counter, defaultdict
from multiprocessing import Pool, cpu_count
from functools import partial
import warnings
from tqdm import tqdm
import hashlib
warnings.filterwarnings('ignore')

# ============================================================================
# PARALLEL PDB PROCESSING FUNCTIONS
# ============================================================================

def extract_sequence_from_pdb_single(pdb_path):
    """
    Extract sequences from a single PDB file for each chain.
    This function is designed to be run in parallel.
    
    Returns: (pdb_path, list of sequences)
    """
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure('protein', pdb_path)
        
        sequences = []
        for model in structure:
            for chain in model:
                seq = []
                for residue in chain:
                    if PDB.is_aa(residue):
                        res_name = residue.get_resname()
                        # Convert 3-letter code to 1-letter code
                        three_to_one = {
                            'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
                            'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
                            'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
                            'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
                            'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
                        }
                        if res_name in three_to_one:
                            seq.append(three_to_one[res_name])
                
                if seq:
                    sequences.append(''.join(seq))
        
        return (pdb_path, sequences)
    except Exception as e:
        print(f"Error processing {pdb_path}: {e}")
        return (pdb_path, [])

def parallel_extract_pdb_sequences(pdb_paths, n_workers=None):
    """
    Extract sequences from multiple PDB files in parallel.
    
    Args:
        pdb_paths: List of PDB file paths
        n_workers: Number of parallel workers (None = use all CPUs)
    
    Returns:
        Dictionary mapping pdb_path to sequences
    """
    if n_workers is None:
        n_workers = cpu_count()
    
    print(f"Processing {len(pdb_paths)} PDB files with {n_workers} workers...")
    
    with Pool(n_workers) as pool:
        results = list(tqdm(
            pool.imap(extract_sequence_from_pdb_single, pdb_paths),
            total=len(pdb_paths),
            desc="Extracting PDB sequences"
        ))
    
    # Convert results to dictionary
    pdb_to_sequences = {pdb_path: sequences for pdb_path, sequences in results}
    return pdb_to_sequences

def create_global_sequence_mapping(pdb_to_sequences):
    """
    Create a global mapping of sequences to IDs across all PDBs.
    
    Args:
        pdb_to_sequences: Dictionary mapping PDB paths to their sequences
    
    Returns:
        seq_to_id: Dictionary mapping sequences to IDs
        id_to_seq: Dictionary mapping IDs to sequences
    """
    seq_to_id = {}
    id_to_seq = {}
    
    # Collect all unique sequences
    all_sequences = set()
    for sequences in pdb_to_sequences.values():
        all_sequences.update(sequences)
    
    # Assign IDs to unique sequences
    for i, seq in enumerate(sorted(all_sequences)):
        if i < 26:
            seq_id = chr(65 + i)  # A, B, C, ...
        else:
            # Use two letters for more than 26 sequences
            seq_id = f"{chr(65 + (i // 26 - 1))}{chr(65 + (i % 26))}"
        
        seq_to_id[seq] = seq_id
        id_to_seq[seq_id] = seq
    
    return seq_to_id, id_to_seq

def create_sequence_id_for_pdb(sequences, seq_to_id):
    """
    Create a sequence ID string for a PDB based on its sequences.
    E.g., "2A4B" means 2 of sequence A and 4 of sequence B
    """
    if not sequences:
        return "EMPTY"
    
    # Count occurrences of each sequence
    seq_counter = Counter(sequences)
    
    # Map sequences to IDs and count them
    id_counts = defaultdict(int)
    for seq, count in seq_counter.items():
        if seq in seq_to_id:
            id_counts[seq_to_id[seq]] += count
    
    # Create the ID string (alphabetically sorted)
    id_string = ''
    for seq_id in sorted(id_counts.keys()):
        count = id_counts[seq_id]
        if count == 1:
            id_string += seq_id
        else:
            id_string += f"{count}{seq_id}"
    
    return id_string

# ============================================================================
# PARALLEL INCHIKEY PROCESSING FUNCTIONS
# ============================================================================

def process_inchikey_batch(batch_data):
    """
    Process a batch of InChIKeys in parallel.
    This function can be extended to perform additional InChIKey processing.
    
    Args:
        batch_data: Tuple of (indices, inchikeys, smiles)
    
    Returns:
        List of tuples (index, processed_inchikey, additional_data)
    """
    indices, inchikeys, smiles = batch_data
    results = []
    
    for idx, inchikey, smile in zip(indices, inchikeys, smiles):
        # Here you can add more complex InChIKey processing
        # For now, we just validate and potentially standardize
        processed_inchikey = inchikey if pd.notna(inchikey) else f"UNKNOWN_{idx}"
        
        # You can add molecular property calculations here
        additional_data = {
            'inchikey_hash': hashlib.md5(str(inchikey).encode()).hexdigest()[:8],
            'has_valid_inchikey': pd.notna(inchikey)
        }
        
        results.append((idx, processed_inchikey, additional_data))
    
    return results

def parallel_process_inchikeys(df, n_workers=None, batch_size=10000):
    """
    Process InChIKeys in parallel batches.
    
    Args:
        df: DataFrame containing InChIKey column
        n_workers: Number of parallel workers
        batch_size: Size of each processing batch
    
    Returns:
        DataFrame with processed InChIKey information
    """
    if n_workers is None:
        n_workers = cpu_count()
    
    print(f"Processing {len(df)} InChIKeys with {n_workers} workers...")
    
    # Create batches
    batches = []
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batches.append((
            batch_df.index.tolist(),
            batch_df['InChIKey'].tolist(),
            batch_df['std_smiles'].tolist() if 'std_smiles' in df.columns else [None] * len(batch_df)
        ))
    
    # Process batches in parallel
    with Pool(n_workers) as pool:
        batch_results = list(tqdm(
            pool.imap(process_inchikey_batch, batches),
            total=len(batches),
            desc="Processing InChIKeys"
        ))
    
    # Flatten results and create mapping
    all_results = []
    for batch_result in batch_results:
        all_results.extend(batch_result)
    
    # Sort by index to maintain order
    all_results.sort(key=lambda x: x[0])
    
    # Add results to dataframe
    for idx, processed_inchikey, additional_data in all_results:
        df.loc[idx, 'processed_inchikey'] = processed_inchikey
        for key, value in additional_data.items():
            df.loc[idx, key] = value
    
    return df

# ============================================================================
# PARALLEL DEDUPLICATION FUNCTIONS
# ============================================================================

def process_duplicate_group(group_data, response_cols, priority_func):
    """
    Process a single group of duplicate complexes.
    
    Args:
        group_data: Tuple of (complex_id, group_dataframe)
        response_cols: List of response column names
        priority_func: Function to calculate priority
    
    Returns:
        Processed row as dictionary
    """
    complex_id, group = group_data
    
    if len(group) == 1:
        # No duplicates, keep as is
        row_dict = group.iloc[0].to_dict()
        row_dict['num_merged'] = 1
        return row_dict
    else:
        # Multiple rows for same complex
        group = group.copy()
        group['priority'] = group['source_file'].apply(priority_func)
        group = group.sort_values('priority')
        
        # Take the row with highest priority as base
        best_row = group.iloc[0].to_dict()
        
        # Merge response values from all rows in the group
        for col in response_cols:
            values = group[col].dropna().values
            
            if len(values) == 0:
                best_row[col] = np.nan
            elif len(values) == 1:
                best_row[col] = values[0]
            elif len(values) == 2:
                best_row[col] = min(values)
            else:  # 3 or more values
                best_row[col] = np.median(values)
        
        # Add merge information
        best_row['merged_from_sources'] = ','.join(group['source_file'].unique())
        best_row['num_merged'] = len(group)
        
        return best_row

def parallel_remove_duplicates(df, n_workers=None):
    """
    Remove duplicates iteratively based on complex (protein+ligand).
    Creates and returns a new dataframe without modifying the original.
    
    Args:
        df: DataFrame with sequence_id and InChIKey columns
    
    Returns:
        result_df: New deduplicated DataFrame
    """
    # Create a copy to avoid modifying original
    df_copy = df.copy()
    
    # Define response value columns
    response_cols = ['pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency', 'pEC50', 'pKi']
    
    # Create complex identifier
    df_copy['complex_id'] = df_copy['InChIKey'] + '_' + df_copy['sequence_id']
    
    # Group by complex_id
    grouped = df_copy.groupby('complex_id')
    
    print(f"Processing {len(grouped)} unique complexes...")
    
    # Define priority function
    def get_priority(source):
        priority_map = {
            'PDBbind2020': 1,
            'FEP_Wang_2015': 2,
            'FEP_Zariquiey_extended_Wang_2015': 3,
            'HiQBind': 4,
            'BioLip2': 5,
            'processed_data': 6,
            'BindingNetv2': 7,
            'BindingNetv1': 8
        }
        return priority_map.get(source, 999)
    
    # Function to merge response values
    def merge_response_values(group, response_cols):
        """
        Merge response values from a group of duplicate rows.
        For each response type, select median if >=3 datapoints, else minimum.
        """
        merged = {}
        
        for col in response_cols:
            # Collect non-NaN values for this response type
            values = group[col].dropna().values
            
            if len(values) == 0:
                merged[col] = np.nan
            elif len(values) == 1:
                merged[col] = values[0]
            elif len(values) == 2:
                merged[col] = min(values)
            else:  # 3 or more values
                merged[col] = np.median(values)
        
        return merged
    
    # List to store processed rows
    result_rows = []
    
    # Iterate through each group
    for complex_id, group in tqdm(grouped, desc="Deduplicating complexes"):
        if len(group) == 1:
            # No duplicates, keep as is
            row_dict = group.iloc[0].to_dict()
            row_dict['num_merged'] = 1
            row_dict.pop('complex_id', None)  # Remove complex_id from result
            result_rows.append(row_dict)
        else:
            # Multiple rows for same complex
            # Create a copy of the group to avoid warnings
            group = group.copy()
            
            # Add priority column
            group['priority'] = group['source_file'].apply(get_priority)
            
            # Sort by priority (lowest number = highest priority)
            group = group.sort_values('priority')
            
            # Take the row with highest priority as base
            best_row = group.iloc[0].to_dict()
            
            # Merge response values from all rows in the group
            merged_values = merge_response_values(group, response_cols)
            
            # Update the best row with merged values
            for col, value in merged_values.items():
                best_row[col] = value
            
            # Add merge information
            best_row['merged_from_sources'] = ','.join(group['source_file'].unique())
            best_row['num_merged'] = len(group)
            
            # Remove temporary columns
            best_row.pop('priority', None)
            best_row.pop('complex_id', None)
            
            result_rows.append(best_row)
    
    # Create new dataframe from results
    result_df = pd.DataFrame(result_rows)
    
    # Ensure column order matches original (plus new columns)
    original_cols = [col for col in df.columns if col in result_df.columns]
    new_cols = [col for col in result_df.columns if col not in df.columns]
    result_df = result_df[original_cols + new_cols]
    
    print(f"Deduplication complete: {len(df)} -> {len(result_df)} rows")
    print(f"Removed {len(df) - len(result_df)} duplicate rows ({(1 - len(result_df)/len(df))*100:.2f}% reduction)")
    
    return result_df

# ============================================================================
# MAIN PIPELINE FUNCTION
# ============================================================================



# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def print_summary_statistics(original_df, final_df):
    """
    Print detailed summary statistics about the deduplication process.
    """
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    print(f"\nOriginal dataset:")
    print(f"  - Total rows: {len(original_df):,}")
    print(f"  - Unique proteins: {original_df['standardized_protein_pdb'].nunique():,}")
    print(f"  - Unique ligands: {original_df['InChIKey'].nunique():,}")
    
    print(f"\nProcessed dataset:")
    print(f"  - Total rows: {len(final_df):,}")
    print(f"  - Rows removed: {len(original_df) - len(final_df):,}")
    print(f"  - Reduction: {(1 - len(final_df)/len(original_df))*100:.2f}%")
    
    if 'num_merged' in final_df.columns:
        merged_stats = final_df['num_merged'].value_counts().sort_index()
        print(f"\nMerging statistics:")
        for num, count in merged_stats.items():
            print(f"  - {count:,} complexes {'kept as-is' if num == 1 else f'merged from {num} sources'}")
    
    print("\nResponse value coverage:")
    response_cols = ['pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency', 'pEC50', 'pKi']
    for col in response_cols:
        if col in final_df.columns:
            coverage = (final_df[col].notna().sum() / len(final_df)) * 100
            print(f"  - {col}: {coverage:.2f}% coverage ({final_df[col].notna().sum():,} values)")
    
    print("\n" + "="*60)

# ============================================================================
# EXAMPLE USAGE
# ============================================================================



In [None]:
### print("="*60)
print("STARTING PARALLEL PROCESSING PIPELINE")
print("="*60)
print(f"Initial dataframe shape: {df.shape}")
# Step 1: Parallel PDB sequence extraction
print("\n" + "="*60)
print("STEP 1: PARALLEL PDB SEQUENCE EXTRACTION")
print("="*60)
pdb_paths = df['standardized_protein_pdb'].unique().tolist()
pdb_to_sequences = parallel_extract_pdb_sequences(pdb_paths, cpu_count())


In [None]:
# Create global sequence mapping
seq_to_id, id_to_seq = create_global_sequence_mapping(pdb_to_sequences)
print(f"Found {len(seq_to_id)} unique protein sequences")
# Apply sequence IDs to dataframe
sequence_ids = []
for _, row in tqdm(df.iterrows()):
    pdb_path = row['standardized_protein_pdb']
    sequences = pdb_to_sequences.get(pdb_path, [])
    seq_id = create_sequence_id_for_pdb(sequences, seq_to_id)
    sequence_ids.append(seq_id)
df['sequence_id'] = sequence_ids
# Create sequence mapping dataframe
seq_mapping = pd.DataFrame([
    {'sequence_id': sid, 'sequence': seq} 
    for sid, seq in id_to_seq.items()
])


In [None]:
# Step 2: Parallel InChIKey processing
print("\n" + "="*60)
print("STEP 2: PARALLEL INCHIKEY PROCESSING")
print("="*60)
df = parallel_process_inchikeys(df, cpu_count())
# Step 3: Parallel deduplication
print("\n" + "="*60)
print("STEP 3: PARALLEL DEDUPLICATION")
print("="*60)


In [17]:
df

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,...,LE_potency,LEnorm_potency,LE,LE_norm,carbon_count,sequence_id,processed_inchikey,inchikey_hash,has_valid_inchikey,complex_id
0,../data/raw/BindingNetv2/high/target_CHEMBL390...,../data/raw/BindingNetv2/high/target_CHEMBL390...,CCCCCCSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O,3.259637,BindingNetv2,False,,,,,...,,,0.120727,0.000298,17,2ƘC,RILVFYFKIDXJNY-UHFFFAOYSA-M,2f8150df,True,RILVFYFKIDXJNY-UHFFFAOYSA-M_2ƘC
1,../data/raw/BindingNetv2/moderate/target_CHEMB...,../data/raw/BindingNetv2/moderate/target_CHEMB...,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NC(C(=O)O)c1cccc...,6.376751,BindingNetv2,False,,,,,...,,,0.193235,0.000409,23,2ƘC,ZPSKWMFLCHMEOY-UHFFFAOYSA-M,71a014c7,True,ZPSKWMFLCHMEOY-UHFFFAOYSA-M_2ƘC
2,../data/raw/BindingNetv2/high/target_CHEMBL390...,../data/raw/BindingNetv2/high/target_CHEMBL390...,Cc1ccc(CSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O)cc1,4.397940,BindingNetv2,False,,,,,...,,,0.151653,0.000357,19,2ƉY,MBXWAPNNAOGFPH-UHFFFAOYSA-M,1cba16d5,True,MBXWAPNNAOGFPH-UHFFFAOYSA-M_2ƉY
3,../data/raw/BindingNetv2/high/target_CHEMBL390...,../data/raw/BindingNetv2/high/target_CHEMBL390...,NC(CCC(=O)NC(CSCc1ccc(Cl)cc1)C(=O)NC(C(=O)O)c1...,6.920819,BindingNetv2,False,,,,,...,,,0.203553,0.000401,24,2ƘC,BXJSPWKYSSRFEB-UHFFFAOYSA-M,a823c50d,True,BXJSPWKYSSRFEB-UHFFFAOYSA-M_2ƘC
4,../data/raw/BindingNetv2/high/target_CHEMBL390...,../data/raw/BindingNetv2/high/target_CHEMBL390...,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NCCC(=O)O)C(=O)O,3.148742,BindingNetv2,False,,,,,...,,,0.112455,0.000274,18,2ŘN,QLVGMERIDWMEBM-UHFFFAOYSA-M,46313ec9,True,QLVGMERIDWMEBM-UHFFFAOYSA-M_2ŘN
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
406083,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CC1CCCC(N2NC(C(C)(C)C)C[C@@H]2N[C@@H](O)NC2CCC...,,BioLip2,True,,,,7.130768,...,,,0.187652,0.000351,29,¦Q,NACQWABIDVXOMV-VFCRXLDWSA-N,b325f1a8,True,NACQWABIDVXOMV-VFCRXLDWSA-N_¦Q
406084,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,OC[C@H]1O[C@H](O[PH](O)(O)O[PH](O)(O)OC[C@H]2O...,,BioLip2,True,,,,3.124939,...,,,0.086804,0.000151,15,ĔV,PEEYOHTXAULSGT-KPLOLNPKSA-N,ba696fe5,True,PEEYOHTXAULSGT-KPLOLNPKSA-N_ĔV
406085,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,C[C@H]1S[C@H]2NC(N)N[C@@H](O)[C@@H]2C1SC1CCC(C...,,BioLip2,True,,,,7.698970,...,,,0.248354,0.000515,19,ȚS,VYTCQXDFOQVDLB-WZBLSFTPSA-N,3c3e05eb,True,VYTCQXDFOQVDLB-WZBLSFTPSA-N_ȚS
406086,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CCCCCCC1CCC(C(O)NNC(S)NC)O1,,BioLip2,True,,,,4.096910,...,,,0.215627,0.000740,13,ǋO,QDYXIFNPTRANPQ-UHFFFAOYSA-N,217020ae,True,QDYXIFNPTRANPQ-UHFFFAOYSA-N_ǋO


In [19]:
final_df = parallel_remove_duplicates(df, cpu_count())
print("\n" + "="*60)
print("PROCESSING COMPLETE")
print("="*60)
print(f"Final dataframe shape: {final_df.shape}")
print(f"Removed {len(df) - len(final_df)} duplicate rows")
print(f"Reduction: {(1 - len(final_df)/len(df))*100:.2f}%")


Processing 395836 unique complexes...


Deduplicating complexes: 100%|██████████| 395836/395836 [01:30<00:00, 4389.68it/s]


Deduplication complete: 406088 -> 395836 rows
Removed 10252 duplicate rows (2.52% reduction)

PROCESSING COMPLETE
Final dataframe shape: (395836, 48)
Removed 10252 duplicate rows
Reduction: 2.52%


In [21]:
print_summary_statistics(df, final_df)

final_df.to_csv('deduplicated_complexes_parallel.csv', index=False)
seq_mapping.to_csv('sequence_mapping.csv', index=False)
    
print("Pipeline loaded successfully!")
print("To use: final_df, seq_mapping = process_dataframe_parallel(df)")


SUMMARY STATISTICS

Original dataset:
  - Total rows: 406,088
  - Unique proteins: 406,088
  - Unique ligands: 306,191

Processed dataset:
  - Total rows: 395,836
  - Rows removed: 10,252
  - Reduction: 2.52%

Merging statistics:
  - 388,888 complexes kept as-is
  - 5,986 complexes merged from 2 sources
  - 558 complexes merged from 3 sources
  - 185 complexes merged from 4 sources
  - 40 complexes merged from 5 sources
  - 40 complexes merged from 6 sources
  - 18 complexes merged from 7 sources
  - 24 complexes merged from 8 sources
  - 15 complexes merged from 9 sources
  - 6 complexes merged from 10 sources
  - 2 complexes merged from 11 sources
  - 12 complexes merged from 12 sources
  - 9 complexes merged from 13 sources
  - 7 complexes merged from 14 sources
  - 2 complexes merged from 15 sources
  - 2 complexes merged from 16 sources
  - 5 complexes merged from 17 sources
  - 2 complexes merged from 18 sources
  - 2 complexes merged from 19 sources
  - 1 complexes merged from 

NameError: name 'seq_mapping' is not defined

In [24]:
final_df

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,...,LEnorm_potency,LE,LE_norm,carbon_count,sequence_id,processed_inchikey,inchikey_hash,has_valid_inchikey,num_merged,merged_from_sources
0,../data/raw/BindingNetv2/moderate/target_CHEMB...,../data/raw/BindingNetv2/moderate/target_CHEMB...,CCOc1cc2ncc(C#N)c(Nc3ccc(OCc4nc5ccccc5s4)c(Cl)...,,BindingNetv2,False,,,,,...,,0.167298,0.000272,33,QS,AAAAZQPHATYWOK-JXMROGBWSA-O,4f129304,True,1,
1,../data/raw/BindingNetv2/moderate/target_CHEMB...,../data/raw/BindingNetv2/moderate/target_CHEMB...,CCOc1cc2ncc(C#N)c(Nc3ccc(OCc4nc5ccccc5s4)c(Cl)...,,BindingNetv2,False,,,,,...,,0.156415,0.000255,33,ãK,AAAAZQPHATYWOK-JXMROGBWSA-O,4f129304,True,1,
2,../data/raw/BindingNetv2/high/target_CHEMBL258...,../data/raw/BindingNetv2/high/target_CHEMBL258...,Cc1ccc(O)cc1Nc1cc(N2CCOCC2)nc(-n2cnc3ccccc32)n1,,BindingNetv2,False,,,,,...,,0.228776,0.000568,22,ȭG,AAABTPAECTZDET-UHFFFAOYSA-N,62568322,True,1,
3,../data/raw/BindingNetv2/high/target_CHEMBL297...,../data/raw/BindingNetv2/high/target_CHEMBL297...,COC(=O)C[C@@H](NC(=O)c1ccc(-c2cn[nH]c2)c(C)c1)...,,BindingNetv2,False,,,,,...,,0.264950,0.000729,21,ǱN,AAACGYYPWMUUFL-LJQANCHMSA-N,b8ce4f00,True,1,
4,../data/raw/BindingNetv2/high/target_CHEMBL323...,../data/raw/BindingNetv2/high/target_CHEMBL323...,COC(=O)C[C@@H](NC(=O)c1ccc(-c2cn[nH]c2)c(C)c1)...,,BindingNetv2,False,,,,,...,,0.236761,0.000651,21,ȍV,AAACGYYPWMUUFL-LJQANCHMSA-N,b8ce4f00,True,1,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
395831,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NP(=O)(O[H])O[H])N([H])C([H])([H])C([...,,PDBbind2020,True,1.85,,,3.070581,...,,0.191911,0.000755,6,ÆIăS,,d41d8cd9,True,1,
395832,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(C2:C([H]):N:N(C([H])([H]...,6.873234,PDBbind2020,True,1.90,,,,...,,0.242082,0.000679,16,ôX,,d41d8cd9,True,6,PDBbind2020
395833,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(Cl):C(N2C([H])([H])C([H]...,5.867740,PDBbind2020,True,1.70,,,,...,,0.279416,0.000899,13,ôY,,d41d8cd9,True,2,PDBbind2020
395834,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(Cl):C(N([H])[H]):N:C:1N(...,5.275724,PDBbind2020,True,2.10,,,,...,,0.351715,0.001538,7,õB,,d41d8cd9,True,1,


# Task range

In [6]:
df = pd.read_csv("deduplicated_complexes_parallel.csv")

In [7]:
df

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,...,LEnorm_potency,LE,LE_norm,carbon_count,sequence_id,processed_inchikey,inchikey_hash,has_valid_inchikey,num_merged,merged_from_sources
0,../data/raw/BindingNetv2/moderate/target_CHEMB...,../data/raw/BindingNetv2/moderate/target_CHEMB...,CCOc1cc2ncc(C#N)c(Nc3ccc(OCc4nc5ccccc5s4)c(Cl)...,,BindingNetv2,False,,,,,...,,0.167298,0.000272,33,QS,AAAAZQPHATYWOK-JXMROGBWSA-O,4f129304,True,1,
1,../data/raw/BindingNetv2/moderate/target_CHEMB...,../data/raw/BindingNetv2/moderate/target_CHEMB...,CCOc1cc2ncc(C#N)c(Nc3ccc(OCc4nc5ccccc5s4)c(Cl)...,,BindingNetv2,False,,,,,...,,0.156415,0.000255,33,ãK,AAAAZQPHATYWOK-JXMROGBWSA-O,4f129304,True,1,
2,../data/raw/BindingNetv2/high/target_CHEMBL258...,../data/raw/BindingNetv2/high/target_CHEMBL258...,Cc1ccc(O)cc1Nc1cc(N2CCOCC2)nc(-n2cnc3ccccc32)n1,,BindingNetv2,False,,,,,...,,0.228776,0.000568,22,ȭG,AAABTPAECTZDET-UHFFFAOYSA-N,62568322,True,1,
3,../data/raw/BindingNetv2/high/target_CHEMBL297...,../data/raw/BindingNetv2/high/target_CHEMBL297...,COC(=O)C[C@@H](NC(=O)c1ccc(-c2cn[nH]c2)c(C)c1)...,,BindingNetv2,False,,,,,...,,0.264950,0.000729,21,ǱN,AAACGYYPWMUUFL-LJQANCHMSA-N,b8ce4f00,True,1,
4,../data/raw/BindingNetv2/high/target_CHEMBL323...,../data/raw/BindingNetv2/high/target_CHEMBL323...,COC(=O)C[C@@H](NC(=O)c1ccc(-c2cn[nH]c2)c(C)c1)...,,BindingNetv2,False,,,,,...,,0.236761,0.000651,21,ȍV,AAACGYYPWMUUFL-LJQANCHMSA-N,b8ce4f00,True,1,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
395831,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NP(=O)(O[H])O[H])N([H])C([H])([H])C([...,,PDBbind2020,True,1.85,,,3.070581,...,,0.191911,0.000755,6,ÆIăS,,d41d8cd9,True,1,
395832,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(C2:C([H]):N:N(C([H])([H]...,6.873234,PDBbind2020,True,1.90,,,,...,,0.242082,0.000679,16,ôX,,d41d8cd9,True,6,PDBbind2020
395833,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(Cl):C(N2C([H])([H])C([H]...,5.867740,PDBbind2020,True,1.70,,,,...,,0.279416,0.000899,13,ôY,,d41d8cd9,True,2,PDBbind2020
395834,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(\NC(=O)C1:N:C(Cl):C(N([H])[H]):N:C:1N(...,5.275724,PDBbind2020,True,2.10,,,,...,,0.351715,0.001538,7,õB,,d41d8cd9,True,1,


In [8]:

# Calculate task ranges for weighting
task_ranges = prepare_mtl_experiment(df, CONFIG['task_cols'])

# Create task statistics
print("\nTask Statistics:")
for task in CONFIG['task_cols']:
    if task in df.columns:
        valid_values = df[task].dropna()
        if len(valid_values) > 0:
            print(f"{task}:")
            print(f"  Count: {len(valid_values)}")
            print(f"  Mean: {valid_values.mean():.2f}")
            print(f"  Std: {valid_values.std():.2f}")
            print(f"  Range: [{valid_values.min():.2f}, {valid_values.max():.2f}]")

Task ranges for weighting:
  pKi: range=10.00, weight=0.0627
  pEC50: range=9.95, weight=0.0629
  pKd (Wang, FEP): range=4.90, weight=0.1278
  pKd: range=10.92, weight=0.0574
  pIC50: range=10.00, weight=0.0627
  potency: range=1.00, weight=0.6266

Task Statistics:
pKi:
  Count: 95436
  Mean: 7.02
  Std: 1.43
  Range: [3.00, 13.00]
pEC50:
  Count: 18266
  Mean: 6.76
  Std: 1.30
  Range: [3.05, 13.00]
pKd (Wang, FEP):
  Count: 262
  Mean: 6.88
  Std: 1.00
  Range: [4.24, 9.14]
pKd:
  Count: 16114
  Mean: 6.52
  Std: 1.53
  Range: [3.00, 13.92]
pIC50:
  Count: 267096
  Mean: 6.81
  Std: 1.31
  Range: [3.00, 13.00]


In [9]:
CONFIG['structure_chunks_dir']

'../input/chunk/'

In [10]:
from rdkit import RDLogger

# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')


In [11]:
CONFIG['task_cols']

['pKi', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency']

In [None]:

# Create chunk loader
chunk_loader = StructureChunkLoader(
    chunk_dir=CONFIG['structure_chunks_dir'],
    cache_size=10
)



IOStream.flush timed out


# 4 : Cross-Validation Training

In [12]:
# Cell 9: Cross-Validation Training
# Initialize cross-validator
cv = CrossValidator(
    model_config=CONFIG['model_config'],
    task_cols=CONFIG['task_cols'],
    task_ranges=task_ranges,
    n_folds=CONFIG['training_config']['n_folds'],
    batch_size=CONFIG['training_config']['batch_size'],
    n_epochs=CONFIG['training_config']['n_epochs'],
    learning_rate=CONFIG['training_config']['learning_rate'],
    patience=CONFIG['training_config']['patience'],
    device=device,
    seed=SEED
)


In [None]:

# Run cross-validation
print("\nStarting cross-validation...")
cv_results = cv.run(df, chunk_loader)


In [None]:
# Need to clear memory between each fold

In [None]:
# 100-200 epochs

In [None]:
# Print summary
cv.print_summary()

# Save CV results
save_results(
    cv_results, 
    os.path.join(CONFIG['experiment_dir'], 'results', 'cv_results.pkl'),
    format='pickle'
)

# 4 : Full Training on All Data

In [None]:
# Create train/valid/test splits
from gnn_dta_mtl.datasets import create_data_splits

splits = create_data_splits(
    df,
    split_method='random',  # or 'scaffold', 'protein', 'drug'
    split_frac=[0.7, 0.1, 0.2],
    seed=SEED
)

df_train = splits['train']
df_valid = splits['valid']
df_test = splits['test']

print(f"Train: {len(df_train)}, Valid: {len(df_valid)}, Test: {len(df_test)}")

# Build datasets
train_dataset = build_mtl_dataset_optimized(df_train, chunk_loader, CONFIG['task_cols'])
valid_dataset = build_mtl_dataset_optimized(df_valid, chunk_loader, CONFIG['task_cols'])
test_dataset = build_mtl_dataset_optimized(df_test, chunk_loader, CONFIG['task_cols'])

# Create data loaders
train_loader = torch_geometric.loader.DataLoader(
    train_dataset,
    batch_size=CONFIG['training_config']['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

valid_loader = torch_geometric.loader.DataLoader(
    valid_dataset,
    batch_size=CONFIG['training_config']['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = torch_geometric.loader.DataLoader(
    test_dataset,
    batch_size=CONFIG['training_config']['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
)


In [None]:
os.path.join(CONFIG['experiment_dir'], 'models', 'final_model.pt')

# multi gpu

In [4]:
# Cell 1: Setup and Configuration
import os
import subprocess
import json
import pandas as pd
from pathlib import Path
from datetime import datetime
import time
import threading
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

# Your existing configuration
CONFIG = {
    'data_path': '../input/combined/deduplicated_complexes_parallel.csv',
    'structure_chunks_dir': '../input/chunk/',
    'task_cols': ['pKi', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency'],
    
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    },
    
    'training_config': {
        'batch_size_per_gpu': 512,  # With 16 GPUs = 8192 total batch size
        'n_epochs': 200,
        'learning_rate': 0.001,
        'patience': 30,
        'n_folds': 5
    },
    
    'seed': 42,
    'gradient_accumulation_steps': 1
}

# Create directories
experiment_name = f'ddp_training_{datetime.now():%Y%m%d_%H%M%S}'
experiment_dir = Path(f'../output/experiments/{experiment_name}')
experiment_dir.mkdir(parents=True, exist_ok=True)

CONFIG['checkpoint_dir'] = str(experiment_dir / 'checkpoints')
CONFIG['log_dir'] = str(experiment_dir / 'logs')

# Save config
config_path = experiment_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print(f"Experiment: {experiment_name}")
print(f"Config saved to: {config_path}")
print(f"Using {torch.cuda.device_count()} GPUs")

Experiment: ddp_training_20250926_073940
Config saved to: ../output/experiments/ddp_training_20250926_073940/config.json
Using 16 GPUs


In [6]:
# Cell 2: Launch DDP Training
def launch_ddp_training(config_path, mode='train', n_gpus=16):
    """
    Launch DDP training using torchrun.
    """
    cmd = [
        'torchrun',
        '--nproc_per_node', str(n_gpus),
        '--master_port', '12355',
        '../training/launch_training.py',
        '--config', str(config_path),
        '--mode', mode,
        '--n_gpus', str(n_gpus)
    ]
    
    print(f"Launching {mode} with command:")
    print(' '.join(cmd))
    
    # Launch process
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )
    
    return process

# Launch training
training_process = launch_ddp_training(config_path, mode='train', n_gpus=16)

Launching train with command:
torchrun --nproc_per_node 16 --master_port 12355 ../training/launch_training.py --config ../output/experiments/ddp_training_20250926_073940/config.json --mode train --n_gpus 16


In [7]:
# Cell 3: Monitor Training Progress
class TrainingMonitor:
    """Monitor DDP training progress from logs."""
    
    def __init__(self, log_dir, update_interval=5):
        self.log_dir = Path(log_dir)
        self.update_interval = update_interval
        self.metrics_file = self.log_dir / 'metrics_*.json'
        self.running = False
        self.thread = None
        
    def start(self):
        """Start monitoring."""
        self.running = True
        self.thread = threading.Thread(target=self._monitor_loop)
        self.thread.start()
        
    def stop(self):
        """Stop monitoring."""
        self.running = False
        if self.thread:
            self.thread.join()
    
    def _monitor_loop(self):
        """Main monitoring loop."""
        while self.running:
            try:
                # Find latest metrics file
                metrics_files = list(self.log_dir.glob('metrics_*.json'))
                if metrics_files:
                    latest_file = max(metrics_files, key=os.path.getmtime)
                    
                    # Load metrics
                    with open(latest_file, 'r') as f:
                        metrics = json.load(f)
                    
                    if metrics:
                        self._display_progress(metrics)
            except Exception as e:
                print(f"Error reading metrics: {e}")
            
            time.sleep(self.update_interval)
    
    def _display_progress(self, metrics):
        """Display training progress."""
        clear_output(wait=True)
        
        latest = metrics[-1] if metrics else {}
        
        print(f"{'='*60}")
        print(f"TRAINING PROGRESS")
        print(f"{'='*60}")
        
        if 'epoch' in latest:
            print(f"Epoch: {latest['epoch']}")
            print(f"Train Loss: {latest.get('train_loss', 0):.4f}")
            print(f"Valid Loss: {latest.get('val_loss', 0):.4f}")
            print(f"Batch Time: {latest.get('batch_time', 0):.2f}s")
            print(f"GPU Memory: {latest.get('gpu_memory_gb', 0):.1f}GB")
            
            if 'task_metrics' in latest:
                print("\nTask Metrics:")
                for task, task_metrics in latest['task_metrics'].items():
                    print(f"  {task}: RMSE={task_metrics['rmse']:.3f}, R²={task_metrics['r2']:.3f}")
        
        # Plot loss curves
        if len(metrics) > 1:
            epochs = [m['epoch'] for m in metrics]
            train_losses = [m.get('train_loss', 0) for m in metrics]
            val_losses = [m.get('val_loss', 0) for m in metrics]
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
            
            ax1.plot(epochs, train_losses, label='Train Loss')
            ax1.plot(epochs, val_losses, label='Valid Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training Progress')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Plot R2 scores
            if 'task_metrics' in metrics[-1]:
                tasks = list(metrics[-1]['task_metrics'].keys())
                r2_scores = {task: [] for task in tasks}
                
                for m in metrics:
                    if 'task_metrics' in m:
                        for task in tasks:
                            if task in m['task_metrics']:
                                r2_scores[task].append(m['task_metrics'][task]['r2'])
                
                for task, scores in r2_scores.items():
                    if scores:
                        ax2.plot(epochs[:len(scores)], scores, label=task)
                
                ax2.set_xlabel('Epoch')
                ax2.set_ylabel('R² Score')
                ax2.set_title('Task Performance')
                ax2.legend()
                ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()

# Start monitoring
monitor = TrainingMonitor(CONFIG['log_dir'])
monitor.start()

# To stop monitoring:
# monitor.stop()

# CV multi

In [None]:
# Cell 4: Launch Cross-Validation
# Stop previous monitoring if running
if 'monitor' in locals():
    monitor.stop()

# Launch CV
cv_process = launch_ddp_training(config_path, mode='cv', n_gpus=16)

# Start monitoring for CV
cv_monitor = TrainingMonitor(CONFIG['log_dir'])
cv_monitor.start()

In [None]:
# Cell 5: Check Training Status and Results
def check_training_status(checkpoint_dir):
    """Check training status from checkpoints."""
    checkpoint_dir = Path(checkpoint_dir)
    
    # Check for best model
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        checkpoint = torch.load(best_model_path, map_location='cpu')
        print("Best Model Found!")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Validation Loss: {checkpoint.get('val_loss', 'N/A'):.4f}")
        
        if 'task_metrics' in checkpoint:
            print("\n  Task Metrics:")
            for task, metrics in checkpoint['task_metrics'].items():
                print(f"    {task}: RMSE={metrics['rmse']:.3f}, R²={metrics['r2']:.3f}")
    
    # Check for CV results
    cv_results_path = checkpoint_dir / 'cv_results.json'
    if cv_results_path.exists():
        with open(cv_results_path, 'r') as f:
            cv_results = json.load(f)
        
        print("\nCross-Validation Results:")
        for task, metrics in cv_results.items():
            print(f"  {task}:")
            print(f"    R²: {metrics['r2_mean']:.3f} ± {metrics['r2_std']:.3f}")
            print(f"    RMSE: {metrics['rmse_mean']:.3f} ± {metrics['rmse_std']:.3f}")

check_training_status(CONFIG['checkpoint_dir'])

In [None]:
# Cell 6: Load Best Model for Inference
def load_best_model(checkpoint_dir, model_config, task_cols):
    """Load the best trained model."""
    from gnn_dta_mtl import MTL_DTAModel
    
    # Create model
    model = MTL_DTAModel(
        task_names=task_cols,
        **model_config
    )
    
    # Load checkpoint
    checkpoint_path = Path(checkpoint_dir) / 'best_model.pt'
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Loaded model from epoch {checkpoint['epoch']}")
    return model

# Load model
best_model = load_best_model(
    CONFIG['checkpoint_dir'],
    CONFIG['model_config'],
    CONFIG['task_cols']
)

# Move to GPU for inference
best_model = best_model.cuda()
best_model.eval()

# Cell 4: Launch Cross-Validation
# Stop previous monitoring if running
if 'monitor' in locals():
    monitor.stop()

# Launch CV
cv_process = launch_ddp_training(config_path, mode='cv', n_gpus=16)

# Start monitoring for CV
cv_monitor = TrainingMonitor(CONFIG['log_dir'])
cv_monitor.start()# 6 : Evaluate Model

In [None]:
# Evaluate on test set (for full training)
if 'test_loader' in locals():
    print("\nEvaluating on test set...")
    test_results = evaluate_model(model, test_loader, CONFIG['task_cols'], device)
    
    # Print test results
    print("\nTest Results:")
    for task, metrics in test_results.items():
        print(f"\n{task}:")
        for metric_name, value in metrics.items():
            if isinstance(value, float):
                print(f"  {metric_name}: {value:.4f}")
    
    # Save test results
    save_results(
        test_results,
        os.path.join(CONFIG['experiment_dir'], 'results', 'test_results.json'),
        format='json'
    )

# For cross-validation results
else:
    print("\nCross-Validation Summary:")
    summary_df = create_summary_report(
        cv_results,
        CONFIG['task_cols'],
        os.path.join(CONFIG['experiment_dir'], 'results', 'cv_summary.csv')
    )
    print(summary_df)

# 7 : Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Plot CV results
if 'cv_results' in locals():
    fig = plot_results(
        cv_results,
        CONFIG['task_cols'],
        save_path=os.path.join(CONFIG['experiment_dir'], 'figures', 'cv_results.png')
    )

# Plot training history
if 'trainer' in locals() and hasattr(trainer, 'train_losses'):
    from gnn_dta_mtl.evaluation.visualization import plot_training_history
    plot_training_history(
        trainer.train_losses[:-22] + trainer.train_losses[-21:],
        trainer.val_losses[:-22] + trainer.val_losses[-21:],
        save_path=os.path.join(CONFIG['experiment_dir'], 'figures', 'training_history.png')
    )

# Plot metrics distribution across folds
if 'cv_results' in locals():
    from gnn_dta_mtl.evaluation.visualization import plot_metrics_distribution
    plot_metrics_distribution(
        cv_results,
        CONFIG['task_cols'],
        metric='r2',
        save_path=os.path.join(CONFIG['experiment_dir'], 'figures', 'r2_distribution.png')
    )
    
    plot_metrics_distribution(
        cv_results,
        CONFIG['task_cols'],
        metric='rmse',
        save_path=os.path.join(CONFIG['experiment_dir'], 'figures', 'rmse_distribution.png')
    )

In [None]:
# Analyze prediction errors
if 'cv_results' in locals():
    for task in CONFIG['task_cols']:
        if len(cv_results[task]['all_targets']) > 0:
            targets = np.array(cv_results[task]['all_targets'])
            preds = np.array(cv_results[task]['all_predictions'])
            
            # Calculate residuals
            residuals = targets - preds
            
            # Create figure with subplots
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            
            # 1. Predictions vs Targets
            ax = axes[0, 0]
            from gnn_dta_mtl.evaluation.visualization import plot_predictions
            plot_predictions(targets, preds, task, ax)
            
            # 2. Residuals plot
            ax = axes[0, 1]
            from gnn_dta_mtl.evaluation.visualization import plot_residuals
            plot_residuals(targets, preds, task, ax)
            
            # 3. Residual distribution
            ax = axes[1, 0]
            ax.hist(residuals, bins=50, edgecolor='black', alpha=0.7)
            ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
            ax.set_xlabel('Residuals')
            ax.set_ylabel('Frequency')
            ax.set_title(f'Residual Distribution - {task}')
            
            # 4. Q-Q plot
            ax = axes[1, 1]
            from scipy import stats
            stats.probplot(residuals, dist="norm", plot=ax)
            ax.set_title(f'Q-Q Plot - {task}')
            
            plt.suptitle(f'Detailed Analysis - {task}', fontsize=14)
            plt.tight_layout()
            plt.savefig(
                os.path.join(CONFIG['experiment_dir'], 'figures', f'analysis_{task}.png'),
                dpi=300, bbox_inches='tight'
            )
            plt.show()

In [None]:
# Cell 14: Bootstrap Analysis
from gnn_dta_mtl.evaluation.metrics import bootstrap_metrics

if 'cv_results' in locals():
    print("\nBootstrap Confidence Intervals (95%):")
    
    for task in CONFIG['task_cols']:
        if len(cv_results[task]['all_targets']) > 0:
            targets = np.array(cv_results[task]['all_targets'])
            preds = np.array(cv_results[task]['all_predictions'])
            
            # Calculate bootstrap CIs
            boot_results = bootstrap_metrics(
                targets, preds,
                n_bootstrap=1000,
                confidence=0.95,
                seed=SEED
            )
            
            print(f"\n{task}:")
            for metric, (mean, lower, upper) in boot_results.items():
                print(f"  {metric}: {mean:.3f} [{lower:.3f}, {upper:.3f}]")

In [None]:
# Cell 15: Analyze Model Features
if 'model' in locals():
    from gnn_dta_mtl.utils.model_utils import (
        count_parameters, 
        get_model_size,
        get_activation_stats
    )
    
    # Model statistics
    n_params = count_parameters(model)
    model_size = get_model_size(model)
    
    print("Model Statistics:")
    print(f"  Total parameters: {n_params:,}")
    print(f"  Model size: {model_size['total_size_mb']:.2f} MB")
    
    # Get activation statistics
    if 'test_loader' in locals():
        act_stats = get_activation_stats(model, test_loader, device)
        
        # Visualize activation statistics
        layers = list(act_stats.keys())[-10:]  # Last 10 layers
        means = [act_stats[l]['mean'] for l in layers]
        stds = [act_stats[l]['std'] for l in layers]
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
        
        x = range(len(layers))
        ax1.bar(x, means, alpha=0.7)
        ax1.set_xticks(x)
        ax1.set_xticklabels(layers, rotation=45, ha='right')
        ax1.set_ylabel('Mean Activation')
        ax1.set_title('Layer Activation Statistics')
        
        ax2.bar(x, stds, alpha=0.7, color='orange')
        ax2.set_xticks(x)
        ax2.set_xticklabels(layers, rotation=45, ha='right')
        ax2.set_ylabel('Std Activation')
        
        plt.tight_layout()
        plt.savefig(
            os.path.join(CONFIG['experiment_dir'], 'figures', 'activation_stats.png'),
            dpi=300, bbox_inches='tight'
        )
        plt.show()

# Export


In [None]:
# Cell 16: Export Final Results
import json
from datetime import datetime

# Compile all results
final_results = {
    'experiment_name': experiment_name,
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    'data_stats': {
        'total_samples': len(df),
        'train_samples': len(df_train) if 'df_train' in locals() else None,
        'valid_samples': len(df_valid) if 'df_valid' in locals() else None,
        'test_samples': len(df_test) if 'df_test' in locals() else None
    }
}

# Add CV results summary
if 'cv' in locals() and hasattr(cv, 'summary'):
    final_results['cv_summary'] = cv.summary

# Add test results
if 'test_results' in locals():
    final_results['test_results'] = test_results

# Save comprehensive report
report_path = os.path.join(CONFIG['experiment_dir'], 'final_report.json')
with open(report_path, 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

print(f"\nFinal report saved to: {report_path}")
print(f"All results saved in: {CONFIG['experiment_dir']}")

# Create LaTeX table for publication
if 'cv' in locals() and hasattr(cv, 'summary'):
    print("\nLaTeX Table for Publication:")
    print("\\begin{table}[h]")
    print("\\centering")
    print("\\begin{tabular}{lcccc}")
    print("\\hline")
    print("Task & R² & RMSE & MAE & N \\\\")
    print("\\hline")
    
    for task, metrics in cv.summary.items():
        print(f"{task} & "
              f"{metrics['r2_mean']:.3f}$\\pm${metrics['r2_std']:.3f} & "
              f"{metrics['rmse_mean']:.3f}$\\pm${metrics['rmse_std']:.3f} & "
              f"- & "
              f"{metrics['n_samples']} \\\\")
    
    print("\\hline")
    print("\\end{tabular}")
    print("\\caption{Cross-validation results for multi-task drug-target affinity prediction}")
    print("\\end{table}")

In [None]:
# Cell 17: Interactive Analysis Functions
def analyze_predictions_by_property(task, property_col='MolWt', n_bins=5):
    """Analyze predictions by molecular property"""
    if task not in cv_results or len(cv_results[task]['all_targets']) == 0:
        print(f"No results for {task}")
        return
    
    # Get predictions and targets
    targets = np.array(cv_results[task]['all_targets'])
    preds = np.array(cv_results[task]['all_predictions'])
    
    # Get property values (need to match with original df)
    # This assumes df is still aligned with cv_results
    property_values = df[property_col].values[:len(targets)]
    
    # Create bins
    bins = pd.qcut(property_values, n_bins, labels=False, duplicates='drop')
    
    # Calculate metrics per bin
    from sklearn.metrics import r2_score, mean_squared_error
    
    results = []
    for bin_idx in range(n_bins):
        mask = bins == bin_idx
        if mask.sum() > 0:
            r2 = r2_score(targets[mask], preds[mask])
            rmse = np.sqrt(mean_squared_error(targets[mask], preds[mask]))
            results.append({
                'bin': bin_idx,
                'n_samples': mask.sum(),
                'r2': r2,
                'rmse': rmse,
                f'{property_col}_mean': property_values[mask].mean()
            })
    
    results_df = pd.DataFrame(results)
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.bar(results_df['bin'], results_df['r2'], alpha=0.7)
    ax1.set_xlabel(f'{property_col} Bin')
    ax1.set_ylabel('R²')
    ax1.set_title(f'R² by {property_col} - {task}')
    
    ax2.bar(results_df['bin'], results_df['rmse'], alpha=0.7, color='orange')
    ax2.set_xlabel(f'{property_col} Bin')
    ax2.set_ylabel('RMSE')
    ax2.set_title(f'RMSE by {property_col} - {task}')
    
    plt.tight_layout()
    plt.show()
    
    return results_df

# Example usage
if 'cv_results' in locals():
    for task in CONFIG['task_cols'][:1]:  # Analyze first task
        results_by_mw = analyze_predictions_by_property(task, 'MolWt')
        results_by_logp = analyze_predictions_by_property(task, 'LogP')

In [None]:
# Cell 18: Save Session State
import pickle

# Save important objects
session_state = {
    'config': CONFIG,
    'task_ranges': task_ranges,
    'cv_results': cv_results if 'cv_results' in locals() else None,
    'test_results': test_results if 'test_results' in locals() else None,
    'df_stats': {
        'shape': df.shape,
        'columns': df.columns.tolist(),
        'task_coverage': {task: df[task].notna().sum() for task in CONFIG['task_cols']}
    }
}

session_path = os.path.join(CONFIG['experiment_dir'], 'session_state.pkl')
with open(session_path, 'wb') as f:
    pickle.dump(session_state, f)

print(f"Session state saved to: {session_path}")
print("\nTo restore session in a new notebook:")
print(f"with open('{session_path}', 'rb') as f:")
print("    session_state = pickle.load(f)")

In [None]:
1