## CELL 1: Mount Google Drive (REQUIRED - DO THIS FIRST)

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

print("Google Drive mounted successfully!")

## CELL 2: Check GPU and Navigate to Project

In [None]:
import os
import torch

# Check GPU
print("=" * 50)
print("GPU INFORMATION")
print("=" * 50)
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"PyTorch Version: {torch.__version__}")
print("=" * 50)

# Navigate to project
PROJECT_PATH = '/content/drive/MyDrive/HRF-Segmentation-Unet'
CODE_PATH = os.path.join(PROJECT_PATH, 'code')
DATA_PATH = os.path.join(PROJECT_PATH, 'data')
CHECKPOINT_PATH = os.path.join(PROJECT_PATH, 'checkpoints')

# Create checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

print(f"\nProject Path: {PROJECT_PATH}")
print(f"Code Path: {CODE_PATH}")
print(f"Data Path: {DATA_PATH}")
print(f"Checkpoint Path: {CHECKPOINT_PATH}")

# Verify files exist
print(f"\nCode files in Google Drive:")
for file in os.listdir(CODE_PATH):
    print(f"  - {file}")

print(f"\nData folders:")
for folder in os.listdir(DATA_PATH):
    print(f"  - {folder}/")
    count = len(os.listdir(os.path.join(DATA_PATH, folder)))
    print(f"    ({count} files)")

## CELL 3: Install Required Packages

In [None]:
# Install required packages (Colab already has most pre-installed)
!pip install -q opencv-python tqdm tifffile imagecodecs albumentations -U

print("✓ All packages installed successfully!")

## CELL 4: Import Your Code from Google Drive

In [None]:
import sys
sys.path.insert(0, CODE_PATH)

# Import your code modules - UPDATED: using dataset_enhanced
from dataset_enhanced import create_dataloaders
from losses import get_loss_function
from metrics import evaluate_batch, MetricsTracker
from unet import UNet
from train import Trainer

print("✓ All modules imported successfully!")
print(f"\nAvailable modules:")
print(f"  - dataset_enhanced ")
print(f"  - losses (loss functions)")
print(f"  - metrics (evaluation metrics)")
print(f"  - unet (U-Net model)")
print(f"  - train (training framework)")

## CELL 5: Verify Data

In [None]:
import os
import numpy as np
import cv2
import tifffile
import matplotlib.pyplot as plt

IMAGE_DIR = os.path.join(DATA_PATH, 'images')
MASK_DIR = os.path.join(DATA_PATH, 'masks')

# Count files
image_files = sorted([f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.tif'))])
mask_files = sorted([f for f in os.listdir(MASK_DIR) if f.lower().endswith(('.tiff', '.tif', '.jpeg', '.jpg', '.png'))])

print("="*50)
print("DATA VERIFICATION")
print("="*50)
print(f"Total images: {len(image_files)}")
print(f"Total masks: {len(mask_files)}")
print(f"\nNOTE: Images loaded at original size!")
print(f"NOTE: Using TIFFFILE for proper .ome.tiff handling")

if len(image_files) > 0 and len(mask_files) > 0:
    # Load sample
    sample_img_name = image_files[0]
    sample_img_path = os.path.join(IMAGE_DIR, sample_img_name)

    sample_img = cv2.imread(sample_img_path)
    sample_img = cv2.cvtColor(sample_img, cv2.COLOR_BGR2RGB)

    print(f"\nSample Image (ORIGINAL SIZE):")
    print(f"  Filename: {sample_img_name}")
    print(f"  Shape: {sample_img.shape}")
    print(f"  Dtype: {sample_img.dtype}")
    print(f"  Range: [{sample_img.min()}, {sample_img.max()}]")

    # IMPROVED MASK FINDING - WITH TIFFFILE FOR SCIENTIFIC .OME.TIFF
    print(f"\nFinding mask for: {sample_img_name}")

    # Generate all candidate mask names
    candidates = [
        sample_img_name.replace('.jpg', '_HRF.ome.tiff').replace('.jpeg', '_HRF.ome.tiff').replace('.png', '_HRF.ome.tiff'),
        sample_img_name.replace('.jpg', '.tiff').replace('.jpeg', '.tiff').replace('.png', '.tiff'),
        sample_img_name.replace('.jpg', '_mask.tiff').replace('.jpeg', '_mask.tiff').replace('.png', '_mask.tiff'),
        sample_img_name.replace('.jpg', '_HRF.ome.tif').replace('.jpeg', '_HRF.ome.tif').replace('.png', '_HRF.ome.tif'),
        sample_img_name.replace('.jpg', '.tif').replace('.jpeg', '.tif').replace('.png', '.tif'),
        sample_img_name.replace('.jpg', '_mask.tif').replace('.jpeg', '_mask.tif').replace('.png', '_mask.tif'),
    ]

    sample_mask_path = None
    sample_mask = None

    for i, candidate in enumerate(candidates):
        candidate_path = os.path.join(MASK_DIR, candidate)
        if os.path.exists(candidate_path):
            print(f"  ✓ Found: {candidate}")

            # IMPORTANT: Use TIFFFILE for .ome.tiff files
            # TIFFFILE properly handles scientific TIFF format with metadata
            try:
                sample_mask = tifffile.imread(candidate_path)
                sample_mask_path = candidate_path
                print(f"    ✓ Successfully read with TIFFFILE")
                break
            except Exception as e:
                print(f"    ✗ Error reading file: {e}")
                continue
        else:
            print(f"  ✗ Not found: {candidate}")

    if sample_mask is None:
        print(f"\n⚠️ MASK NOT FOUND OR COULD NOT BE READ!")
        print(f"Check that mask files exist and are in supported format (.tiff, .tif)")
    else:
        # Handle multi-channel masks
        if len(sample_mask.shape) == 3:
            print(f"  Note: Mask is multi-channel/multi-page (shape: {sample_mask.shape})")
            # Take first channel if multi-channel
            sample_mask = sample_mask[0] if sample_mask.shape[0] < sample_mask.shape[1] else sample_mask[:, :, 0]
            print(f"        Using first channel (shape: {sample_mask.shape})")

        print(f"\nSample Mask:")
        print(f"  Filename: {os.path.basename(sample_mask_path)}")
        print(f"  Shape: {sample_mask.shape}")
        print(f"  Dtype: {sample_mask.dtype}")
        print(f"  Unique values: {np.unique(sample_mask)}")
        print(f"  HRF pixels: {np.sum(sample_mask > 0)}")
        print(f"  Note: Legend area in mask should be 0 (not marked as HRF)")

        # Binarize for training
        sample_mask_binary = (sample_mask > 0).astype(np.float32)
        if len(np.unique(sample_mask)) > 2:
            print(f"  Note: Mask has {len(np.unique(sample_mask))} unique values")
            print(f"        Will be binarized during training (threshold=0)")

        # Visualize
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(sample_img)
        axes[0].set_title('OCT Image (Original Size)')
        axes[0].axis('off')

        axes[1].imshow(sample_mask, cmap='gray')
        axes[1].set_title('HRF Mask (from .ome.tiff)')
        axes[1].axis('off')

        overlay = sample_img.copy()
        overlay[sample_mask_binary > 0] = [255, 0, 0]  # Red for HRF
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay (Red = HRF, Legend = Background)')
        axes[2].axis('off')

        plt.tight_layout()
        plt.show()

        print("\n✓ Data verification successful!")
else:
    print("\n⚠️ WARNING: No images or masks found!")
    print("Make sure you uploaded them to Google Drive correctly.")

## CELL 6: Create Data Splits

In [None]:
import os
import numpy as np

IMAGE_DIR = os.path.join(DATA_PATH, 'images')
image_files = sorted([f for f in os.listdir(IMAGE_DIR) if f.endswith(('.jpg','.jpeg', '.png', '.tiff'))])

# Create splits
np.random.seed(42)
all_files = image_files.copy()
np.random.shuffle(all_files)

n_total = len(all_files)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.15)

train_files = all_files[:n_train]
val_files = all_files[n_train:n_train + n_val]
test_files = all_files[n_train + n_val:]

print("=" * 50)
print("DATA SPLIT")
print("=" * 50)
print(f"Total images: {n_total}")
print(f"Training: {len(train_files)} ({100*len(train_files)/n_total:.1f}%)")
print(f"Validation: {len(val_files)} ({100*len(val_files)/n_total:.1f}%)")
print(f"Test: {len(test_files)} ({100*len(test_files)/n_total:.1f}%)")

print("\n✓ Data split completed!")

## CELL 7: Create Data Loaders

In [None]:
# Create dataloaders
print("Creating data loaders...")
print("This may take a moment...\n")

train_loader, val_loader, test_loader = create_dataloaders(
    data_dir=DATA_PATH,
    train_files=train_files,
    val_files=val_files,
    test_files=test_files,
    batch_size=4,  # Reduced for variable image sizes
    num_workers=2,  # Colab works best with 2 workers
    image_size=None,  # None = keep original size
    apply_augmentation=True,  # ENABLED augmentation
)

print("=" * 50)
print("DATA LOADERS CREATED")
print("=" * 50)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Test one batch
images, masks = next(iter(train_loader))
print(f"\nBatch shape (note: may vary due to original sizes):")
print(f"  Images: {images.shape}")
print(f"  Masks: {masks.shape}")
print(f"\n✓ Data loaders ready!")

## CELL 8: Setup Model and Training Components

In [None]:
import torch
import torch.optim as optim

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Create model
print("\nCreating U-Net model...")
model = UNet(
    n_channels=3,
    n_classes=1,
    bilinear=False,
    base_filters=64,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: U-Net")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(2, 3, 512, 512).to(device)
with torch.no_grad():
    test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")

# Loss function
print(f"\nSetup loss function...")
loss_fn = get_loss_function(
    'focal_tversky',
    alpha=0.3,
    beta=0.7,
    gamma=4/3,
)
print(f"Loss: Focal Tversky Loss (recommended for class imbalance)")

# Optimizer
print(f"\nSetup optimizer...")
optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=0.0001,
)
print(f"Optimizer: Adam")
print(f"Learning Rate: 0.001")

# Scheduler
print(f"\nSetup scheduler...")
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=5,
)
print(f"Scheduler: ReduceLROnPlateau")

print(f"\n✓ All components ready!")

In [None]:
import torch
import os

# ========== LOAD CHECKPOINT (CELL 8.5 - NEW CELL) ==========
print("\n" + "="*50)
print("LOADING CHECKPOINT")
print("="*50)

checkpoint_path = '/content/drive/MyDrive/HRF-Segmentation-Unet/checkpoints/best_model.pth'

if os.path.exists(checkpoint_path):
    print(f"✓ Checkpoint found: {checkpoint_path}")

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path,weights_only=False, map_location=device)

    # Restore model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    print(f"✓ Model weights loaded from epoch {checkpoint['epoch']}")

    # Restore optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"✓ Optimizer state restored")

    # Restore scheduler state
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print(f"✓ Scheduler state restored")

    # Extract metadata
    start_epoch = checkpoint['epoch'] + 1  # Start from NEXT epoch

    print(f"\n✓ Ready to resume training!")
    print(f"  - Start epoch: {start_epoch}/100")
    print(f"  - Remaining epochs: {100 - start_epoch}")
else:
    print(f"❌ Checkpoint not found at {checkpoint_path}")
    print("Starting fresh training from epoch 1")
    start_epoch = 1
    best_dice = 0.0

print("="*50)

## CELL 9: Create Trainer

In [None]:
# Create trainer
print("Creating trainer...\n")

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir=CHECKPOINT_PATH,
    use_amp=True,  # Mixed precision for faster training
    log_wandb=False,  # Set to True if you want to use W&B
)

print(f"Trainer created successfully!")
print(f"\nTraining Configuration:")
print(f"  Device: {device}")
print(f"  Mixed Precision: True")
print(f"  AUGMENTATION: Albumentations (Flip, Rotate, Brightness)")
print(f"  Checkpoint Dir: {CHECKPOINT_PATH}")

## CELL 10: Start Training (Main Training Cell) ⭐

In [None]:
import pandas as pd
CHECKPOINT_PATH = '/content/drive/MyDrive/HRF-Segmentation-Unet/checkpoints/'
best_model_path = os.path.join(CHECKPOINT_PATH, 'best_model.pth')
# MAIN TRAINING CELL
print("="*60)
print("STARTING TRAINING")
print("="*60)
print(f"\nTraining started at: {pd.Timestamp.now()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Total epochs: 100")
print(f"Batch size: 4")
print(f"Images loaded at original size")
print(f"\nNote: Training will save checkpoints to Google Drive automatically.")
print(f"Best model will be saved to: {os.path.join(CHECKPOINT_PATH, 'best_model.pth')}")
print("="*60)
print()

# Train
trainer.train(
    num_epochs=100,
    early_stopping_patience=25,
    #resume_from=best_model_path,
)

print("\n" + "="*60)
print("TRAINING COMPLETED!")
print("="*60)
print(f"Best Validation Dice: {trainer.best_val_dice:.4f}")
print(f"\nCheckpoints saved to: {CHECKPOINT_PATH}")

## CELL 11: Evaluation on Test Set

In [None]:
from metrics import evaluate_batch, MetricsTracker
import torch
from tqdm import tqdm
from sklearn.metrics import jaccard_score, roc_curve, auc, confusion_matrix
import numpy as np
import pandas as pd
import os

# Load best model
checkpoint_path = os.path.join(CHECKPOINT_PATH, 'best_model.pth')
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from: {checkpoint_path}")
print(f"Best validation Dice: {checkpoint['metrics']['dice']:.4f}\n")

# Evaluate on test set
metrics_tracker = MetricsTracker()
all_targets = []
all_probs = []

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc='Evaluating on test set'):
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        batch_metrics = evaluate_batch(outputs, masks)
        metrics_tracker.update(batch_metrics)

        # Store for sklearn metrics
        probs = torch.sigmoid(outputs)
        all_probs.append(probs.cpu().numpy().flatten())
        all_targets.append(masks.cpu().numpy().flatten())

# Get results
results = metrics_tracker.get_average()
std_results = metrics_tracker.get_std()

# Calculate global metrics
print("Calculating global metrics (Jaccard, ROC, Confusion Matrix)...")
all_targets_np = np.concatenate(all_targets)
all_probs_np = np.concatenate(all_probs)
all_preds_np = (all_probs_np > 0.5).astype(int)

# Jaccard
jaccard = jaccard_score(all_targets_np, all_preds_np)

# ROC and AUC
fpr, tpr, _ = roc_curve(all_targets_np, all_probs_np)
auc_score = auc(fpr, tpr)

# Confusion Matrix
tn, fp, fn, tp = confusion_matrix(all_targets_np, all_preds_np).ravel()

# Save Results to File for Average Calculation later
metrics_file = os.path.join(CHECKPOINT_PATH, 'test_metrics_history.csv')

new_row = {
    'timestamp': pd.Timestamp.now(),
    'dice': results['dice'],
    'dice_std': std_results['dice'],
    'iou': results['iou'],
    'iou_std': std_results['iou'],
    'precision': results['precision'],
    'precision_std': std_results['precision'],
    'recall': results['recall'],
    'recall_std': std_results['recall'],
    'f1': results['f1'],
    'f1_std': std_results['f1'],
    'specificity': results['specificity'],
    'specificity_std': std_results['specificity'],
    'jaccard': jaccard,
    'auc': auc_score,
    'tp': tp,
    'tn': tn,
    'fp': fp,
    'fn': fn
}

# Append to CSV
if os.path.exists(metrics_file):
    # Check if we need to update columns in case existing file is old
    df = pd.read_csv(metrics_file)
    new_df = pd.DataFrame([new_row])
    df = pd.concat([df, new_df], ignore_index=True)
else:
    df = pd.DataFrame([new_row])

df.to_csv(metrics_file, index=False)
print(f"Results appended to: {metrics_file}")

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
for metric, value in results.items():
    print(f"{metric.capitalize():15s}: {value:.4f} ± {std_results[metric]:.4f}")
print("-" * 60)
print(f"{'Jaccard':15s}: {jaccard:.4f}")
print(f"{'AUC':15s}: {auc_score:.4f}")
print(f"{'Confusion Matrix':15s}: TP={tp}, TN={tn}, FP={fp}, FN={fn}")
print("="*60)

## CELL 12: Visualize Predictions

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import seaborn as sns
import random

# Visualize 5 Random Samples
num_samples_to_visualize = 5
dataset = test_loader.dataset
dataset_len = len(dataset)

# Select random indices
if dataset_len < num_samples_to_visualize:
    indices = list(range(dataset_len))
    print(f"Warning: Dataset smaller than {num_samples_to_visualize}, evaluating all {dataset_len} samples.")
else:
    indices = random.sample(range(dataset_len), num_samples_to_visualize)

print(f"Visualizing samples at indices: {indices}")

# Collect samples
selected_images = []
selected_masks = []
model.eval()

with torch.no_grad():
    for idx in indices:
        # dataset[idx] returns (image, mask) tuple of tensors
        img_tensor, mask_tensor = dataset[idx]
        selected_images.append(img_tensor)
        selected_masks.append(mask_tensor)

    # Stack into batches
    batch_images = torch.stack(selected_images).to(device)
    batch_masks = torch.stack(selected_masks).cpu().numpy()

    # Predict
    outputs = model(batch_images)
    predictions = torch.sigmoid(outputs) > 0.5
    predictions = predictions.cpu().numpy()

# Plotting and Saving Individual Files
num_samples = len(indices)
print("Saving individual prediction images...")

for i in range(num_samples):
    # Denormalize image
    img_np = batch_images[i].cpu().numpy().transpose(1, 2, 0)
    img_np = img_np / img_np.max() if img_np.max() > 0 else img_np  # Normalize to 0-1
    img_np = np.clip(img_np, 0, 1)

    mask_np = batch_masks[i, 0]
    pred_np = predictions[i, 0]

    # Create a separate figure for each sample
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    axes[0].imshow(img_np)
    axes[0].set_title(f'Sample {indices[i]}: Image')
    axes[0].axis('off')

    axes[1].imshow(mask_np, cmap='gray')
    axes[1].set_title(f'Sample {indices[i]}: Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(pred_np, cmap='gray')
    axes[2].set_title(f'Sample {indices[i]}: Prediction')
    axes[2].axis('off')

    plt.tight_layout()
    filename = f'prediction_sample_{indices[i]}.png'
    save_path = os.path.join(CHECKPOINT_PATH, filename)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show() # Display inline as well
    plt.close(fig) # Close to free memory
    print(f"Saved {filename}")

# Visualize ROC and Confusion Matrix (from Cell 11 calculated metrics)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# ROC Curve
ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {auc_score:.4f})')
ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax1.set_xlim([0.0, 1.0])
ax1.set_ylim([0.0, 1.05])
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.set_title('Receiver Operating Characteristic (ROC)')
ax1.legend(loc="lower right")

# Confusion Matrix
cm = np.array([[tn, fp], [fn, tp]])
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax2,
            xticklabels=['Predicted 0', 'Predicted 1'],
            yticklabels=['Actual 0', 'Actual 1'])
ax2.set_title('Confusion Matrix')
ax2.set_ylabel('Actual')
ax2.set_xlabel('Predicted')

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_PATH, 'metrics_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

print("Visualizations saved to Google Drive!")

## CELL 13: Download Results

In [None]:
from google.colab import files
import os

print("="*60)
print("RESULTS SAVED TO GOOGLE DRIVE")
print("="*60)
print(f"\nCheckpoint directory: {CHECKPOINT_PATH}")
print(f"\nSaved files:")
for file in os.listdir(CHECKPOINT_PATH):
    file_size = os.path.getsize(os.path.join(CHECKPOINT_PATH, file)) / 1e6
    print(f"  - {file} ({file_size:.1f} MB)")

print(f"\n✓ All results automatically saved to your Google Drive!")
print(f"\nYou can download them from:")
print(f"Google Drive → HRF-Segmentation-Unet → checkpoints")
print(f"\nKey files:")
print(f"  - best_model.pth (Your trained model - USE THIS!)")
print(f"  - latest_model.pth (Last checkpoint)")
print(f"  - predictions.png (Visual results)")