# Assignment 10

Welcome to the assignment for week 10.

## Task 10: UNet for Semantic Segmentation

In this assignement we are going to program our own UNet model (https://arxiv.org/pdf/1505.04597.pdf). This model outputs a segmentation map. This segmentation map can be a little bit smaller than the true map but should keep the same spatial structure. This map however is composed of several layers, one per class. The goal of the network is to activate a each class layer according to the detected pixels per class.

In [None]:
from IPython.display import Image
Image(url= "https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png", width=700)

Feel free to take this picture as a reference when designing your UNet, but you don't have to stick to the these layer dimensions exactly. You know several techniques to downsample and upsample an image or (latent activation of an image) by now. Use them to your advantage. Don't forget the skip connections, indicated as horizontal grey arrows.

### Task 10.0 Prepare the PascalVOC dataset

* To speed things up, you will find code for preparing the PascalVOC dataset. 
* Prepare the dataset for your needs.

In [2]:
import numpy as np
import torch
import torchvision
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from PIL import Image


In [3]:
class VOCSegLoader(torchvision.datasets.VOCSegmentation):
    def __init__(self,
                 root,
                 year='2012',
                 image_set='train',
                 download=False,
                 transform=None,
                 target_transform=None,
                 transforms=None):

        super(VOCSegLoader, self).__init__(root, year, image_set, download, transform, target_transform, transforms)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        target = np.array(target)
        target[target == 255] = 0
        target = Image.fromarray(target)

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        target = torch.as_tensor(np.asarray(target, dtype=np.uint8), dtype=torch.long)
        return img, target

In [None]:
_epochs = 3
batch_size_train = 100
batch_size_val = 100
learning_rate = 0.001
momentum = 0.9
log_interval = 10
image_size = (64, 85)


transform_data = torchvision.transforms.Compose([torchvision.transforms.Resize(image_size),
                                                 torchvision.transforms.ToTensor()])
transform_label = torchvision.transforms.Compose([torchvision.transforms.Resize(image_size, interpolation=0)])


train_dataset = VOCSegLoader('./data', year='2012', image_set='train', download=True,
                             transform=transform_data, target_transform=transform_label)
val_dataset = VOCSegLoader('./data', year='2012', image_set='val', download=True,
                           transform=transform_data , target_transform=transform_label)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size_train)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size_val)

In [None]:
def visualize_samples(dataloader, num_samples=5):
    # Get a batch of data
    dataiter = iter(dataloader)
    images, masks = next(dataiter)
    
    # Create a figure with a grid of subplots
    fig, axes = plt.subplots(num_samples, 2, figsize=(10, num_samples * 3))
    
    for idx in range(num_samples):
        # Get image and mask for current sample
        img = images[idx]
        mask = masks[idx]
        
        # Convert tensor to PIL Image for display
        img_display = to_pil_image(img)
        
        # Display original image
        axes[idx, 0].imshow(img_display)
        axes[idx, 0].set_title(f'Sample {idx + 1} - Image')
        axes[idx, 0].axis('off')
        
        # Display segmentation mask
        axes[idx, 1].imshow(mask.numpy(), cmap='tab20')
        axes[idx, 1].set_title(f'Sample {idx + 1} - Target Mask')
        axes[idx, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples from the training set
print("Visualizing 5 samples from the training set:")
visualize_samples(train_loader, num_samples=5)

### Task 10.1 Setting up and training your UNet model

* Implement your UNet model. Make use of down- and upsampling techniques. **(RESULT)**
* Build a training loop for your model and train it for a few minutes. **(RESULT)**
* Report the model performance using a fitting metric for semantic segmentation. **(RESULT)**

In [6]:
# code here

### Task 10.2 Reinjection Link Manipulation (BONUS)

* Implement a way to manipulate the reinjection links in the UNet model and retrain without them. **(RESULT)**
* Report the change in performance with respect to your previous model. **(RESULT)**
* Implement an alternative model without reinjection links or down-and upsampling with a roughly similar number of parameters. Retrain and compare the new model's performance to your UNet models for this task. Which one performs best? **(RESULT)**

In [None]:
# code here

## Congratz, you made it! :)