# Analisando a rede U-Net

Este notebook demonstra, camada por camada, o funcionamento da rede U-Net, uma arquitetura de rede convolucional para segmentação rápida e precisa de imagens.

O notebook é organizado da seguinte forma:

- importação das bibliotecas
- definição da rede U-Net
- demonstração da rede U-Net

##  Importação das bibliotecas

In [1]:
# image and file libs
from PIL import Image
from pandas.io.parsers import read_csv

# from os
import os.path

# from torch
import torch
from torch import nn
from torch.autograd import Variable

# from torchvision
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

## A rede U-Net

<img src='../figures/u-net-architecture.png', width=900pt></img>

A U-Net é uma arquitetura de rede convolucional para segmentação rápida e precisa de imagens. Mais detalhes podem ser vistos em https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

### Classe da rede U-Net

<img src='../figures/u-net.png', width=900pt></img>

In [2]:
import torch.nn.functional as F

class ConvBlock(torch.nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, padding=1, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Conv2d(in_size, out_size, kernel_size,
                              padding=padding, stride=stride)
        self.bn = torch.nn.BatchNorm2d(out_size)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# classe UNet
class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.down_1 = torch.nn.Sequential(
            ConvBlock(1, 16),
            ConvBlock(16, 32, stride=2, padding=1))

        self.down_2 = torch.nn.Sequential(
            ConvBlock(32, 64),
            ConvBlock(64, 128))

        self.middle = ConvBlock(128, 128, kernel_size=1, padding=0)

        self.up_2 = torch.nn.Sequential(
            ConvBlock(256, 128),
            ConvBlock(128, 32))

        self.up_1 = torch.nn.Sequential(
            ConvBlock(64, 64),
            ConvBlock(64, 32))

        self.output = torch.nn.Sequential(
            ConvBlock(32, 16),
            ConvBlock(16, 1, kernel_size=1, padding=0))

    def forward(self, x):
        down1 = self.down_1(x)
        out = F.max_pool2d(down1, kernel_size=2, stride=2)

        down2 = self.down_2(out)
        out = F.max_pool2d(down2, kernel_size=2, stride=2)

        out = self.middle(out)

        out = F.upsample(out, scale_factor=2)
        out = torch.cat([down2, out], 1)
        out = self.up_2(out)

        out = F.upsample(out, scale_factor=2)
        out = torch.cat([down1, out], 1)
        out = self.up_1(out)

        out = F.upsample(out, scale_factor=2)
        return self.output(out)

## Demonstração das camadas da rede U-Net

### Definindo a imagem de entrada

In [3]:
image = torch.FloatTensor(1,1,184,184)
image.shape

torch.Size([1, 1, 184, 184])

### Transformando a imagem em uma Variable

In [4]:
input_image = Variable(image)

In [5]:
input_image.shape

torch.Size([1, 1, 184, 184])

### Definindo a primeira camada convolucional de descida (down1)

In [6]:
down1 = torch.nn.Sequential(ConvBlock(1, 16), ConvBlock(16, 32, stride=2, padding=1))

### Passando a imagem pela primeira camada convolucional de descida (down1)

In [7]:
out_down1 = down1(input_image)
out_down1.shape

torch.Size([1, 32, 92, 92])

### Definindo a segunda camada convolucional de descida (down2)

In [8]:
down2 = torch.nn.Sequential(ConvBlock(32, 64), ConvBlock(64, 128))

### Aplicando a segunda camada convolucional de descida e o maxpooling (down2)

In [9]:
out_down2 = down2(F.max_pool2d(out_down1, kernel_size=2, stride=2))
out_down2.shape

torch.Size([1, 128, 46, 46])

### Definindo a camada convolucional central (middle)

In [10]:
middle = ConvBlock(128, 128, kernel_size=1, padding=0)

### Aplicando a camada central convolucional e o maxpooling (middle)

In [11]:
out_middle = middle(F.max_pool2d(out_down2, kernel_size=2, stride=2))
out_middle.shape

torch.Size([1, 128, 23, 23])

### Aplicando a segunda camada de upsample na subida (upsample2)

In [12]:
out_upsample2 = F.upsample(out_middle, scale_factor=2)
out_upsample2.shape

torch.Size([1, 128, 46, 46])

### Definindo a segunda camada convolucional de subida (up2)

In [13]:
up2 = torch.nn.Sequential(ConvBlock(256, 128), ConvBlock(128, 32))

### Concatenando a saída de down2 com updample2

In [14]:
concat2 = torch.cat([out_down2, out_upsample2], 1)
concat2.shape

torch.Size([1, 256, 46, 46])

### Aplicando a segunda camada convolucional de subida (up2)

In [15]:
out_up2 = up2(concat2)
out_up2.shape

torch.Size([1, 32, 46, 46])

### Aplicando a primeira camada de upsample na subida (upsample1)

In [16]:
out_upsample1 = F.upsample(out_up2, scale_factor=2)
out_upsample1.shape

torch.Size([1, 32, 92, 92])

### Definindo a primeira camada convolucional de subida (up1)

In [17]:
up1 = torch.nn.Sequential(ConvBlock(64, 64), ConvBlock(64, 32))

### Concatenando a saída de down1 e upsample1

In [18]:
concat1 = torch.cat([out_down1, out_upsample1], 1)
concat1.shape

torch.Size([1, 64, 92, 92])

### Aplicando a primeira camada convolucional de subida (up1)

In [19]:
out_up1 = up1(concat1)
out_up1.shape

torch.Size([1, 32, 92, 92])

### Aplicando a última camada de upsample na subida (upsample0)

In [20]:
out_upsample0 = F.upsample(out_up1, scale_factor=2)
out_upsample0.shape

torch.Size([1, 32, 184, 184])

### Definindo a camada convolucional de saída (output)

In [21]:
output = torch.nn.Sequential(ConvBlock(32, 16), ConvBlock(16, 1, kernel_size=1, padding=0))

### Aplicando a camada convolucional de saída (output)

In [22]:
output(out_upsample0).shape

torch.Size([1, 1, 184, 184])