In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
from torchvision import transforms
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

# =============================
# Define Variables and Configs
# =============================
NUM_CLASSES = 21
BATCH_SIZE = 10
LEARNING_RATE = 1e-3
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 7  # Early stopping patience
MIN_LR = 1e-6  # Minimum learning rate for ReduceLROnPlateau
IMG_SIZE = 128  # Ensure this is divisible by 2^n

# =========================
# Data Transforms and Setup
# =========================
input_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

# Load Datasets
data = VOCSegmentation(root="D:/DATASETS/segmentation", year="2012", image_set="trainval", download=True,
                                 transform=input_transform, target_transform=target_transform)
train_dataset,val_dataset=torch.utils.data.random_split(data,[int(len(data)*0.75),round(len(data)*0.25)+1])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=3)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

data_iter = iter(train_loader)
images, masks = next(data_iter)
print(images.shape, masks.shape)

# ================================
# Define U-Net++ Model with Dropout
# ================================
class UNetPlusPlus(nn.Module):
    def __init__(self, n_channels, n_classes, base_channels=64, dropout_rate=0.3):
        super(UNetPlusPlus, self).__init__()

        # Encoder (Contracting path)
        self.encoder1 = self.conv_block(n_channels, base_channels)
        self.encoder2 = self.conv_block(base_channels, base_channels * 2)
        self.encoder3 = self.conv_block(base_channels * 2, base_channels * 4)
        self.encoder4 = self.conv_block(base_channels * 4, base_channels * 8)

        # Decoder (Expanding path)
        self.decoder1 = self.upconv_block(base_channels * 8, base_channels * 4)
        self.decoder2 = self.upconv_block(base_channels * 4, base_channels * 2)
        self.decoder3 = self.upconv_block(base_channels * 2, base_channels)
        self.decoder4 = nn.ConvTranspose2d(base_channels, n_classes, kernel_size=2, stride=2)

        # Nested Skip Connections
        self.decoder1_1 = self.upconv_block(base_channels * 8 + base_channels * 4, base_channels * 4)
        self.decoder2_1 = self.upconv_block(base_channels * 4 + base_channels * 2, base_channels * 2)
        self.decoder3_1 = self.upconv_block(base_channels * 2 + base_channels, base_channels)

        self.dropout = nn.Dropout2d(dropout_rate)

    def conv_block(self, in_channels, out_channels):
        """Convolution block with Batch Normalization and ReLU."""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        """Upsampling block with ConvTranspose and Batch Normalization."""
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder Pathway
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        # Decoder Pathway
        d1 = self.decoder1(e4)
        d1 = nn.functional.interpolate(d1, size=e3.shape[2:], mode="bilinear", align_corners=False)
        d2 = self.decoder2(d1)
        d2 = nn.functional.interpolate(d2, size=e2.shape[2:], mode="bilinear", align_corners=False)
        d3 = self.decoder3(d2)
        d3 = nn.functional.interpolate(d3, size=e1.shape[2:], mode="bilinear", align_corners=False)
        out = self.decoder4(d3)

        # Nested Skip Pathways
        d1_1 = self.decoder1_1(torch.cat([e4, d1], 1))
        d2_1 = self.decoder2_1(torch.cat([e3, d2], 1))
        d3_1 = self.decoder3_1(torch.cat([e2, d3], 1))

        # Project intermediate tensors to NUM_CLASSES channels
        device = x.device  # Get the device of the input
        proj_d1_1 = nn.Conv2d(d1_1.shape[1], NUM_CLASSES, kernel_size=1).to(device)(d1_1)
        proj_d2_1 = nn.Conv2d(d2_1.shape[1], NUM_CLASSES, kernel_size=1).to(device)(d2_1)
        proj_d3_1 = nn.Conv2d(d3_1.shape[1], NUM_CLASSES, kernel_size=1).to(device)(d3_1)

        # Combine outputs
        out = out + proj_d1_1 + proj_d2_1 + proj_d3_1
        out = self.dropout(out)
        out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)

        return out



model = UNetPlusPlus(n_channels=3, n_classes=NUM_CLASSES).to(DEVICE)

# ==============================
# Define Loss, Optimizer, Scheduler
# ==============================
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Ignore unlabeled pixels
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=3, factor=0.9, min_lr=MIN_LR)

# ============================
# Training Loop
# ============================
class EarlyStopping:
    def __init__(self, patience=7, min_delta=1e-4, mode="min"):
        """
        Early stopping to terminate training when validation loss stops improving.
        :param patience: How many epochs to wait after last improvement.
        :param min_delta: Minimum change to qualify as an improvement.
        :param mode: "min" for loss (lower is better), "max" for accuracy/score (higher is better).
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.best_score = None
        self.epochs_no_improve = 0
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        else:
            improvement = (current_score < self.best_score - self.min_delta) if self.mode == "min" else \
                          (current_score > self.best_score + self.min_delta)
            if improvement:
                self.best_score = current_score
                self.epochs_no_improve = 0
            else:
                self.epochs_no_improve += 1
                if self.epochs_no_improve >= self.patience:
                    self.early_stop = True

# ============================
# Utility Function for Accuracy
# ============================
def calculate_accuracy(outputs, masks):
    """Calculate pixel-wise accuracy."""
    # Get the predicted class for each pixel
    preds = torch.argmax(outputs, dim=1)  # Shape: (batch_size, height, width)
    correct = (preds == masks).float()  # Pixel-wise comparison
    accuracy = correct.sum() / correct.numel()  # Total correct pixels / total pixels
    return accuracy.item()
# ============================
# Utility Functions for Metrics
# ============================
def calculate_iou(outputs, masks):
    """Calculate Intersection over Union (IoU)."""
    preds = torch.argmax(outputs, dim=1)  # Shape: (batch_size, height, width)
    masks = masks.squeeze(1)  # Remove channel dimension if present
    intersection = (preds & masks).float().sum((1, 2))  # Logical AND
    union = (preds | masks).float().sum((1, 2))  # Logical OR
    iou = (intersection + 1e-6) / (union + 1e-6)  # Avoid division by zero
    return iou.mean().item()

def calculate_dice(outputs, masks):
    """Calculate Dice Score."""
    preds = torch.argmax(outputs, dim=1)  # Shape: (batch_size, height, width)
    masks = masks.squeeze(1)  # Remove channel dimension if present
    intersection = (preds & masks).float().sum((1, 2))  # Logical AND
    dice = (2.0 * intersection + 1e-6) / (preds.float().sum((1, 2)) + masks.float().sum((1, 2)) + 1e-6)
    return dice.mean().item()


# ============================
# Training Loop with IoU and Dice
# ============================
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "train_iou": [], "val_iou": [], "train_dice": [], "val_dice": []}
early_stopping = EarlyStopping(patience=7, min_delta=1e-4, mode="min")
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_acc = 0
    train_iou = 0
    train_dice = 0
    with tqdm(train_loader, unit="batch") as tepoch:
        for images, masks in tepoch:
            tepoch.set_description(f"Epoch {epoch + 1}")
            images, masks = images.to(DEVICE), masks.long().to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            train_iou += calculate_iou(outputs, masks)
            train_dice += calculate_dice(outputs, masks)
            loss = criterion(outputs, masks.squeeze(1))
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += calculate_accuracy(outputs, masks)
            train_iou += calculate_iou(outputs, masks)
            train_dice += calculate_dice(outputs, masks)

            tepoch.set_postfix(
                loss=train_loss / len(train_loader),
                acc=train_acc / len(train_loader),
                iou=train_iou / len(train_loader),
                dice=train_dice / len(train_loader)
            )

    history["train_loss"].append(train_loss / len(train_loader))
    history["train_acc"].append(train_acc / len(train_loader))
    history["train_iou"].append(train_iou / len(train_loader))
    history["train_dice"].append(train_dice / len(train_loader))

    # Validation Step
    model.eval()
    val_loss = 0
    val_acc = 0
    val_iou = 0
    val_dice = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(DEVICE), masks.long().to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, masks.squeeze(1))
            val_loss += loss.item()
            val_acc += calculate_accuracy(outputs, masks)
            val_iou += calculate_iou(outputs, masks)
            val_dice += calculate_dice(outputs, masks)

    val_loss /= len(val_loader)
    val_acc /= len(val_loader)
    val_iou /= len(val_loader)
    val_dice /= len(val_loader)

    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)
    history["val_iou"].append(val_iou)
    history["val_dice"].append(val_dice)

    # Print metrics and apply scheduler
    print(f"Epoch {epoch + 1}: Train Loss: {history['train_loss'][-1]}, Train Acc: {history['train_acc'][-1]}, "
          f"Train IoU: {history['train_iou'][-1]}, Train Dice: {history['train_dice'][-1]}, "
          f"Val Loss: {history['val_loss'][-1]}, Val Acc: {history['val_acc'][-1]}, "
          f"Val IoU: {history['val_iou'][-1]}, Val Dice: {history['val_dice'][-1]}")
    scheduler.step(val_loss)

    early_stopping(val_loss)  # Use validation loss or dice as the monitored metric
    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch + 1}")
        break


Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to D:/DATASETS/segmentation/VOCtrainval_11-May-2012.tar


100%|██████████████████████████████████████| 2.00G/2.00G [11:23<00:00, 2.93MB/s]


Extracting D:/DATASETS/segmentation/VOCtrainval_11-May-2012.tar to D:/DATASETS/segmentation
Training samples: 2184, Validation samples: 729
torch.Size([10, 3, 128, 128]) torch.Size([10, 1, 128, 128])


Epoch 1: 100%|█| 219/219 [57:08<00:00, 15.65s/batch, acc=0.799, dice=0.0221, iou


Epoch 1: Train Loss: 1.231534601210459, Train Acc: 0.7993181383399715, Train IoU: 0.0123526780857726, Train Dice: 0.022139836210020716, Val Loss: 0.2604102039173858, Val Acc: 0.9473016147744165, Val IoU: 2.198041404900496e-09, Val Dice: 2.198041404900496e-09


Epoch 2: 100%|█| 219/219 [41:28<00:00, 11.36s/batch, acc=0.835, dice=0.0461, iou


Epoch 2: Train Loss: 0.8118846572289183, Train Acc: 0.8348645529790556, Train IoU: 0.026451331100296675, Train Dice: 0.046079900666448105, Val Loss: 0.22300300471586723, Val Acc: 0.9473016147744165, Val IoU: 2.198041404900496e-09, Val Dice: 2.198041404900496e-09


Epoch 3: 100%|█| 219/219 [10:09:27<00:00, 166.98s/batch, acc=0.86, dice=0.0652, 


Epoch 3: Train Loss: 0.7973634341399963, Train Acc: 0.8600234032765915, Train IoU: 0.03726623050743221, Train Dice: 0.0651566607902506, Val Loss: 0.2119305962569093, Val Acc: 0.9473016147744165, Val IoU: 2.198041404900496e-09, Val Dice: 2.198041404900496e-09


Epoch 4:  99%|▉| 217/219 [2:41:59<00:22, 11.36s/batch, acc=0.854, dice=0.069, io