In [1]:
# pip install einops


In [2]:
import torch
import torchvision
from einops import rearrange
from torch import nn
from torchvision.ops import StochasticDepth
from typing import List
from typing import Iterable
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import random
import torch.nn.functional as F

In [3]:
class ImageMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, filenames, transform_image=None, transform_mask=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = filenames
        self.transform_image = transform_image
        self.transform_mask = transform_mask

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, img_name)
        mask_name = img_name  # Assuming filenames match
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Load image and mask
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transformations
        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)

        return {'pixel_values': image, 'labels': mask}


In [4]:
transform_image = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Normalize using ImageNet mean and std if desired
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_mask = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.NEAREST),
    transforms.PILToTensor(),  # Converts to tensor without normalization
    transforms.Lambda(lambda x: x.squeeze().long())  # Remove channel dimension and convert to long
])


In [5]:
# Use the paths to your image and mask directories
image_dir = r'C:/Users/yifen/Desktop/Desktop/Yi Fen Folder/MITB/Deep Learning for Visual Recognition/Project/Project Dataset/train_images'
mask_dir = r'C:/Users/yifen/Desktop/Desktop/Yi Fen Folder/MITB/Deep Learning for Visual Recognition/Project/Project Dataset/train_masks'
all_filenames = sorted(os.listdir(image_dir))

In [6]:
class LayerNorm2d(nn.LayerNorm):
    def forward(self, x):
        x = rearrange(x, "b c h w -> b h w c")
        x = super().forward(x)
        x = rearrange(x, "b h w c -> b c h w")
        return x


class OverlapPatchMerging(nn.Sequential):
    def __init__(
        self, in_channels: int, out_channels: int, patch_size: int, overlap_size: int
    ):
        super().__init__(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=patch_size,
                stride=overlap_size,
                padding=patch_size // 2,
                bias=False
            ),
            LayerNorm2d(out_channels)
        )

class EfficientMultiHeadAttention(nn.Module):
    def __init__(self, channels: int, reduction_ratio: int = 1, num_heads: int = 8):
        super().__init__()
        self.reducer = nn.Sequential(
            nn.Conv2d(
                channels, channels, kernel_size=reduction_ratio, stride=reduction_ratio
            ),
            LayerNorm2d(channels),
        )
        self.att = nn.MultiheadAttention(
            channels, num_heads=num_heads, batch_first=True
        )

    def forward(self, x):
        _, _, h, w = x.shape
        reduced_x = self.reducer(x)
        # attention needs tensor of shape (batch, sequence_length, channels)
        reduced_x = rearrange(reduced_x, "b c h w -> b (h w) c")
        x = rearrange(x, "b c h w -> b (h w) c")
        out = self.att(x, reduced_x, reduced_x)[0]
        # reshape it back to (batch, channels, height, width)
        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
        return out
    
class MixMLP(nn.Sequential):
    def __init__(self, channels: int, expansion: int = 4):
        super().__init__(
            # dense layer
            nn.Conv2d(channels, channels, kernel_size=1),
            # depth wise conv
            nn.Conv2d(
                channels,
                channels * expansion,
                kernel_size=3,
                groups=channels,
                padding=1,
            ),
            nn.GELU(),
            # dense layer
            nn.Conv2d(channels * expansion, channels, kernel_size=1),
        )

class ResidualAdd(nn.Module):
    """Just an util layer"""
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        out = self.fn(x, **kwargs)
        x = x + out
        return x

class SegFormerEncoderBlock(nn.Sequential):
    def __init__(
        self,
        channels: int,
        reduction_ratio: int = 1,
        num_heads: int = 8,
        mlp_expansion: int = 4,
        drop_path_prob: float = .0
    ):
        super().__init__(
            ResidualAdd(
                nn.Sequential(
                    LayerNorm2d(channels),
                    EfficientMultiHeadAttention(channels, reduction_ratio, num_heads),
                )
            ),
            ResidualAdd(
                nn.Sequential(
                    LayerNorm2d(channels),
                    MixMLP(channels, expansion=mlp_expansion),
                    StochasticDepth(p=drop_path_prob, mode="batch")
                )
            ),
        )

class SegFormerEncoderStage(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        patch_size: int,
        overlap_size: int,
        drop_probs: List[int],
        depth: int = 2,
        reduction_ratio: int = 1,
        num_heads: int = 8,
        mlp_expansion: int = 4,
    ):
        super().__init__()
        self.overlap_patch_merge = OverlapPatchMerging(
            in_channels, out_channels, patch_size, overlap_size,
        )
        self.blocks = nn.Sequential(
            *[
                SegFormerEncoderBlock(
                    out_channels, reduction_ratio, num_heads, mlp_expansion, drop_probs[i]
                )
                for i in range(depth)
            ]
        )
        self.norm = LayerNorm2d(out_channels)


def chunks(data: Iterable, sizes: List[int]):
    """
    Given an iterable, returns slices using sizes as indices
    """
    curr = 0
    for size in sizes:
        chunk = data[curr: curr + size]
        curr += size
        yield chunk
        
class SegFormerEncoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        widths: List[int],
        depths: List[int],
        all_num_heads: List[int],
        patch_sizes: List[int],
        overlap_sizes: List[int],
        reduction_ratios: List[int],
        mlp_expansions: List[int],
        drop_prob: float = .0
    ):
        super().__init__()
        # create drop paths probabilities (one for each stage's block)
        drop_probs =  [x.item() for x in torch.linspace(0, drop_prob, sum(depths))]
        self.stages = nn.ModuleList(
            [
                SegFormerEncoderStage(*args)
                for args in zip(
                    [in_channels, *widths],
                    widths,
                    patch_sizes,
                    overlap_sizes,
                    chunks(drop_probs, sizes=depths),
                    depths,
                    reduction_ratios,
                    all_num_heads,
                    mlp_expansions
                )
            ]
        )
        
    def forward(self, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features

class SegFormerDecoderBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
        super().__init__(
            nn.UpsamplingBilinear2d(scale_factor=scale_factor),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
        )

class SegFormerDecoder(nn.Module):
    def __init__(self, out_channels: int, widths: List[int], scale_factors: List[int]):
        super().__init__()
        self.stages = nn.ModuleList(
            [
                SegFormerDecoderBlock(in_channels, out_channels, scale_factor)
                for in_channels, scale_factor in zip(widths, scale_factors)
            ]
        )
    
    def forward(self, features):
        new_features = []
        for feature, stage in zip(features,self.stages):
            x = stage(feature)
            new_features.append(x)
        return new_features

class SegFormerSegmentationHead(nn.Module):
    def __init__(self, channels: int, num_classes: int, num_features: int = 4):
        super().__init__()
        self.fuse = nn.Sequential(
            nn.Conv2d(channels * num_features, channels, kernel_size=1, bias=False),
            nn.ReLU(),  # Applies ReLU activation to introduce non-linearity
            nn.BatchNorm2d(channels)  # BatchNorm for better stability and convergence
        )
        self.predict = nn.Conv2d(channels, num_classes, kernel_size=1)
        # Add an upsampling layer to match the resolution of 224x224
        self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)

    def forward(self, features):
        x = torch.cat(features, dim=1)
        x = self.fuse(x)
        x = self.predict(x)
        x = self.upsample(x)  # Upsample to 224x224
        return x

class SegFormer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        widths: List[int],
        depths: List[int],
        all_num_heads: List[int],
        patch_sizes: List[int],
        overlap_sizes: List[int],
        reduction_ratios: List[int],
        mlp_expansions: List[int],
        decoder_channels: int,
        scale_factors: List[int],
        num_classes: int,
        drop_prob: float = 0.0,
    ):

        super().__init__()
        self.encoder = SegFormerEncoder(
            in_channels,
            widths,
            depths,
            all_num_heads,
            patch_sizes,
            overlap_sizes,
            reduction_ratios,
            mlp_expansions,
            drop_prob,
        )
        self.decoder = SegFormerDecoder(decoder_channels, widths[::-1], scale_factors)
        self.head = SegFormerSegmentationHead(
            decoder_channels, num_classes, num_features=len(widths)
        )

    def forward(self, x):
        features = self.encoder(x)
        features = self.decoder(features[::-1])
        segmentation = self.head(features)
        return segmentation

segformer = SegFormer(
    in_channels=3,
    widths=[64, 128, 256, 512],
    depths=[3, 4, 6, 3],
    all_num_heads=[1, 2, 4, 8],
    patch_sizes=[7, 3, 3, 3],
    overlap_sizes=[4, 2, 2, 2],
    reduction_ratios=[8, 4, 2, 1],
    mlp_expansions=[4, 4, 4, 4],
    decoder_channels=256,
    scale_factors=[8, 4, 2, 1],
    num_classes=9,
)

# segmentation = segformer(torch.randn((1, 3, 224, 224)))
# segmentation.shape # torch.Size([1, 100, 56, 56])

In [7]:
import math

total_images = 1631

# Calculate the number of images for each set
train_size = int(0.90 * total_images)
val_size = int(0.10 * total_images)
# test_size = total_images - train_size - val_size  # Adjust for any rounding errors

print(f"Training images: {train_size}")
print(f"Validation images: {val_size}")
# print(f"Test images: {test_size}")

Training images: 1467
Validation images: 163


In [8]:
from sklearn.model_selection import train_test_split

# First split into training and temp (validation + test)
train_filenames, val_filenames = train_test_split(all_filenames,test_size=(val_size),random_state=42)

# Now split temp into validation and test
# val_filenames, test_filenames = train_test_split(temp_filenames,test_size=test_size,random_state=42)

print(f"Total images: {len(all_filenames)}")
print(f"Training images: {len(train_filenames)}")
print(f"Validation images: {len(val_filenames)}")
# print(f"Test images: {len(test_filenames)}")

Total images: 1631
Training images: 1468
Validation images: 163


In [9]:
# Training dataset and loader
train_dataset = ImageMaskDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    filenames=train_filenames,
    transform_image=transform_image,
    transform_mask=transform_mask
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)

# Validation dataset and loader
val_dataset = ImageMaskDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    filenames=val_filenames,
    transform_image=transform_image,
    transform_mask=transform_mask
)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

# Test dataset and loader
# test_dataset = ImageMaskDataset(
#     image_dir=image_dir,
#     mask_dir=mask_dir,
#     filenames=test_filenames,
#     transform_image=transform_image,
#     transform_mask=transform_mask
# )
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

In [10]:
def compute_iou(preds, labels, num_classes):
    """
    Computes IoU for each class between predicted labels and ground truth labels.

    Args:
        preds (torch.Tensor): Predicted labels, shape (N, H, W)
        labels (torch.Tensor): Ground truth labels, shape (N, H, W)
        num_classes (int): Number of classes

    Returns:
        list: IoU for each class
    """
    ious = []
    preds = preds.view(-1)
    labels = labels.view(-1)
    for cls in range(num_classes):
        pred_inds = preds == cls
        target_inds = labels == cls
        intersection = torch.sum((pred_inds & target_inds).float()).item()
        union = torch.sum((pred_inds | target_inds).float()).item()
        if union == 0:
            iou = float('nan')  # If there is no ground truth, set IoU to NaN
        else:
            iou = intersection / union
        ious.append(iou)
    return ious

In [11]:
from torch.nn.modules.loss import _Loss
import typing
from typing import List
from typing import Optional
from torch import Tensor

def soft_dice_score(
    output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
) -> torch.Tensor:
    """

    :param output:
    :param target:
    :param smooth:
    :param eps:
    :return:

    Shape:
        - Input: :math:`(N, NC, *)` where :math:`*` means any number
            of additional dimensions
        - Target: :math:`(N, NC, *)`, same shape as the input
        - Output: scalar.

    """
    assert output.size() == target.size()
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)
    dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
    return dice_score


__all__ = ["DiceLoss"]

BINARY_MODE = "binary"
MULTICLASS_MODE = "multiclass"
MULTILABEL_MODE = "multilabel"

class DiceLoss(_Loss):
    """
    Implementation of Dice loss for image segmentation task.
    It supports binary, multiclass and multilabel cases
    """

    def __init__(
        self,
        mode: str,
        classes: List[int] = None,
        log_loss=False,
        from_logits=True,
        smooth: float = 0.0,
        ignore_index=None,
        eps=1e-7,
    ):
        """

        :param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
        :param classes: Optional list of classes that contribute in loss computation;
        By default, all channels are included.
        :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
        :param from_logits: If True assumes input is raw logits
        :param smooth:
        :param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
        :param eps: Small epsilon for numerical stability
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(DiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype=torch.long)

        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.ignore_index = ignore_index
        self.log_loss = log_loss

    def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
        """

        :param y_pred: NxCxHxW
        :param y_true: NxHxW
        :return: scalar
        """
        assert y_true.size(0) == y_pred.size(0)

        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            if self.mode == MULTICLASS_MODE:
                y_pred = y_pred.log_softmax(dim=1).exp()
            else:
                y_pred = F.logsigmoid(y_pred).exp()

        bs = y_true.size(0)
        num_classes = y_pred.size(1)
        dims = (0, 2)

        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask.unsqueeze(1)

                y_true = F.one_hot((y_true * mask).to(torch.long), num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)  # H, C, H*W
            else:
                y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1)  # H, C, H*W

        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)

        if self.log_loss:
            loss = -torch.log(scores.clamp_min(self.eps))
        else:
            loss = 1.0 - scores

        # Dice loss is undefined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

        if self.classes is not None:
            loss = loss[self.classes]

        return loss.mean()

In [12]:
class weighted_loss(nn.Module):
    def __init__(self, reduction='mean', lamb=1.25):
        super().__init__()
        self.reduction = reduction
        self.lamb = lamb
        self.base_loss = torch.nn.CrossEntropyLoss(reduction=reduction)
        self.dice_loss = DiceLoss(mode='multiclass')
        
    def forward(self, logits, target):
        base_l = self.base_loss(logits, target)
        dice_l = self.dice_loss(logits, target)
        
        return base_l + self.lamb * dice_l

In [13]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model
model = segformer
model = model.to(device)

# Define loss function (using CrossEntropyLoss for multi-class segmentation)
criterion = weighted_loss()

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Load checkpoint
checkpoint_path = 'Segformer_Scratch(V1)_checkpoint.pth'
start_epoch = 0  # Initialize start epoch

try:
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
    print(f"Resuming training from epoch {start_epoch}")
except FileNotFoundError:
    print("Checkpoint not found. Starting from scratch.")

num_epochs = 250  # Total epochs to train

# Continue training from the loaded checkpoint
for epoch in range(start_epoch, num_epochs):
    # Training phase
    model.train()
    train_loss = 0
    for batch in train_loader:
        images = batch['pixel_values'].to(device)
        masks = batch['labels'].to(device)

        outputs = model(images)
        logits = outputs  # Directly use the output as logits

        # Resize logits to match the mask size
        upsampled_logits = F.interpolate(
            logits, size=masks.shape[-2:], mode='bilinear', align_corners=False
        )

        # Compute loss
        loss = criterion(upsampled_logits, masks)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0
    iou_list = []
    with torch.no_grad():
        for batch in val_loader:
            images = batch['pixel_values'].to(device)
            masks = batch['labels'].to(device)

            outputs = model(images)
            logits = outputs  # Directly use the output as logits

            # Resize logits to match mask size
            upsampled_logits = F.interpolate(
                logits, size=masks.shape[-2:], mode='bilinear', align_corners=False
            )

            loss = criterion(upsampled_logits, masks)
            val_loss += loss.item()

            # Compute predictions
            _, preds = torch.max(upsampled_logits, 1)
            ious = compute_iou(preds, masks, num_classes=9)
            iou_list.append(ious)

    avg_val_loss = val_loss / len(val_loader)
    iou_list = np.array(iou_list)
    mean_iou_per_class = np.nanmean(iou_list, axis=0)
    mIoU = np.nanmean(mean_iou_per_class)

    # Save model and optimizer states
    torch.save({
        'epoch': epoch,  # Save the current epoch number
        'model_state_dict': model.state_dict(),  # Save model parameters
        'optimizer_state_dict': optimizer.state_dict(),  # Save optimizer parameters
        'loss': loss,  # Optionally, save the loss value if needed
        'train_loss': train_loss,  # Save training losses
        'val_losses': val_loss,  # Save validation losses       
    }, 'Segformer_Scratch(V2)_checkpoint.pth')  # File name where the checkpoint will be saved

    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}, Validation mIoU: {mIoU:.4f}")

print('Finished Training')

# After training is completed, save the model's weights
torch.save(model.state_dict(),'Segformer_Scratch(V2)_weights.pth')
print("Model weights saved successfully.")

  checkpoint = torch.load(checkpoint_path)


Resuming training from epoch 50
Epoch [51/250], Training Loss: 0.0954
Epoch [51/250], Validation Loss: 0.1741, Validation mIoU: 0.8246
Epoch [52/250], Training Loss: 0.0946
Epoch [52/250], Validation Loss: 0.1644, Validation mIoU: 0.8314
Epoch [53/250], Training Loss: 0.0922
Epoch [53/250], Validation Loss: 0.1726, Validation mIoU: 0.8222
Epoch [54/250], Training Loss: 0.0911
Epoch [54/250], Validation Loss: 0.1526, Validation mIoU: 0.8379
Epoch [55/250], Training Loss: 0.0945
Epoch [55/250], Validation Loss: 0.1835, Validation mIoU: 0.8172
Epoch [56/250], Training Loss: 0.0952
Epoch [56/250], Validation Loss: 0.1657, Validation mIoU: 0.8320
Epoch [57/250], Training Loss: 0.0884
Epoch [57/250], Validation Loss: 0.1588, Validation mIoU: 0.8379
Epoch [58/250], Training Loss: 0.0892
Epoch [58/250], Validation Loss: 0.1604, Validation mIoU: 0.8366
Epoch [59/250], Training Loss: 0.1059
Epoch [59/250], Validation Loss: 0.1783, Validation mIoU: 0.8213
Epoch [60/250], Training Loss: 0.0976
Epo

In [14]:
num_classes = 9  # Including background

class_names = ['background', 'spleen', 'right kidney', 'left kidney', 'gallbladder', 'pancreas', 'liver', 'stomach', 'aorta']

# Evaluation on test set
model.eval()
test_loss = 0
iou_list = []
with torch.no_grad():
    for images, masks in val_loader:
        images = batch['pixel_values'].to(device)
        masks = batch['labels'].to(device)

        outputs = model(images)
        logits = outputs

        # Resize logits to match mask size
        upsampled_logits = F.interpolate(logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
        loss = criterion(upsampled_logits, masks)
        test_loss += loss.item()

        # Compute predictions
        _, preds = torch.max(upsampled_logits, 1)
        ious = compute_iou(preds, masks, num_classes=9)
        iou_list.append(ious)

avg_test_loss = test_loss / len(val_loader)
iou_list = np.array(iou_list)
mean_iou_per_class = np.nanmean(iou_list, axis=0)
mIoU = np.nanmean(mean_iou_per_class)

print(f"Val Loss: {avg_test_loss:.4f}, Val mIoU: {mIoU:.4f}")

# Display IoU for each class
for idx, iou in enumerate(mean_iou_per_class):
    print(f"Class {idx} ({class_names[idx]}): IoU = {iou:.4f}")

Val Loss: 0.2893, Val mIoU: 0.7764
Class 0 (background): IoU = 0.9886
Class 1 (spleen): IoU = 0.9316
Class 2 (right kidney): IoU = 0.8702
Class 3 (left kidney): IoU = 0.9015
Class 4 (gallbladder): IoU = 0.0000
Class 5 (pancreas): IoU = 0.9178
Class 6 (liver): IoU = 0.8685
Class 7 (stomach): IoU = 0.8532
Class 8 (aorta): IoU = 0.6564
