## U Net Paper

### Note: Working colab code @ https://drive.google.com/drive/folders/1jw0KE6uZ0a625Rniy6xxB_8GzKK3c1FZ?usp=sharing


<img src="./helper/1.jpg" alt="Drawing" style="width: 800px;"/>

**Note:It is fully convolutional, no fully connected layer used** 
* Input: Grayscale 572*572 
* Output: 388*388 2 output classes
* At every level two 3*3 **valid** convolution 
* * Valid convolution (no padding): 2nd layer shape = (572-3)/1+1 = 570  {(input-filter)/stride+1 }
* downsample 2*2 stride 2 downsampling
* Upsample using transpose convolution 
* * The resulting layer is a concatenation of 
* * * skip connection (we need to crop the image in the left of U to match dimension in the output upsampled layer)
* * * Upsampling

* final 1*1 conv to get to number of classes


* The downsample gives what we are looking at. Upsample (with skip connection) tells where it is.
* For a large image, they take tiles, put padding with mirroring strategy (so that context is there). Then stitch the results together 

#### Why not 3*3 same convolutions and preserve the feature map shape rather than down and up sample? 
* The receptive field doesnt grow that fast 
* Computationally expensive


## From Implementation 
(found @ model.py)

At every stage
* We have two 3*3 convolution layers (stride = 1) which does valid convolution. For ease in implementation, used 'same' convolution (padding = 1). Also, add extra BN layer. 
* we put bias = false as we use BN which will cancel the bias 
* For first stage in_channels = 1, out_channels = 64, for second in_channels = 64, out_channels = 128

![](./helper/1.PNG)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1 ,1 , bias = False), # in_ch, out_ch, kernel_size, stride, padding
            nn.BatchNorm2d(out_channels),
            nn.ReLu(inplace = True), 
            nn.Conv2d(out_channels, out_channels, 3, 1 ,1 , bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLu(inplace = True) 

        )


## For the down part of U net: 
* We need the 2 convolutions and then maxpooling to be repeated 4 times. The out_channels at these stages be : 64,128,256,512

* * We make a list with function that creates 2 convolution layers of required dimensions 
* * * 1st entry (1st stage): in_channels = 1, out_channels = 64 (1st in the list features)
* * * 2nd entry (2nd stage): in_channels = 64 (prev feature), out_channels = 128 (2nd in the list features) and so on
* * We loop through this list and 
* * * Make 2 layers of conv (DoubleConv fn)
* * * Downsample (maxpool)

In [None]:
features =[64,128,256,512] # the out_channels/filters at every stage as we go down U net

self.downs = nn.ModuleList()
self.pools = nn.MaxPool2d(kernel_size = 2, stride = 2)

for feature in features:
    self.downs.append(DoubleConv(in_channels, feature))
    in_channels = feature

# let x be the input passed to this function 

for down in self.downs:
    x = down(x)     # 2 layers of conv
    x = self.pool(x)


We need bottleneck layer which is in_channels = 512 (features[-1]), out_channels = 1024

In [None]:
self.bottleneck = DoubleConv(features[-1], features[-1]*2)

## For up part of U Net 

* We need to make a list of skip connections (They are the output from last conv layer in each down stage)
* * Reverse the list (since last skip connection created going down gets used first when going up the U net)

* As we go up out_channels are 512, 256, 128, 64 (i.e. reversed(features))

In the implementation below:
* * We make a list with transpose2d, DoubleConv2d of required dimension
* * We create a new loop, we perform two entries in the list above in one iteration 
* * * We do transpose2d (upsample)
* * * concatenate result with skip connection (reversed list)
* * * maxpooling and upsampling can reduce the dimensions (eg input size = 161, after 4 stages botteleneck is 40, when we upsample it is 160: 1 pixel less: concatenation of skip + upsample not possible (Paper crops the skip connection to match upsampled output. Here resize our skip connection before concatenation)
* * * pass concatenated through Double Conv (2 conv layers)

             

In [None]:
### make list of skip connections and reverse
 
skip_connections = [] 

for down in self.downs: # (shown separately just for illustration, else cache during going down )
    x = down(x)     # down is basically 2 conv layers
    skip_connections.append(x)
 
skip_connections = skip_connections[::-1]


### list with transpose2d, DoubleConv2d of required dimension

self.ups = nn.ModuleList()

for feature in reversed(features):
    self.ups.append(
        nn.ConvTranspose2d(
            feature*2, feature, kernel_size = 2, stride = 2,
        )
    )
    self.ups.append(DoubleConv(feature*2, feature))
        
## Create new loop, we perform two entries in the list above in one iteration 

for idx in range(0, len(self.ups), 2):
    x = self.ups[idx](x)    # upsample( transpose)

    skip_connection = skip_connections[idx//2] # 2 steps in self.ups, we only have one skip per stage

    if x.shape != skip_connection.shape: 
                x = TF.resize(x, size = skip_connection.shape[2:])


    concat_skip = torch.cat((skip_connection, x), dim = 1) # concat along the depth

    x = self.ups[idx+1](concat_skip) #2 layers of conv

## Loss function and training 

@ dataset.py 

We convert mask to binary {0,1} so that it can be used as logits ground truth in our BCE loss later. 

@ train.py 

We use BCEwithlogit loss (this passes our predictions through sigmoid and then applies BCE on it vs target) 
#### Note
BCE is acting on this 2D array of mask and predicted image. (default it takse mean to reduce the array i believe)

Ref: 
Learnings from implementation found @ https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet

## BCE nuances 

* BCE is used as method of reconstruction (autoencoder, segmentation)
* BCE is calculated for each pixel according to usual logit loss formula. Then we can reduce it by sum/ mean (default)
* We check the same below 
* * (BCEwithlogits -> We dont apply sigmoid activation to final layer (same done in scratch implementation) )
* * Apply sigmoid, then take BCEloss 


In [None]:
import torch 

pred = torch.randn((3, 1, 160, 240)) # output of last layer without activation (linear activation??)
target = torch.randint(low=0, high=2, size = (3, 1, 160, 240), dtype=torch.float32 ) # (0,1 mask)

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
criterion(pred, target)

# this gives output tensor(0.8079)

In [None]:
m = torch.nn.Sigmoid()
loss = torch.nn.BCELoss()

loss(m(pred), target)

# m(pred) brings it to [0,1]