In [11]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

In [13]:
# Defining the Segmentation Dataset class
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = os.listdir(image_dir)
        self.mask_files = os.listdir(mask_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        mask_name = os.path.join(self.mask_dir, self.mask_files[idx])

        image = Image.open(img_name).convert("RGB")
        mask = Image.open(mask_name).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return {'image': image, 'mask': mask}

In [14]:
# Defining the augmented transformations
aug_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor()
])

In [15]:
# Defining image and mask directories
image_dir = '/content/drive/MyDrive/images'
mask_dir = '/content/drive/MyDrive/masks'

In [16]:
# Creating augmented dataset
augmented_dataset = SegmentationDataset(image_dir, mask_dir, transform=aug_transform)

In [17]:
# Spliting the augmented dataset
train_size_aug = int(0.6 * len(augmented_dataset))
val_size_aug = int(0.2 * len(augmented_dataset))
test_size_aug = len(augmented_dataset) - train_size_aug - val_size_aug

train_dataset_aug, val_dataset_aug, test_dataset_aug = random_split(augmented_dataset,
                                                                    [train_size_aug, val_size_aug, test_size_aug])

# Defining batch size and create data loaders for augmented dataset
batch_size_aug = 16
train_loader_aug = DataLoader(train_dataset_aug, batch_size=batch_size_aug, shuffle=True)
val_loader_aug = DataLoader(val_dataset_aug, batch_size=batch_size_aug, shuffle=False)
test_loader_aug = DataLoader(test_dataset_aug, batch_size=batch_size_aug, shuffle=False)

In [18]:
# Defining the UNet architecture for segmentation
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Defining encoder layers
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Defining decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),  # Additional upsampling layer
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [19]:
# Defining Binary Cross-Entropy Loss for segmentation
criterion_seg = nn.BCEWithLogitsLoss()

# Defining a new segmentation model instance
segmentation_model_h2 = UNet()

# Defining new optimizer with adjusted learning rate
optimizer_seg_h2 = torch.optim.Adam(segmentation_model_h2.parameters(), lr=0.0001)

In [22]:
# Training the segmentation model with augmented dataset and early stopping
num_epochs_aug = 10
early_stopping_patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs_aug):
    segmentation_model_h2.train()
    running_loss_aug = 0.0

    # Training loop
    for batch_idx, batch_aug in enumerate(train_loader_aug):
        images_aug, masks_aug = batch_aug['image'], batch_aug['mask']

        optimizer_seg_h2.zero_grad()

        outputs_aug = segmentation_model_h2(images_aug)
        masks_aug_resized = nn.functional.interpolate(masks_aug, size=outputs_aug.shape[2:], mode='bilinear', align_corners=True)

        loss_aug = criterion_seg(outputs_aug, masks_aug_resized)
        loss_aug.backward()
        optimizer_seg_h2.step()

        running_loss_aug += loss_aug.item() * images_aug.size(0)

    epoch_loss_aug = running_loss_aug / len(train_dataset_aug)
    print(f'Augmented Segmentation Epoch [{epoch + 1}/{num_epochs_aug}], Training Loss: {epoch_loss_aug:.4f}')

    # Validation loop
    segmentation_model_h2.eval()
    val_loss_aug = 0.0

    with torch.no_grad():
        for batch_idx, batch_val in enumerate(val_loader_aug):
            images_val, masks_val = batch_val['image'], batch_val['mask']
            outputs_val = segmentation_model_h2(images_val)
            masks_val_resized = nn.functional.interpolate(masks_val, size=outputs_val.shape[2:], mode='bilinear', align_corners=True)
            val_loss_aug += criterion_seg(outputs_val, masks_val_resized).item() * images_val.size(0)

    val_loss_aug /= len(val_dataset_aug)
    print(f'Augmented Segmentation Validation Loss: {val_loss_aug:.4f}')

    # Check for early stopping
    if val_loss_aug < best_val_loss:
        best_val_loss = val_loss_aug
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print(f'Early stopping after epoch {epoch + 1}')
            break


Augmented Segmentation Epoch [1/10], Training Loss: 0.1972
Augmented Segmentation Validation Loss: 0.0734
Augmented Segmentation Epoch [2/10], Training Loss: 0.0700
Augmented Segmentation Validation Loss: 0.0583
Augmented Segmentation Epoch [3/10], Training Loss: 0.0556
Augmented Segmentation Validation Loss: 0.0534
Augmented Segmentation Epoch [4/10], Training Loss: 0.0505
Augmented Segmentation Validation Loss: 0.0482
Augmented Segmentation Epoch [5/10], Training Loss: 0.0465
Augmented Segmentation Validation Loss: 0.0450
Augmented Segmentation Epoch [6/10], Training Loss: 0.0440
Augmented Segmentation Validation Loss: 0.0431
Augmented Segmentation Epoch [7/10], Training Loss: 0.0429
Augmented Segmentation Validation Loss: 0.0424
Augmented Segmentation Epoch [8/10], Training Loss: 0.0426
Augmented Segmentation Validation Loss: 0.0419
Augmented Segmentation Epoch [9/10], Training Loss: 0.0421
Augmented Segmentation Validation Loss: 0.0418
Augmented Segmentation Epoch [10/10], Training

In [24]:
torch.save(segmentation_model_h2.state_dict(), '/content/drive/MyDrive/SEG_Hypothesis_2.pth')