In [None]:
import os
from pathlib import Path
import kagglehub
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from src.dataset import RetinaDataset, norm
from src.config import Config
from src.model import UNet
from src.trainer import UNetTrainer

In [None]:
# Set random seed for reproducibility
torch.manual_seed(Config.random_state)
torch.cuda.manual_seed(Config.random_state)
np.random.seed(Config.random_state)
torch.backends.cudnn.deterministic = True

In [None]:
# Set device
device = Config.device
print(f"Using device: {device}")

if device.type == "cuda":
    print(
        f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
    )

In [None]:
# Download the dataset using kagglehub
dataset_path = kagglehub.dataset_download("abdallahwagih/retina-blood-vessel")
DATA_DIR = Path(dataset_path)
print("Dataset contents:", os.listdir(DATA_DIR))

In [None]:
# Load Paths
train_img_dir = os.path.join(DATA_DIR, "Data/train/image")
train_mask_dir = os.path.join(DATA_DIR, "Data/train/mask")

test_img_dir = os.path.join(DATA_DIR, "Data/test/image")
test_mask_dir = os.path.join(DATA_DIR, "Data/test/mask")

train_images = sorted(
    [
        os.path.join(train_img_dir, file)
        for file in os.listdir(train_img_dir)
        if file.endswith(".png") or file.endswith(".jpg")
    ]
)

train_masks = sorted(
    [
        os.path.join(train_mask_dir, f)
        for f in os.listdir(train_mask_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

test_images = sorted(
    [
        os.path.join(test_img_dir, f)
        for f in os.listdir(test_img_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

test_masks = sorted(
    [
        os.path.join(test_mask_dir, f)
        for f in os.listdir(test_mask_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
    train_images, train_masks, test_size=0.2, random_state=Config.random_state
)

In [None]:
train_ds = RetinaDataset(train_images, train_masks)
val_ds = RetinaDataset(val_images, val_masks)
test_ds = RetinaDataset(test_images, test_masks)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=2)

In [None]:
# Create model
print("Creating model...")
model = UNet()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Create trainer
trainer = UNetTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
)

# Train model
print("Starting training...")
training_history = trainer.train()

In [None]:
# validate test data
test_loss, test_dice = trainer.validate(test_loader)
print(f"Test Loss: {test_loss:.4f}, Test Dice: {test_dice:.4f}")

In [None]:
# Plotting
train_losses, val_losses, val_dices = training_history.values()

epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, label="Training Loss")
plt.plot(epochs, val_losses, label="Validation Loss")
plt.title("Training and Validation Loss Curves")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(
    range(1, len(val_dices) + 1),
    val_dices,
    marker="o",
    color="tab:green",
    label="Val Dice",
)
plt.xlabel("Epoch")
plt.ylabel("Dice coefficient")
plt.title("Validation Dice over epochs")
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.0)
plt.xticks(range(1, len(val_dices) + 1))
plt.legend()
plt.show()

In [None]:
num_examples = 5

with torch.no_grad():
    for i, (images, masks) in enumerate(test_loader):
        if i >= num_examples:
            break

        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        prediction = torch.sigmoid(outputs) > 0.5  # Binarize predictions

        # Plot for the first image in the batch
        image = images[0].cpu().permute(1, 2, 0).numpy()
        image = image * np.array(norm["std"]) + np.array(norm["mean"])  # Denormalize
        image = np.clip(image, 0.0, 1.0)

        mask = masks[0].cpu().squeeze().numpy()
        prediction = prediction[0].cpu().squeeze().float().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(image)
        axs[0].set_title("Original Image")

        axs[1].imshow(mask, cmap="gray")
        axs[1].set_title("Ground Truth Mask")

        axs[2].imshow(prediction, cmap="gray")
        axs[2].set_title("Predicted Mask")

        plt.show()