In [18]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

In [19]:
# Define the U-Net 3+ block
class UNet3PlusBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3PlusBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

In [20]:
# Define the U-Net 3+ architecture with VGG16 backbone
class UNet3PlusVGG16(nn.Module):
    def __init__(self, num_classes):
        super(UNet3PlusVGG16, self).__init__()

        # Load the pretrained VGG16 model
        vgg16_model = vgg16(pretrained=True)

        # Encoder
        self.encoder = nn.ModuleList([
            vgg16_model.features[:4],   # Block 1
            vgg16_model.features[4:9],  # Block 2
            vgg16_model.features[9:16],  # Block 3
            vgg16_model.features[16:23]  # Block 4
        ])

        # Decoder
        self.decoder = nn.ModuleList([
            UNet3PlusBlock(512, 256),
            UNet3PlusBlock(256, 128),
            UNet3PlusBlock(128, 64),
            UNet3PlusBlock(64, 64)
        ])

        # Full-scale skip connections
        self.full_scale_skip = nn.ModuleList(
            [nn.Conv2d(64 * 2 ** i, 64 * 2 ** (4 - i), kernel_size=1) for i in range(4)])

        # Final convolutional layer
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        encoder_outputs = []
        for block in self.encoder:
            x = block(x)
            encoder_outputs.append(x)
            x = nn.MaxPool2d(2, 2)(x)

        for i, block in enumerate(self.decoder):
            x = nn.functional.interpolate(
                x, scale_factor=2, mode='bilinear', align_corners=True)
            x = torch.cat([x, encoder_outputs.pop(),
                          self.full_scale_skip[i](encoder_outputs[i])], dim=1)
            x = block(x)

        return self.final_conv(x)

In [21]:
class FHCDataset(Dataset):
    def __init__(self, csv_file, img_dir, mask_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        mask_path = os.path.join(
            self.mask_dir, self.data.iloc[idx, 0].replace('.png', '_Mask.png'))

        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        image = np.expand_dims(image, axis=2)
        mask = np.expand_dims(mask, axis=2)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

In [22]:
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [25]:
# Load the datasets
train_csv = 'Dataset/training_set_pixel_size_and_HC.csv'
val_csv = 'Dataset/val_csv.csv'
train_img_dir = 'Dataset/training_set/images'
train_mask_dir = 'Dataset/training_set/masks'
val_img_dir = 'Dataset/val_set/images'
val_mask_dir = 'Dataset/val_set/masks'

In [26]:
train_dataset = FHCDataset(train_csv, train_img_dir,
                           train_mask_dir, transform=transform)
val_dataset = FHCDataset(val_csv, val_img_dir,
                         val_mask_dir, transform=transform)

In [27]:
# Create data loaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet3PlusVGG16(num_classes=1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/suraj/.cache/torch/hub/checkpoints/vgg16-397923af.pth
40.8%


RuntimeError: invalid hash value (expected "397923af", got "50a066fff529ce943f2551543408a6f64c7cdbd85969259ef485baf296fd196e")