<a href="https://colab.research.google.com/github/spate472/RecreatingRetinaUNet/blob/main/Attempt3_ToyDataNoPH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2

# Set a random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

# ToyDataset class
class ToyDataset(Dataset):
    def __init__(self, num_samples=1000, image_size=(320, 320), noise_factor=0.2, transform=None):
        self.num_samples = num_samples
        self.image_size = image_size
        self.noise_factor = noise_factor
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Create a blank image (black background)
        image = np.zeros((320, 320, 3), dtype=np.uint8)

        # Randomly choose shape type: 0 = circle, 1 = donut
        shape_type = random.randint(0, 1)

        # Random center and radius
        center_x = random.randint(64, 256)
        center_y = random.randint(64, 256)
        radius = random.randint(20, 40)

        # Define custom colors
        color = (random.randint(0, 100), random.randint(100, 255), random.randint(200, 255))

        # Draw the shape on the image
        if shape_type == 0:
            # Draw filled circle
            cv2.circle(image, (center_x, center_y), radius, color, -1)
        else:
            # Draw ring
            cv2.circle(image, (center_x, center_y), radius, color, 3)
            cv2.circle(image, (center_x, center_y), radius - 10, (0, 0, 0), -1)

        # Generate noise with a blue hue, but with a slightly darker blue background
        noise = np.random.uniform(low=0, high=self.noise_factor, size=(320, 320, 3))

        # Create a background that is a little darker than the circle's color
        noise[..., 0] += np.random.uniform(0.0, 0.1, (320, 320))  # Slightly darker blue
        noise[..., 1] += np.random.uniform(0.0, 0.2, (320, 320))  # Slight green variation
        noise[..., 2] += np.random.uniform(0.1, 0.3, (320, 320))  # Slight red variation

        # Clip the values to ensure they stay within the valid image range (0-255)
        noisy_image = np.clip(image + noise * 255, 0, 255).astype(np.uint8)

        # Create the mask (ground truth)
        mask = np.zeros((320, 320), dtype=np.uint8)
        if shape_type == 0:
            # Circle mask
            cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        else:
            # Donut mask
            cv2.circle(mask, (center_x, center_y), radius, 255, 3)
            cv2.circle(mask, (center_x, center_y), radius - 10, 0, -1)


        # Normalize image and mask
        noisy_image = noisy_image.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0

        # Convert to tensors
        noisy_image = torch.from_numpy(noisy_image.transpose((2, 0, 1)))
        mask = torch.from_numpy(mask).unsqueeze(0)

        return noisy_image, mask

# Baseline Model 1: Simple UNet implementation
class SimpleUNet(nn.Module):
    def __init__(self):
        super(SimpleUNet, self).__init__()
        # Encoder
        self.enc_conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.AvgPool2d(2)

        # Decoder
        self.upconv = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2)
        self.dec_conv = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        x = self.enc_conv(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.upconv(x)
        x = self.dec_conv(x)
        return x

# Baseline Model 2: Simple RetinaNet implementation
class SimpleRetinaNet(nn.Module):
    def __init__(self):
        super(SimpleRetinaNet, self).__init__()
        # ResNet18 backbone
        resnet = models.resnet18(pretrained=False)
        # Use only up to layer3 to avoid too much downsampling
        self.encoder = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
        )
        # Neck (process feature maps)
        self.neck = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        # Head (predict mask)
        self.head = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1)
        )
        # Final upsampling layer (by 8x)
        self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.encoder(x)
        x = self.neck(x)
        x = self.head(x)
        x = self.upsample(x)
        return x

# Our implementation of the Retina U-Net Model
class RetinaUNet(nn.Module):
    def __init__(self):
        super(RetinaUNet, self).__init__()
        resnet = models.resnet18(pretrained=False)

        # Encoder layers (for skip connections)
        self.enc1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.enc2 = nn.Sequential(resnet.maxpool, resnet.layer1)
        self.enc3 = resnet.layer2
        self.enc4 = resnet.layer3
        self.enc5 = resnet.layer4

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(256 + 256, 256, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(128 + 128, 128, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64 + 64, 64, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(32 + 64, 32, kernel_size=3, padding=1)

        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        # Encoder with skip connections
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        x5 = self.enc5(x4)

        # Decoder with skip connections
        x = self.upconv1(x5)
        x = torch.cat([x, x4], dim=1)
        x = F.relu(self.conv1(x))

        x = self.upconv2(x)
        x = torch.cat([x, x3], dim=1)
        x = F.relu(self.conv2(x))

        x = self.upconv3(x)
        x = torch.cat([x, x2], dim=1)
        x = F.relu(self.conv3(x))

        x = self.upconv4(x)
        x = torch.cat([x, x1], dim=1)
        x = F.relu(self.conv4(x))

        x = F.interpolate(x, size=(320, 320), mode='bilinear', align_corners=False)
        return self.final_conv(x)

# Training loop
def train(model, dataloader, optimizer, device, num_epochs=5, dropout_p=0.5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (imgs, masks) in enumerate(dataloader):
            imgs = imgs.to(device)
            masks = masks.to(device)

            # Apply input dropout
            if dropout_p > 0:
                imgs = F.dropout(imgs, p=dropout_p, training=True)

            # Ensure proper mask dimensions
            if masks.ndimension() == 3:
                masks = masks.unsqueeze(1)

            # Resize masks and predictions
            masks_resized = F.interpolate(masks.float(), size=(320, 320), mode='nearest')
            optimizer.zero_grad()
            outputs = model(imgs)
            outputs_resized = F.interpolate(outputs, size=(320, 320), mode='nearest')

            # Compute and optimize loss
            loss = criterion(outputs_resized, masks_resized)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'>> Epoch {epoch+1}/{num_epochs}, Average Loss: {running_loss/len(dataloader):.4f}\n')



# Visualize a sample image, mask, and prediction from the model
def visualize_samples(model, dataloader, device, num_samples=3):
    model.eval()
    imgs, masks = next(iter(dataloader))
    imgs = imgs.to(device)
    masks = masks.to(device)

    with torch.no_grad():
        outputs = model.forward(imgs)
        outputs = torch.sigmoid(outputs).squeeze().cpu()

    # Create a figure with multiple subplots
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

    for i in range(num_samples):
        # Show input image
        axes[i, 0].imshow(imgs[i].cpu().permute(1, 2, 0))
        axes[i, 0].set_title(f'Image {i+1}')
        axes[i, 0].axis('off')

        # Show ground truth mask
        gt_mask = masks[i].cpu()
        if gt_mask.ndim == 3 and gt_mask.shape[0] == 1:
            gt_mask = gt_mask.squeeze(0)
        axes[i, 1].imshow(gt_mask, cmap='gray')
        axes[i, 1].set_title(f'Ground Truth Mask {i+1}')
        axes[i, 1].axis('off')

        # Show prediction
        pred_mask = outputs[i]
        if pred_mask.ndim == 3 and pred_mask.shape[0] == 1:
            pred_mask = pred_mask.squeeze(0)
        pred_mask = pred_mask.numpy()

        axes[i, 2].imshow(pred_mask, cmap='gray')
        axes[i, 2].set_title(f'Predicted Mask {i+1}')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()


# Evaluating the Dice Score between the Ground Truth Mask and the Model's Prediction
def evaluate(model, loader, device):
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device).float()

            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            preds = (outputs > 0.5).float()

            # Calculate Dice score for each sample
            intersection = (preds * masks).sum(dim=(1,2,3))
            union = preds.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3))
            dice = (2. * intersection) / (union + 1e-8)
            dice_scores.extend(dice.cpu().numpy())

    # Compute average Dice score
    mean_dice = sum(dice_scores) / len(dice_scores)
    print(f"Mean Dice score: {mean_dice:.4f}")

In [None]:
# Create full dataset
full_dataset = ToyDataset(num_samples=1000)

# Split into train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create loaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)

# Pick Random Image and Mask's and visually displaying them
indices = random.sample(range(len(train_dataset)), 10)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
fig.suptitle("Random Samples: Image and Mask Side-by-Side", fontsize=18)

for ax, idx in zip(axes.flatten(), indices):
    image, mask = train_dataset[idx]
    image_np = image.permute(1, 2, 0).numpy()
    mask_np = mask.squeeze().numpy()
    if mask_np.ndim == 2:
        mask_np = np.expand_dims(mask_np, axis=-1)

    if mask_np.max() > 1:
        mask_np = mask_np / mask_np.max()

    mask_rgb = np.repeat(mask_np, 3, axis=-1)
    combined = np.concatenate((image_np, mask_rgb), axis=1)
    ax.imshow(combined)
    ax.axis('off')

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()

In [None]:
# Create model
model = SimpleUNet()
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
# Run training, visualizing output and evaluation
train(model, train_loader, optimizer, device, num_epochs=5, dropout_p=0.3)
visualize_samples(model, val_loader, device, num_samples=10)
evaluate(model, val_loader, device)


In [None]:
# Create model
model = SimpleRetinaNet()
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)
# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
# Run training, visualizing output and evaluation
train(model, train_loader, optimizer, device, num_epochs=5, dropout_p=0.3)
visualize_samples(model, val_loader, device, num_samples=10)
evaluate(model, val_loader, device)

In [None]:
# Create model
model = RetinaUNet()
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)
# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
# Run training, visualizing output and evaluation
train(model, train_loader, optimizer, device, num_epochs=5, dropout_p=0.3)
visualize_samples(model, val_loader, device, num_samples=10)
evaluate(model, val_loader, device)