# LP-IOANet Training Notebook

This notebook runs the LP-IOANet training pipeline with research-grade ResearchTracker metrics.

## Features
- GPU verification and configuration
- ResearchTracker metrics monitoring (PSNR, SSIM, MAE)
- Real-time training progress
- TensorBoard integration
- Training utilities and analysis tools

## 1. Environment Setup

In [None]:
import os
import sys
from pathlib import Path
import subprocess
import torch
import numpy as np

# Set up project path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))

print(f"Project Root: {project_root}")
print(f"Python Version: {sys.version}")
print(f"PyTorch Version: {torch.__version__}")

## 2. GPU Verification

In [None]:
# Check GPU availability
print("=" * 80)
print("GPU VERIFICATION")
print("=" * 80)

if torch.cuda.is_available():
    print(f"✓ CUDA Available: YES")
    print(f"✓ GPU Count: {torch.cuda.device_count()}")
    print(f"✓ Current Device: {torch.cuda.current_device()}")
    print(f"✓ GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Reset GPU memory
    torch.cuda.empty_cache()
    print(f"✓ GPU Memory Cleared")
else:
    print("✗ CUDA Not Available - Training will run on CPU (very slow!)")
    print("  Recommendation: Use NVIDIA GPU for reasonable training time")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice to use: {device}")
print("=" * 80)

## 3. Dataset Verification

In [None]:
# Check dataset structure
print("\n" + "=" * 80)
print("DATASET VERIFICATION")
print("=" * 80)

dataset_root = project_root / "Dataset"

if dataset_root.exists():
    print(f"✓ Dataset directory found: {dataset_root}")
    
    # Check training data
    train_dir = dataset_root / "train"
    train_input = train_dir / "input"
    train_target = train_dir / "target"
    
    if train_input.exists() and train_target.exists():
        train_count = len(list(train_input.glob("*.png"))) or len(list(train_input.glob("*.jpg")))
        print(f"✓ Training set: {train_count} image pairs")
    else:
        print(f"✗ Training data directories not found")
        print(f"  Expected: {train_input}, {train_target}")
    
    # Check test data
    test_dir = dataset_root / "test"
    test_input = test_dir / "input"
    test_target = test_dir / "target"
    
    if test_input.exists() and test_target.exists():
        test_count = len(list(test_input.glob("*.png"))) or len(list(test_input.glob("*.jpg")))
        print(f"✓ Test set: {test_count} image pairs")
    else:
        print(f"✗ Test data directories not found")
        print(f"  Expected: {test_input}, {test_target}")
else:
    print(f"✗ Dataset directory not found: {dataset_root}")
    print(f"  Please create the dataset structure:")
    print(f"  Dataset/")
    print(f"  ├── train/")
    print(f"  │   ├── input/")
    print(f"  │   └── target/")
    print(f"  └── test/")
    print(f"      ├── input/")
    print(f"      └── target/")

print("=" * 80)

## 4. Import Training Components

In [None]:
# Import training modules
try:
    from src.models import IOANet, LPIOANet, build_model
    from src.data import create_dataloaders
    from src.losses import ShadowLoss
    from src.utils import ResearchTracker
    print("✓ All training modules imported successfully")
except ImportError as e:
    print(f"✗ Import error: {e}")
    print("  Make sure all source files are in place")

## 5. Training Configuration

In [None]:
import yaml

# Load training configuration
config_path = project_root / "config" / "config.yaml"

if config_path.exists():
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    print("\nTraining Configuration:")
    print("=" * 80)
    
    # Stage 1 config
    stage1 = config['training']['stage1']
    print(f"\nStage 1 (IOANet):")
    print(f"  Epochs: {stage1['epochs']}")
    print(f"  Batch Size: {stage1['batch_size']}")
    print(f"  Input Resolution: 192x256")
    print(f"  Learning Rate: {stage1['learning_rate']}")
    
    # Stage 2 config
    stage2 = config['training']['stage2']
    print(f"\nStage 2 (LP-IOANet):")
    print(f"  Epochs: {stage2['epochs']}")
    print(f"  Batch Size: {stage2['batch_size']}")
    print(f"  Input Resolution: 768x1024")
    print(f"  Learning Rate: {stage2['learning_rate']}")
    
    # Loss weights
    print(f"\nLoss Configuration:")
    print(f"  Stage 1 L1 Weight: {stage1.get('loss_weights', {}).get('l1', 10.0)}")
    print(f"  Stage 1 LPIPS Weight: {stage1.get('loss_weights', {}).get('lpips', 5.0)}")
    print(f"  Stage 2 L1 Weight: {stage2.get('loss_weights', {}).get('l1', 1.0)}")
    
    print("=" * 80)
else:
    print(f"✗ Config file not found: {config_path}")

## 6. Run Training - Stage 1

In [None]:
# Run training script for Stage 1
print("\n" + "=" * 80)
print("STARTING STAGE 1 TRAINING (IOANet at 192x256)")
print("=" * 80)
print("\nThis will train for 1000 epochs with ResearchTracker metrics.")
print("Metrics: PSNR, SSIM, MAE, Loss components")
print("\nPress Ctrl+C to interrupt training")
print("=" * 80 + "\n")

import subprocess

# Run training
result = subprocess.run(
    [sys.executable, "train.py", "--stage", "1"],
    cwd=project_root,
    capture_output=False
)

print(f"\nTraining exit code: {result.returncode}")
if result.returncode == 0:
    print("✓ Stage 1 training completed successfully")
else:
    print(f"✗ Training failed with exit code {result.returncode}")

## 7. Training Monitoring - Real-time Metrics

In [None]:
# Monitor training progress from logs
import json
from datetime import datetime

log_dir = project_root / "logs"

if log_dir.exists():
    print("\nTraining Logs Found:")
    print("=" * 80)
    
    # List recent runs
    runs = sorted(log_dir.glob("run_*"), key=lambda x: x.stat().st_mtime, reverse=True)
    
    if runs:
        latest_run = runs[0]
        print(f"Latest run: {latest_run.name}")
        print(f"\nTo monitor training with TensorBoard:")
        print(f"  tensorboard --logdir={log_dir}")
        print(f"\nThen open: http://localhost:6006")
        print("\nAvailable metrics:")
        print("  - train/psnr, train/ssim, train/mae")
        print("  - val/psnr, val/ssim, val/mae")
        print("  - train/loss, val/loss")
        print("  - train/lr (learning rate)")
    else:
        print("No training runs found")
else:
    print("Logs directory not yet created (will be created during training)")

## 8. Run Training - Stage 2 (Optional)

In [None]:
# Run Stage 2 training (after Stage 1 is complete)
# Uncomment to run Stage 2

print("\n" + "=" * 80)
print("STARTING STAGE 2 TRAINING (LP-IOANet at 768x1024)")
print("=" * 80)
print("\n⚠️  Note: Stage 2 should only run after Stage 1 completes successfully")
print("    Ensure MAE < 0.05 before proceeding")
print("\nUncomment the code below to start Stage 2 training")
print("=" * 80)

# Uncomment to run:
# result = subprocess.run(
#     [sys.executable, "train.py", "--stage", "2"],
#     cwd=project_root,
#     capture_output=False
# )

## 9. Post-Training Analysis

In [None]:
# Check training outputs
import matplotlib.pyplot as plt

print("\nPost-Training Analysis")
print("=" * 80)

# Check checkpoints
checkpoint_dir = project_root / "checkpoints"
if checkpoint_dir.exists():
    checkpoints = list(checkpoint_dir.glob("*.pth"))
    print(f"✓ Checkpoints saved: {len(checkpoints)} files")
    
    if checkpoints:
        best_model = checkpoint_dir / "best_model.pth"
        if best_model.exists():
            size_mb = best_model.stat().st_size / (1024 * 1024)
            print(f"✓ Best model saved: {size_mb:.2f} MB")

# Check debug samples
debug_dir = project_root / "debug_samples"
if debug_dir.exists():
    samples = list(debug_dir.glob("*.png"))
    print(f"✓ Debug samples saved: {len(samples)} images")
    
    if samples:
        print(f"  Latest: {samples[-1].name}")
        print(f"\n  View samples in: {debug_dir}")

print("=" * 80)

## 10. Test Metrics on Sample Batch

In [None]:
# Test ResearchTracker with sample data
print("\nTesting ResearchTracker Metrics")
print("=" * 80)

from src.utils import ResearchTracker

# Create test data
tracker = ResearchTracker()

# Simulate batch
torch.manual_seed(42)
pred = torch.rand(4, 3, 192, 256)
target = torch.rand(4, 3, 192, 256)
target = target * 0.7 + pred * 0.3  # Add correlation

losses = {
    'total': 1.23456,
    'l1': 0.12345,
    'lpips': 0.08901
}

tracker.update(pred, target, losses)
metrics = tracker.get_avg()

print("\nSample Metrics Output:")
print(f"  MAE:        {metrics['mae']:.5f}")
print(f"  PSNR:       {metrics['psnr']:.2f} dB")
print(f"  SSIM:       {metrics['ssim']:.4f}")
print(f"  Total Loss: {metrics['total']:.5f}")
print(f"  L1 Loss:    {metrics['l1']:.5f}")
print(f"  LPIPS Loss: {metrics['lpips']:.5f}")

# Test formatting
log = tracker.format_log(0, 1000, stage="TRAIN", lr=1e-3, time_sec=45.3)
print(f"\nFormatted Log Output:")
print(f"  {log}")

print("\n✓ ResearchTracker working correctly")
print("=" * 80)

## 11. Utilities - Check Training Status

In [None]:
# Helper function to check training status
def check_training_status():
    print("\nTraining Status Check")
    print("=" * 80)
    
    # Check for running processes
    result = subprocess.run(
        ["tasklist", "/FI", "IMAGENAME eq python.exe"],
        capture_output=True,
        text=True
    )
    
    if "python.exe" in result.stdout:
        print("✓ Python process running (training may be active)")
    else:
        print("✗ No Python process found")
    
    # Check GPU
    if torch.cuda.is_available():
        print(f"✓ GPU Memory Used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"✓ GPU Memory Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
    
    print("=" * 80)

check_training_status()

## 12. Utilities - View Debug Samples

In [None]:
# Display debug samples from training
from PIL import Image
import matplotlib.pyplot as plt
import glob

debug_dir = project_root / "debug_samples"

if debug_dir.exists():
    samples = sorted(glob.glob(str(debug_dir / "*.png")), reverse=True)
    
    if samples:
        print(f"\nFound {len(samples)} debug samples")
        print(f"Latest 3 samples:")
        
        for sample_path in samples[:3]:
            print(f"\n  {Path(sample_path).name}")
            
            # Display image
            img = Image.open(sample_path)
            plt.figure(figsize=(15, 5))
            plt.imshow(img)
            plt.axis('off')
            plt.title(Path(sample_path).name)
            plt.tight_layout()
            plt.show()
    else:
        print("No debug samples yet (will be generated during training)")
else:
    print(f"Debug directory not found: {debug_dir}")

## 13. Next Steps

In [None]:
print("""
╔════════════════════════════════════════════════════════════════════════════════╗
║                         TRAINING NEXT STEPS                                    ║
╚════════════════════════════════════════════════════════════════════════════════╝

1. MONITOR TRAINING PROGRESS
   ✓ Run TensorBoard: tensorboard --logdir=logs/
   ✓ Open http://localhost:6006 in browser
   ✓ Watch PSNR, SSIM, MAE, and loss curves

2. STAGE 1 SUCCESS CRITERIA (after ~100 epochs)
   ✓ PSNR > 27 dB
   ✓ SSIM > 0.78
   ✓ MAE < 0.05
   ✓ Loss smoothly decreasing

3. REVIEW DEBUG SAMPLES
   ✓ Check debug_samples/epoch_XXXX_*.png every 5 epochs
   ✓ Verify attention masks highlight shadows correctly
   ✓ Look for artifacts at shadow boundaries

4. START STAGE 2 (After Stage 1 Success)
   ✓ Uncomment Stage 2 cell above
   ✓ Train at 768x1024 resolution
   ✓ Target PSNR > 30 dB

5. EVALUATION
   ✓ Run: python evaluate.py
   ✓ Compare metrics on test set
   ✓ Save outputs for visual inspection

6. TROUBLESHOOTING
   - PSNR stalled: Check attention masks in debug_samples/
   - SSIM dropping: Reduce L1 weight in config.yaml
   - Training slow: Ensure GPU is being used (check nvidia-smi)
   - Out of memory: Reduce batch_size in config.yaml

╔════════════════════════════════════════════════════════════════════════════════╗
║                    For questions, check RESEARCH_TRACKER_SUMMARY.md            ║
╚════════════════════════════════════════════════════════════════════════════════╝
""")