# Setup

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.transforms import v2
from torchvision.models.resnet import ResNet18_Weights
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataloader

In [None]:
class DMCDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None, mask_transform=None):
        '''
        '''
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.mask_transform = mask_transform
        # print('Dataloader initialized')

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load the image and mask
        # print('Loading image:', self.image_paths[idx])
        image = Image.open(self.image_paths[idx])
        mask = Image.open(self.mask_paths[idx])
        # print('Images loaded')
    
        # Convert both image and mask to tensors, apply any specified transformations
        if self.transform:
            image = self.transform(image)
            mask = self.mask_transform(mask)
        
        # print('Images transformed to tensors')
        
        # Ensure mask is a binary tensor (0 or 1 values)
        # mask = (mask > 0.5).float() # Threshold mask to binary if not already

        # print('Mask thresholded')

        return image, mask

In [None]:
transform = v2.Compose([
    v2.Resize((512, 512), interpolation=InterpolationMode.BILINEAR), # Resize to match model input
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ResNet normalization
])

mask_transform = v2.Compose([
    v2.Resize((512, 512), interpolation=InterpolationMode.BILINEAR), # Resize to match model input
    v2.Grayscale(num_output_channels=1), # ensure mask is grayscale
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

In [None]:
# Get paths to all images and masks
train_image_paths = glob.glob('../data/synth_data/noisy/*.png')[:500]
train_mask_paths = glob.glob('../data/synth_data/ground_truth/*.png')[:500]

test_image_paths = glob.glob('../data/synth_data/noisy/*.png')[-500:]
test_mask_paths = glob.glob('../data/synth_data/ground_truth/*.png')[-500:]

# Create dataset and dataloader
train_dataset = DMCDataset(train_image_paths, train_mask_paths, transform=transform, mask_transform=mask_transform)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)
test_dataset = DMCDataset(test_image_paths, test_mask_paths, transform=transform, mask_transform=mask_transform)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
# train dataset example
for i, (image, mask) in enumerate(train_dataset):
    print(image.shape, mask.shape)
    break

# train Dataloader example
for i, (image, mask) in enumerate(train_dataloader):
    print(image.shape, mask.shape)
    image = v2.ToPILImage()(image[0].squeeze(0))
    # image.show()

    mask = v2.ToPILImage()(mask[0].squeeze(0))
    # mask.show()
    break

In [None]:
# test dataset example
for i, (image, mask) in enumerate(test_dataset):
    print(image.shape, mask.shape)
    break

# test Dataloader example
for i, (image, mask) in enumerate(test_dataloader):
    print(image.shape, mask.shape)
    image = v2.ToPILImage()(image[0].squeeze(0))
    image.show()

    mask = v2.ToPILImage()(mask[0].squeeze(0))
    mask.show()
    break

# Modifying Architecture

In [None]:
class binarizer(nn.Module):
    def __init__(self):
        super(binarizer, self).__init__()

        # Load the pretrained ResNet-18 model
        resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)

        # Use the ResNet layers up to the last layer (remove the fully connected layer)
        self.encoder = nn.Sequential(*list(resnet18.children())[:-2]) # output size = 512 x 8 x 8

        # Define the decoder part to upsample back to 256x256
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            # nn.Conv2d(16, 1, kernel_size=1), # Output layer with 1 channel for binary mask
            nn.Sigmoid() # Use sigmoid for binary output in range [0, 1]
        )

    def forward(self, x):
        '''Pass input through the resnet-18 encoder and custom decoder'''
        # print('FORWARD A')
        # print(x.shape)
        # print(x)
        x = self.encoder(x)
        # print('FORWARD B')
        x = self.decoder(x)
        return x

# Training Loop

In [None]:
# setup
model = binarizer().to(device)

# params
lr = 0.001
num_epochs = 10

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

In [None]:
# Actual training loop
for epoch in range(num_epochs):
    # print('EPOCH', epoch)
    model.train()
    # print('ok')
    running_loss = 0.0

    count = 0
    for inputs, masks in train_dataloader: # consider renaming masks to targets
        inputs, masks = inputs.to(device), masks.to(device)

        # print('forward pass')
        # Forward pass
        outputs = model(inputs) # Add batch dimension to input

        # print('loss calculate')
        # Calculate loss
        loss = criterion(outputs, masks)

        # print('backwards pass')
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        count += 1
        print(count, end='\r')

        if count == 10:
            break
    break

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_dataloader):.4f}')

# Evaluation

In [None]:
# Assuming you have a test DataLoader providing (images, ground_truth_masks)
def evaluate_model(model, test_loader, threshold=0.5):
    model.eval()  # Set model to evaluation mode
    
    iou_total = 0
    dice_total = 0
    accuracy_total = 0
    count = 0
    
    with torch.no_grad():  # No need to compute gradients for evaluation
        for images, ground_truths in test_loader:
            # Move data to the same device as the model
            images = images.to(device)
            ground_truths = ground_truths.to(device)
            
            # Forward pass
            outputs = model(images)
            predictions = (outputs > threshold).float()  # Binarize predictions
            # predictions = outputs # inspect the output pre binarization
            
            # Calculate IoU
            intersection = (predictions * ground_truths).sum((1, 2, 3))
            union = (predictions + ground_truths).sum((1, 2, 3)) - intersection
            iou = (intersection / union).mean()  # Mean IoU for batch
            iou_total += iou.item()
            
            # Calculate Dice Coefficient
            dice = (2 * intersection / (predictions.sum((1, 2, 3)) + ground_truths.sum((1, 2, 3)))).mean()
            dice_total += dice.item()
            
            # Calculate Pixel Accuracy
            correct = (predictions == ground_truths).float().sum()
            accuracy = correct / ground_truths.numel()
            accuracy_total += accuracy.item()
            
            count += 1
            print(count, end='\r')

            if count == 10:
                # show a few examples
                for i in range(1):
                    image = v2.ToPILImage()(images[i].squeeze(0))
                    mask = v2.ToPILImage()(ground_truths[i].squeeze(0))
                    pred = v2.ToPILImage()(predictions[i].squeeze(0))
                    image.show()
                    mask.show()
                    pred.show()
                break
    
    # Average the metrics over the whole test set
    mean_iou = iou_total / count
    mean_dice = dice_total / count
    mean_accuracy = accuracy_total / count
    
    print(f'Mean IoU: {mean_iou:.4f}')
    print(f'Mean Dice Coefficient: {mean_dice:.4f}')
    print(f'Mean Pixel Accuracy: {mean_accuracy:.4f}')
    
    return mean_iou, mean_dice, mean_accuracy

In [None]:
evaluate_model(model, test_dataloader)