# Image segmentation (with U-Nets)


In [135]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim


## U-Net Architecture
![architecture](.\architecture.png "UNet architecture")

### Contracting Path
It consits of 4 blocks and every block has the same architecture (typical architecture of a concolutional network)
1. Two 3x3 unpadded convolutions each followed by a rectified linear unit
    - because its is unpadded we loss with every convolution 2 pixels in each dimension
2. 2x2 max pooling operations with stride 2
    - with this setting the image get downsampled by a factor of 2

They start with 64 feature channels (filters) at the first block and with each block they double the number of feature channels

The output of the last block will be sended through two convoltuion layers (3x3) each followed by a ReLU, because of the doubling of the feature channels the resulting feature map for each layer is 1024 dimensional.

### Expansive Path
Its symetric to the contracting path, so again 4 blocks. However the architecture changes slightly
1. Upsampling by a 2x2 up-convolution (that halves the number of feature channels)
2. Copy a cropped version of the feature map from the corresponding feature map of the contracting path and concatenate with the upsampled feature map
    - The cropping is necessary due to the loss of border pixels in every convolution
3. Two 3x3 unpadded convolutions each followed by a rectified linear unit

At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes (in our case 2 classes). 

In total the network has 23 convolutional layers. (2\*4 (Contracting Path) + 2 (last layer) + 3\*4 (Expansive Path) + 1 (final layer))

In [150]:
# because 18 of the 23 convolutional layers uses the same setting, we defined a own function for this
def conv3x3(_input, output):
    return nn.Conv2d(_input, output, 3, padding=1)

# define block of contracting path
class ContractingBlock(nn.Module):
    def __init__(self,in_channels,channels):
        super(ContractingBlock, self).__init__()
        # for the first conv layer the number of input channels are the number of channels form the previous block and they will be doubled (first block starts with 64) 
        self.conv1 = conv3x3(in_channels,channels)
        self.conv2 = conv3x3(channels,channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample_block = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x, isInitBlock = 0):
        if not isInitBlock:
            x = self.downsample_block(x)
        
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        print(x.size())
        
        return x
    
class BottomBlock(nn.Module):
    def __init__(self,in_channels,channels):
        super(BottomBlock, self).__init__()
        # for the first conv layer the number of input channels are the number of channels form the previous block and they will be doubled (first block starts with 64) 
        self.conv1 = conv3x3(in_channels,channels)
        self.conv2 = conv3x3(channels,channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        
        return x

# define block of expansive path
class ExpansiveBlock(nn.Module):
    def __init__(self, in_channels, channels):
        super(ExpansiveBlock, self).__init__()
        self.upsampled = nn.ConvTranspose2d(in_channels=in_channels, out_channels=channels, kernel_size=2, stride=2)
        self.conv1 = conv3x3(in_channels,channels)
        self.conv2 = conv3x3(channels,channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x, cop):
        x = self.upsampled(x)
        cc = transforms.CenterCrop(x.size())
        cop = cc(cop)
        x = torch.cat([cop, x], dim=1)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        
        return x


In [151]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.contracted1 = ContractingBlock(3,64)
        self.contracted2 = ContractingBlock(64,128)
        self.contracted3 = ContractingBlock(128,256)
        self.contracted4 = ContractingBlock(256,512)
        self.bottom = ContractingBlock(512,1024)
        self.expanded1 = ExpansiveBlock(1024,512)
        self.expanded2 = ExpansiveBlock(512,256)
        self.expanded3 = ExpansiveBlock(256,128)
        self.expanded4 = ExpansiveBlock(128,64)
        self.final = nn.Conv2d(64, 2, 1, padding=1)
        
    def forward(self,x):
        con1 = self.contracted1(x, 1)
        con2 = self.contracted2(con1)
        con3 = self.contracted3(con2)
        con4 = self.contracted4(con3)
        bot = self.bottom(con4)
        exp1 = self.expanded1(bot,con4)
        exp2 = self.expanded2(exp1,con3)
        exp3 = self.expanded3(exp2,con2)
        exp4 = self.expanded4(exp3,con1)
        fin = self.final(exp4)
        
        return fin

## Loss-Function


In [None]:
def EngergyFunction(featuremap, target,K):
    n, c, h, w = featuremap.size()
    featuremap = featuremap.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)
    
    for i in range(h):
        for j in range(w):
            
        

## Optimizer

## Dataloader

In [29]:
def load_dataset():
    data_path = './data/training'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )
    return train_loader

In [174]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
train = load_dataset()
dataiter = iter(train)
image, label = dataiter.next()
print(image.size())
#npimg = image.numpy()
#npimg = np.reshape(npimg, (1,5))
print(npimg.shape)
cc = transforms.CenterCrop([1,3,400,400])
npimg = cc(npimg)
#imshow(torchvision.utils.make_grid(image))

torch.Size([1, 3, 584, 565])
(1, 3, 584, 565)


TypeError: 'int' object is not iterable

In [152]:
unet = UNet()

In [153]:
output = unet(image)

torch.Size([1, 64, 584, 565])
torch.Size([1, 128, 292, 282])
torch.Size([1, 256, 146, 141])
torch.Size([1, 512, 73, 70])
torch.Size([1, 1024, 36, 35])


TypeError: 'builtin_function_or_method' object is not iterable