In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models

In [4]:
# Defining the Segmentation Dataset class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = os.listdir(image_dir)
        self.mask_files = os.listdir(mask_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        mask_name = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_name).convert("RGB")
        mask = Image.open(mask_name).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return {'image': image, 'mask': mask}

In [5]:
# Defining image and mask directories
image_dir = '/content/drive/MyDrive/images'
mask_dir = '/content/drive/MyDrive/masks'

In [6]:
# Defining augmented transformations
aug_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor()
])

In [7]:
# Creating augmented dataset
augmented_dataset = SegmentationDataset(image_dir, mask_dir, transform=aug_transform)

In [8]:
# Spliting the augmented dataset
train_size_aug = int(0.6 * len(augmented_dataset))
val_size_aug = int(0.2 * len(augmented_dataset))
test_size_aug = len(augmented_dataset) - train_size_aug - val_size_aug

train_dataset_aug, val_dataset_aug, test_dataset_aug = random_split(augmented_dataset,
                                                                    [train_size_aug, val_size_aug, test_size_aug])

# Defining batch size and creating data loaders for augmented dataset
batch_size_aug = 16
train_loader_aug = DataLoader(train_dataset_aug, batch_size=batch_size_aug, shuffle=True)
val_loader_aug = DataLoader(val_dataset_aug, batch_size=batch_size_aug, shuffle=False)
test_loader_aug = DataLoader(test_dataset_aug, batch_size=batch_size_aug, shuffle=False)


In [9]:
# Defining the UNet architecture for segmentation with ResNet-50 backbone
class UNetWithResNet(nn.Module):
    def __init__(self):
        super(UNetWithResNet, self).__init__()
        # Loading pre-trained ResNet-50 as encoder
        self.encoder = models.resnet50(pretrained=True)
        # Modifing the first layer to accept 3 channels instead of 1
        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Defining decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

    def forward(self, x):
        # Forward pass through ResNet-50 encoder
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)
        x = self.encoder.layer1(x)
        x = self.encoder.layer2(x)
        x = self.encoder.layer3(x)
        x = self.encoder.layer4(x)
        # Forward pass through decoder
        x = self.decoder(x)
        return x

In [10]:
# Defining Binary Cross-Entropy Loss for segmentation
criterion_seg = nn.BCEWithLogitsLoss()

# Defining a new segmentation model instance with ResNet-50 backbone
segmentation_model_h3 = UNetWithResNet()

# Freezing the ResNet-50 layers
for param in segmentation_model_h3.encoder.parameters():
    param.requires_grad = False

# Defining new optimizer with adjusted learning rate
optimizer_seg_h3 = torch.optim.Adam(segmentation_model_h3.decoder.parameters(), lr=0.0001)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 117MB/s]


In [11]:
# Initializing variables for early stopping
num_epochs_aug = 10
best_val_loss = float('inf')
patience = 3  # Number of epochs to wait if validation loss stops improving
counter = 0  # Counter to keep track of epochs without improvement

# Training the segmentation model with augmented dataset and early stopping
for epoch in range(num_epochs_aug):
    segmentation_model_h3.train()
    running_loss_aug = 0.0

    for batch_idx, batch_aug in enumerate(train_loader_aug):
        images_aug, masks_aug = batch_aug['image'], batch_aug['mask']

        optimizer_seg_h3.zero_grad()

        outputs_aug = segmentation_model_h3(images_aug)
        masks_aug_resized = nn.functional.interpolate(masks_aug, size=outputs_aug.shape[2:], mode='bilinear', align_corners=True)

        loss_aug = criterion_seg(outputs_aug, masks_aug_resized)
        loss_aug.backward()
        optimizer_seg_h3.step()

        running_loss_aug += loss_aug.item() * images_aug.size(0)

    epoch_loss_aug = running_loss_aug / len(train_dataset_aug)
    print(f'Augmented Segmentation Epoch [{epoch + 1}/{num_epochs_aug}], Training Loss: {epoch_loss_aug:.4f}')

    # Validation loss calculation
    segmentation_model_h3.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_idx, batch_val in enumerate(val_loader_aug):
            images_val, masks_val = batch_val['image'], batch_val['mask']
            outputs_val = segmentation_model_h3(images_val)
            outputs_val_resized = nn.functional.interpolate(outputs_val, size=(256, 256), mode='bilinear', align_corners=False)
            loss_val = criterion_seg(outputs_val_resized, masks_val)
            val_loss += loss_val.item() * images_val.size(0)
    val_loss /= len(val_dataset_aug)
    print(f'Validation Loss: {val_loss:.4f}')

    # Checking for improvement in validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
    else:
        counter += 1

    # Checking if early stopping conditions are met
    if counter >= patience:
        print("Early stopping triggered! Validation loss has not improved for {} epochs.".format(patience))
        break


Augmented Segmentation Epoch [1/10], Training Loss: 0.3414
Validation Loss: 0.1358
Augmented Segmentation Epoch [2/10], Training Loss: 0.0851
Validation Loss: 0.0648
Augmented Segmentation Epoch [3/10], Training Loss: 0.0631
Validation Loss: 0.0552
Augmented Segmentation Epoch [4/10], Training Loss: 0.0560
Validation Loss: 0.0526
Augmented Segmentation Epoch [5/10], Training Loss: 0.0511
Validation Loss: 0.0486
Augmented Segmentation Epoch [6/10], Training Loss: 0.0475
Validation Loss: 0.0460
Augmented Segmentation Epoch [7/10], Training Loss: 0.0450
Validation Loss: 0.0447
Augmented Segmentation Epoch [8/10], Training Loss: 0.0436
Validation Loss: 0.0432
Augmented Segmentation Epoch [9/10], Training Loss: 0.0425
Validation Loss: 0.0422
Augmented Segmentation Epoch [10/10], Training Loss: 0.0420
Validation Loss: 0.0419


In [12]:
torch.save(segmentation_model_h3.state_dict(), '/content/drive/MyDrive/SEG_Hypothesis_3.pth')