<a href="https://colab.research.google.com/github/sarim711/image_segmentation_ISIC2016/blob/main/Image_Segmentation_ISIC2016_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Source: https://github.com/zekikus/MedSegBench

# Install medsegbench lib
!pip install medsegbench

Collecting medsegbench
  Downloading medsegbench-1.0.0-py3-none-any.whl.metadata (797 bytes)
Collecting segmentation-models-pytorch (from medsegbench)
  Downloading segmentation_models_pytorch-0.4.0-py3-none-any.whl.metadata (32 kB)
Collecting efficientnet-pytorch>=0.6.1 (from segmentation-models-pytorch->medsegbench)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels>=0.7.1 (from segmentation-models-pytorch->medsegbench)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->medsegbench)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->medsegbench)
  Downloading nvi

In [None]:
# import lib and utils from related to the ISIC 2016 dataset (melanoma segmentation)
import medsegbench
from medsegbench import Isic2016MSBench
from medsegbench import INFO

In [None]:
# Get info from the dataset
info = INFO["isic2016"]
n_channels = info['n_channels_im']
n_classes = len(info['pixel_labels'])
n_samples = info['n_samples']

DataClass = getattr(medsegbench, info['python_class'])

In [None]:
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

In [None]:
# dataset loading using the original DataClass
BATCH_SIZE = 32

img_size = 128
download = True


In [None]:
class AugmentedDataClass(DataClass):
    def __init__(self, split, size=128, transform=None, target_transform=None, download=False):
        super().__init__(split=split, size=size, transform=transform, target_transform=target_transform, download=download)
        self.split = split

        # Define augmentations
        self.image_augmentations = transforms.Compose([
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Only for image, as segmentation  masks are arrays of class label or binary masks)
            transforms.RandomRotation(degrees=30),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])

        self.mask_augmentations = transforms.Compose([  # Geometric transformations for masks same as image
            transforms.RandomRotation(degrees=30),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])

    def __getitem__(self, idx):
        image, mask = super().__getitem__(idx)

        if self.split == 'train':
            seed = torch.randint(0, 2**32, (1,)).item()  # Generate a random seed
            torch.manual_seed(seed)

            # Apply same geometric transformations to both image and mask
            image = self.image_augmentations(image)  # Includes color jitter
            torch.manual_seed(seed)
            mask = self.mask_augmentations(mask)  # Excludes color jitter

        return image, mask


# Define transformations
data_transform = transforms.Compose([transforms.ToTensor()])

# Load datasets
train_dataset = AugmentedDataClass(split='train', size=img_size, transform=data_transform, target_transform=data_transform, download=download)
validation_dataset = AugmentedDataClass(split='val', size=img_size, transform=data_transform, target_transform=data_transform, download=download)
test_dataset = AugmentedDataClass(split='test', size=img_size, transform=data_transform, target_transform=data_transform, download=download)

# Dataloader
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = data.DataLoader(dataset=validation_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)


100%|██████████| 47.9M/47.9M [00:36<00:00, 1.31MB/s]


In [None]:
# model config
NUM_EPOCHS = 10
lr = 0.001

In [None]:
# Unet Model
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        def conv_block(input_channels, output_channels):
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            )

        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoding path
        encoder1 = self.encoder1(x)
        encoder2 = self.encoder2(self.pool(encoder1))
        encoder3 = self.encoder3(self.pool(encoder2))
        encoder4 = self.encoder4(self.pool(encoder3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(encoder4))

        # Decoding path
        decoder4 = self.upconv4(bottleneck)
        decoder4 = torch.cat((decoder4, encoder4), dim=1)
        decoder4 = self.decoder4(decoder4)

        decoder3 = self.upconv3(decoder4)
        decoder3 = torch.cat((decoder3, encoder3), dim=1)
        decoder3 = self.decoder3(decoder3)

        decoder2 = self.upconv2(decoder3)
        decoder2 = torch.cat((decoder2, encoder2), dim=1)
        decoder2 = self.decoder2(decoder2)

        decoder1 = self.upconv1(decoder2)
        decoder1 = torch.cat((decoder1, encoder1), dim=1)
        decoder1 = self.decoder1(decoder1)

        return self.out_conv(decoder1)

# Test the model
model = UNet(in_channels=3, out_channels=1)
x = torch.randn(1, 3, 256, 256)
output = model(x)
print(f"Model output shape: {output.shape}")


Model output shape: torch.Size([1, 1, 256, 256])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

criterion = nn.BCEWithLogitsLoss()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
def calculate_pixel_ratios(dataset):
    foreground_pixels = 0
    background_pixels = 0

    for _, mask in dataset:
        mask = mask.squeeze()  # Remove batch dimension if present
        foreground_pixels += (mask > 0).sum().item()  # Count positive (foreground) pixels
        background_pixels += (mask == 0).sum().item()  # Count zero (background) pixels

    total_pixels = foreground_pixels + background_pixels
    foreground_ratio = foreground_pixels / total_pixels
    background_ratio = background_pixels / total_pixels

    return foreground_ratio, background_ratio

# Example usage
foreground_ratio, background_ratio = calculate_pixel_ratios(train_dataset)
print(f"Foreground Ratio: {foreground_ratio:.4f}, Background Ratio: {background_ratio:.4f}")


Foreground Ratio: 0.2704, Background Ratio: 0.7296


In [None]:

pos_weight_value = background_ratio / foreground_ratio
pos_weight = torch.tensor([pos_weight_value],device=device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


In [None]:
def dice_coefficient(preds, targets, threshold=0.5):
    # Apply threshold to convert logits to binary predictions
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * targets).sum()  # True positives
    union = preds.sum() + targets.sum()  # Total pixels in both masks
    dice = (2.0 * intersection) / (union + 1e-8)  # Add epsilon to avoid division by zero
    return dice.item()

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    model.to(device)

    for epoch in range(num_epochs):
        train_loss = []
        train_dice_scores = []

        model.train()

        # Training loop
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)

            # Compute loss
            loss = criterion(outputs, targets)
            train_loss.append(loss.item())

            # Backward pass + optimization step
            loss.backward()
            optimizer.step()

            # Calculate Dice coefficient for training
            dice = dice_coefficient(outputs, targets)
            train_dice_scores.append(dice)

        # Compute average training metrics
        avg_train_loss = sum(train_loss) / len(train_loss)
        avg_train_dice = sum(train_dice_scores) / len(train_dice_scores)

        # Validation loop
        val_loss = []
        val_dice_scores = []

        model.eval()
        with torch.no_grad():  # Disable gradient computation for validation
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)

                # Forward pass
                outputs = model(inputs)

                # Compute loss
                loss = criterion(outputs, targets)
                val_loss.append(loss.item())

                # Calculate Dice coefficient for validation
                dice = dice_coefficient(outputs, targets)
                val_dice_scores.append(dice)

        # Compute average validation metrics
        avg_val_loss = sum(val_loss) / len(val_loss)
        avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)

        # Print metrics for the epoch
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}")

train_model(model=model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, num_epochs=NUM_EPOCHS, device=device)

Epoch 1/10
  Train Loss: 0.7402, Train Dice: 0.6312
  Val Loss: 4.9304, Val Dice: 0.4849
Epoch 2/10
  Train Loss: 0.6357, Train Dice: 0.6862
  Val Loss: 0.8690, Val Dice: 0.5243
Epoch 3/10
  Train Loss: 0.6182, Train Dice: 0.6951
  Val Loss: 0.6470, Val Dice: 0.6379
Epoch 4/10
  Train Loss: 0.5919, Train Dice: 0.7104
  Val Loss: 0.4823, Val Dice: 0.7419
Epoch 5/10
  Train Loss: 0.5795, Train Dice: 0.7193
  Val Loss: 0.5069, Val Dice: 0.7308
Epoch 6/10
  Train Loss: 0.5475, Train Dice: 0.7331
  Val Loss: 0.7931, Val Dice: 0.5288
Epoch 7/10
  Train Loss: 0.5502, Train Dice: 0.7281
  Val Loss: 0.6690, Val Dice: 0.6354
Epoch 8/10
  Train Loss: 0.5457, Train Dice: 0.7274
  Val Loss: 0.4763, Val Dice: 0.7218
Epoch 9/10
  Train Loss: 0.5107, Train Dice: 0.7420
  Val Loss: 0.4753, Val Dice: 0.7919
Epoch 10/10
  Train Loss: 0.5031, Train Dice: 0.7527
  Val Loss: 0.4807, Val Dice: 0.7269


In [None]:
evaluate_on_test(model, test_loader, device)

Average Dice Coefficient on Test Dataset: 0.7502


In [None]:
from torchvision.models import resnet34, ResNet34_Weights
import torch.nn.functional as F # import functional interface
class UNetWithResNetEncoder(nn.Module):
    def __init__(self, out_channels=1):
        super(UNetWithResNetEncoder, self).__init__()

        # Pre-trained ResNet backbone
        resnet = resnet34(weights=ResNet34_Weights.DEFAULT)
        self.encoder1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # First layer
        self.encoder2 = resnet.layer1  # ResNet block 1
        self.encoder3 = resnet.layer2  # ResNet block 2
        self.encoder4 = resnet.layer3  # ResNet block 3

        # Modify the last ResNet block (layer4) to prevent excessive downsampling
        self.encoder5 = resnet.layer4
        for name, layer in self.encoder5.named_modules():
            if isinstance(layer, nn.Conv2d) and layer.stride == (2, 2):
                layer.stride = (1, 1)  # Reduce downsampling

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = self.conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Encoding path
        encoder1 = self.encoder1(x)
        encoder2 = self.encoder2(F.max_pool2d(encoder1, 2))
        encoder3 = self.encoder3(F.max_pool2d(encoder2, 2))
        encoder4 = self.encoder4(F.max_pool2d(encoder3, 2))
        encoder5 = self.encoder5(F.max_pool2d(encoder4, 2))

        # Bottleneck
        bottleneck = self.bottleneck(encoder5)

        # Decoding path
        decoder4 = self.upconv4(bottleneck)
        decoder4 = torch.cat((decoder4, F.interpolate(encoder5, size=decoder4.shape[2:], mode="bilinear", align_corners=False)), dim=1)
        decoder4 = self.decoder4(decoder4)

        decoder3 = self.upconv3(decoder4)
        decoder3 = torch.cat((decoder3, F.interpolate(encoder4, size=decoder3.shape[2:], mode="bilinear", align_corners=False)), dim=1)
        decoder3 = self.decoder3(decoder3)

        decoder2 = self.upconv2(decoder3)
        decoder2 = torch.cat((decoder2, F.interpolate(encoder3, size=decoder2.shape[2:], mode="bilinear", align_corners=False)), dim=1)
        decoder2 = self.decoder2(decoder2)

        decoder1 = self.upconv1(decoder2)
        decoder1 = torch.cat((decoder1, F.interpolate(encoder2, size=decoder1.shape[2:], mode="bilinear", align_corners=False)), dim=1)
        decoder1 = self.decoder1(decoder1)

        # Final upsampling to match input size
        final_output = F.interpolate(decoder1, size=x.shape[2:], mode="bilinear", align_corners=False)

        return self.out_conv(final_output)


# Instantiate and test the updated model
model_resnet_unet = UNetWithResNetEncoder(out_channels=1)
x = torch.randn(1, 3, 256, 256)
output = model_resnet_unet(x)
print(f"Output shape: {output.shape}")


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 166MB/s]


Output shape: torch.Size([1, 1, 256, 256])


In [None]:
# Train the model with training and validation
NUM_EPOCHS = 10
lr = 0.001

# Initialize model, criterion, optimizer, and dataloaders
model_resnet_unet = UNetWithResNetEncoder(out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model_resnet_unet.parameters(), lr=lr)

train_model(model=model_resnet_unet, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, num_epochs=NUM_EPOCHS, device=device)

Epoch 1/10
  Train Loss: 0.5677, Train Dice: 0.7259
  Val Loss: 2.1767, Val Dice: 0.5009
Epoch 2/10
  Train Loss: 0.4407, Train Dice: 0.7863
  Val Loss: 0.4156, Val Dice: 0.7589
Epoch 3/10
  Train Loss: 0.4271, Train Dice: 0.7733
  Val Loss: 0.5171, Val Dice: 0.6980
Epoch 4/10
  Train Loss: 0.4015, Train Dice: 0.7917
  Val Loss: 0.3535, Val Dice: 0.7859
Epoch 5/10
  Train Loss: 0.4088, Train Dice: 0.7925
  Val Loss: 0.3706, Val Dice: 0.7809
Epoch 6/10
  Train Loss: 0.3959, Train Dice: 0.7961
  Val Loss: 0.4648, Val Dice: 0.7046
Epoch 7/10
  Train Loss: 0.3812, Train Dice: 0.7992
  Val Loss: 0.3775, Val Dice: 0.7584
Epoch 8/10
  Train Loss: 0.3732, Train Dice: 0.8043
  Val Loss: 0.4070, Val Dice: 0.7783
Epoch 9/10
  Train Loss: 0.3639, Train Dice: 0.8052
  Val Loss: 0.4476, Val Dice: 0.7096
Epoch 10/10
  Train Loss: 0.3557, Train Dice: 0.8085
  Val Loss: 0.2857, Val Dice: 0.8392


In [None]:
def predict(model, image, device, threshold=0.5):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        image = image.to(device)  # Move the image to the specified device
        output = model(image)  # Forward pass through the model
        output = torch.sigmoid(output)  # Apply sigmoid to get probabilities
        mask = (output > threshold).float()  # Binarize using the threshold
    return mask.squeeze(0).squeeze(0)  # Remove batch and channel dimensions

def evaluate_on_test(model, test_loader, device):
    model.to(device)
    model.eval()
    dice_scores = []

    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device)  # Move data to the device

            # Make predictions
            outputs = model(images)
            outputs = torch.sigmoid(outputs)  # Convert logits to probabilities
            preds = (outputs > 0.5).float()  # Binarize predictions

            # Compute Dice coefficient
            dice = dice_coefficient(preds, masks)
            dice_scores.append(dice)

    avg_dice = sum(dice_scores) / len(dice_scores)
    print(f"Average Dice Coefficient on Test Dataset: {avg_dice:.4f}")

evaluate_on_test(model_resnet_unet, test_loader, device)

Average Dice Coefficient on Test Dataset: 0.8499
