# 🚀 Prostate WSI Segmentation - Training Pipeline

This notebook demonstrates the complete training pipeline for prostate tissue segmentation.

## 📋 Setup Instructions:
1. Clone this repository
2. Install requirements: `pip install -r requirements.txt`
3. Download datasets (or use your own)
4. Run this notebook!


In [None]:
# Install requirements if running on Colab
import sys
if 'google.colab' in sys.modules:
    !git clone https://github.com/YOUR_USERNAME/hackathon.git
    %cd hackathon
    !pip install -r requirements.txt

In [None]:
# Import required libraries
import sys
import os
sys.path.append('src')

from config import TrainingConfig
from data_loader import WSIPatchDataset, get_transforms
from model import SegmentationModel, CombinedLoss
from train import train_model
from utils import calculate_mean_iou

import torch
import glob
import matplotlib.pyplot as plt

In [None]:
# Setup configuration
config = TrainingConfig()
print(f"🔧 Device: {config.DEVICE}")
print(f"📊 Classes: {config.CLASSES}")
print(f"🎯 Training for {config.EPOCHS} epochs")

In [None]:
# Prepare training and validation paths
train_images = sorted(glob.glob(str(config.TRAIN_DIR / "*.png")))
train_images = [img for img in train_images if not img.endswith('_mask.png')]
train_masks = [img.replace('.png', '_mask.png') for img in train_images]
train_paths = list(zip(train_images, train_masks))

val_images = sorted(glob.glob(str(config.VAL_DIR / "*.png")))
val_images = [img for img in val_images if not img.endswith('_mask.png')]
val_masks = [img.replace('.png', '_mask.png') for img in val_images]
val_paths = list(zip(val_images, val_masks))

print(f"📊 Training WSIs: {len(train_paths)}")
print(f"📊 Validation WSIs: {len(val_paths)}")

In [None]:
# Start training
print("🚀 Starting training...")
history, best_iou = train_model(train_paths, val_paths, config)

print(f"\n🎉 Training completed!")
print(f"🏆 Best WSI IoU: {best_iou:.4f}")

In [None]:
# Plot training history
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.title('Patch-level IoU')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(history['wsi_iou'], label='WSI IoU')
plt.title('WSI-level IoU')
plt.legend()

plt.tight_layout()
plt.show()