# Image segmentation (with U-Nets)


In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from os import listdir
import skimage
from skimage import io


## U-Net Architecture
![architecture](.\architecture.png "UNet architecture")

### Contracting Path
It consits of 4 blocks and every block has the same architecture (typical architecture of a concolutional network)
1. Two 3x3 unpadded convolutions each followed by a rectified linear unit
    - because its is unpadded we loss with every convolution 2 pixels in each dimension
2. 2x2 max pooling operations with stride 2
    - with this setting the image get downsampled by a factor of 2

They start with 64 feature channels (filters) at the first block and with each block they double the number of feature channels

The output of the last block will be sended through two convoltuion layers (3x3) each followed by a ReLU, because of the doubling of the feature channels the resulting feature map for each layer is 1024 dimensional.

### Expansive Path
Its symetric to the contracting path, so again 4 blocks. However the architecture changes slightly
1. Upsampling by a 2x2 up-convolution (that halves the number of feature channels)
2. Copy a cropped version of the feature map from the corresponding feature map of the contracting path and concatenate with the upsampled feature map
    - The cropping is necessary due to the loss of border pixels in every convolution
3. Two 3x3 unpadded convolutions each followed by a rectified linear unit

At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes (in our case 2 classes). 

In total the network has 23 convolutional layers. (2\*4 (Contracting Path) + 2 (last layer) + 3\*4 (Expansive Path) + 1 (final layer))

In [9]:
# because 18 of the 23 convolutional layers uses the same setting, we defined a own function for this
def conv3x3(_input, output):
    return nn.Conv2d(_input, output, 3, padding=1)


# define block of contracting path
class ContractingBlock(nn.Module):
    def __init__(self,in_channels,channels):
        super(ContractingBlock, self).__init__()
        # for the first conv layer the number of input channels are the number of channels form the previous block and they will be doubled (first block starts with 64) 
        self.conv1 = conv3x3(in_channels,channels)
        self.conv2 = conv3x3(channels,channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample_block = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x, isInitBlock = 0):
        if not isInitBlock:
            x = self.downsample_block(x)
        
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        print(x.size())
        
        return x

# define block of expansive path
class ExpansiveBlock(nn.Module):
    def __init__(self, in_channels, channels):
        super(ExpansiveBlock, self).__init__()
        self.upsampled = nn.ConvTranspose2d(in_channels=in_channels, out_channels=channels, kernel_size=2, stride=2)
        self.conv1 = conv3x3(in_channels,channels)
        self.conv2 = conv3x3(channels,channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x, cop):
        x = self.upsampled(x)
        n, c, h, w = x.size()
        cop = cop[0:n,0:c,0:h,0:w]
        x = torch.cat([cop, x], dim=1)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        print(x.size())
        
        return x


In [10]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.contracted1 = ContractingBlock(3,64)
        self.contracted2 = ContractingBlock(64,128)
        self.contracted3 = ContractingBlock(128,256)
        self.contracted4 = ContractingBlock(256,512)
        self.bottom = ContractingBlock(512,1024)
        self.expanded1 = ExpansiveBlock(1024,512)
        self.expanded2 = ExpansiveBlock(512,256)
        self.expanded3 = ExpansiveBlock(256,128)
        self.expanded4 = ExpansiveBlock(128,64)
        self.final = nn.Conv2d(64, 2, 1, padding=1)
        
    def forward(self,x):
        con1 = self.contracted1(x, 1)
        con2 = self.contracted2(con1)
        con3 = self.contracted3(con2)
        con4 = self.contracted4(con3)
        bot = self.bottom(con4)
        exp1 = self.expanded1(bot,con4)
        exp2 = self.expanded2(exp1,con3)
        exp3 = self.expanded3(exp2,con2)
        exp4 = self.expanded4(exp3,con1)
        fin = self.final(exp4)
        
        return fin

unet = UNet()

## Loss-Function
For the loss a pixel-wise soft-max over the final feature map combined with the corss entropy loss function is computed.
The function torch.nn.functional.corss_entropy() does exactly this.

In [31]:
def EngergyFunction(featuremap, target,weight):
    n, c, h, w = featuremap.size()
    featuremap = featuremap.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.contiguous().view(-1)
    loss = F.cross_entropy(featuremap, target, weight=weight, reduction='elementwise_mean')
    
    return loss
        

## Optimizer
Like described in [1] we use stochastic gradient descent with a momentum of 0.99 for optimization.

In [13]:
optimizer = optim.SGD(
        unet.parameters(), 
        lr=0.01,
        momentum=0.99)

## Dataloader

In [14]:
def dataLister(path):
    data_path = path + '/images/'
    label_path = path + '/1st_manual/'
    data = listdir(data_path)
    labels = listdir(label_path)
    
    combined = list()
    for d, l in zip(data, labels):
        combined.append((data_path + d, label_path + l))
        
    #print(combined)
    
    return combined

In [26]:
def imageLoader(data_path, label_path):
    data_image = io.imread(data_path)
    label_image = io.imread(label_path)
    
    data = torch.from_numpy(data_image)
    label = torch.from_numpy(label_image)
    
    data = data.unsqueeze(0)
    label = label.unsqueeze(0).unsqueeze(0)
    
    data = data.transpose(2,3).transpose(1,2)
    data = data.type(torch.FloatTensor)
    label = label.type(torch.FloatTensor)
    
    return data, label

In [38]:
traindata = dataLister('./data/test')

for x in traindata:
    data, label = imageLoader(*x)
    #print(data.size())
    #print(label.size())



  strip = decompress(strip)


In [28]:
data = data.type(torch.FloatTensor)
print(data.type)
output = unet(data)

<built-in method type of Tensor object at 0x0000018954435EE8>
torch.Size([1, 64, 584, 565])
torch.Size([1, 128, 292, 282])
torch.Size([1, 256, 146, 141])
torch.Size([1, 512, 73, 70])
torch.Size([1, 1024, 36, 35])
torch.Size([1, 512, 72, 70])
torch.Size([1, 256, 144, 140])
torch.Size([1, 128, 288, 280])
torch.Size([1, 64, 576, 560])


In [32]:
n, c, h, w = output.size()
label = label[0:n,0:c,0:h,0:w]
print(EngergyFunction(output, label,1))

ValueError: Expected input batch_size (324836) to match target batch_size (1).

In [48]:
print(label.size())
label = label.contiguous().view(-1)
print(sum(label)/255)

torch.Size([329960])
tensor(24265.)


In [49]:
print(output.size())

torch.Size([1, 2, 578, 562])


## References
[1] U-Net: Convolutional Networks for Biomedical Image Segmentation