In [1]:
import os
from pathlib import Path
import kagglehub
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from src.dataset import RetinaDataset
from src.config import Config
from src.model import UNet
from src.trainer import UNetTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set random seed for reproducibility
torch.manual_seed(Config.random_state)
torch.cuda.manual_seed(Config.random_state)
np.random.seed(Config.random_state)
torch.backends.cudnn.deterministic = True

In [3]:
# Set device
device = Config.device
print(f"Using device: {device}")

if device.type == "cuda":
    print(
        f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
    )

Using device: cpu


In [4]:
# Download the dataset using kagglehub
dataset_path = kagglehub.dataset_download("abdallahwagih/retina-blood-vessel")
DATA_DIR = Path(dataset_path)
print("Dataset contents:", os.listdir(DATA_DIR))

Dataset contents: ['Data']


In [5]:
# Load Paths
train_img_dir = os.path.join(DATA_DIR, "Data/train/image")
train_mask_dir = os.path.join(DATA_DIR, "Data/train/mask")

test_img_dir = os.path.join(DATA_DIR, "Data/test/image")
test_mask_dir = os.path.join(DATA_DIR, "Data/test/mask")

train_images = sorted(
    [
        os.path.join(train_img_dir, file)
        for file in os.listdir(train_img_dir)
        if file.endswith(".png") or file.endswith(".jpg")
    ]
)

train_masks = sorted(
    [
        os.path.join(train_mask_dir, f)
        for f in os.listdir(train_mask_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

test_images = sorted(
    [
        os.path.join(test_img_dir, f)
        for f in os.listdir(test_img_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

test_masks = sorted(
    [
        os.path.join(test_mask_dir, f)
        for f in os.listdir(test_mask_dir)
        if f.endswith(".png") or f.endswith(".jpg")
    ]
)

In [6]:
train_image, val_image, train_mask, val_mask = train_test_split(
    train_images, train_masks, test_size=0.2, random_state=Config.random_state
)

In [7]:
train_loader = DataLoader(RetinaDataset(train_image, train_mask))
val_loader = DataLoader(RetinaDataset(val_image, val_mask))
test_loader = DataLoader(RetinaDataset(test_images, test_masks))

In [8]:
# Create model
print("Creating model...")
model = UNet()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Create trainer
trainer = UNetTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
)

# Train model
print("Starting training...")
training_history = trainer.train()

Creating model...
Model parameters: 7,702,977
Starting training...
Starting training for 100 epochs...


Training:   3%|▎         | 2/64 [00:22<11:22, 11.00s/it, Loss=1.5053]


KeyboardInterrupt: 

In [None]:
# Plotting
train_losses, val_losses = training_history.values()

epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, label="Training Loss", marker="o")
plt.plot(epochs, val_losses, label="Validation Loss", marker="o")
plt.title("Training and Validation Loss Curves")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
def visualize_predictions(model, loader, device, num_examples=5):
    model.eval()
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(loader):
            if i >= num_examples:
                break
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            preds = torch.sigmoid(outputs) > 0.5  # Binarize predictions

            # Plot for the first image in the batch (adjust if batch_size >1)
            img = imgs[0].cpu().permute(1, 2, 0).numpy()  # Denormalize if needed
            img = img * np.array([0.229, 0.224, 0.225]) + np.array(
                [0.485, 0.456, 0.406]
            )  # Denormalize
            mask = masks[0].cpu().squeeze().numpy()
            pred = preds[0].cpu().squeeze().float().numpy()

            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(img)
            axs[0].set_title("Original Image")
            axs[1].imshow(mask, cmap="gray")
            axs[1].set_title("Ground Truth Mask")
            axs[2].imshow(pred, cmap="gray")
            axs[2].set_title("Predicted Mask")
            plt.show()


visualize_predictions(model, test_loader, device, num_examples=5)