<a href="https://colab.research.google.com/github/tristanoprofetto/neural-networks/blob/main/CNN/U-Net/U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
import torch.nn.functional as F
from skimage import io

In [2]:
def show_image(images, n_images, size=(1, 28, 28)):

  unflat = images.detach().cpu().view(-1, *size)
  grid = make_grid(unflat[:n_images], nrow=4)
  plt.imshow(grid.permute(1, 2, 0).squeeze())
  plt.show()

In [5]:
# Function for cropping an image given a tensor and the desired shape
def crop(image, new_size):

  h = image.shape[2] //2 # Height
  w = image.shape[3] //2 # Width

  start_h = h - new_size[2] //2
  start_w = w - new_size[3] //2

  final_h = start_h + new_size[2]
  final_w = start_w + new_size[3]

  cropped = image[:, :, start_h:final_h, start_w:final_w]

  return cropped

In [4]:
class ContractingBlock(nn.Module):

  def __init__(self, inputChannels):
    super(ContractingBlock, self).__init__()

    self.c1 = nn.Conv2d(inputChannels, 2*inputChannels, kernel_size=3)
    self.c2 = nn.Conv2d(2*inputChannels, 2*inputChannels, kernel_size=3)
    self.activation = nn.ReLU()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  
  def feedForward(self, x):

    x = self.c1(x)
    x = self.activation(x)
    x = self.c2(x)
    x = self.activation(x)
    x = self.pool(x)

    return x

In [7]:
class ExpandingBlock(nn.Module):

  def __init__(self, inputChannels):
    super(ExpandingBlock, self).__init__()

    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.c1 = nn.Conv2d(inputChannels, inputChannels // 2, kernel_size=2, stride=1)
    self.c2 = nn.Conv2d(inputChannels, inputChannels //2, kernel_size=3, stride=1)
    self.c3 = nn.Conv2d(inputChannels //2, inputChannels //2, kernel_size=3, stride=1)

    self.activation = nn.ReLU()


  def feedForward(self, x, skip_connection):

    x = self.upsample(x)
    x = self.c1(x)
    
    skip_connection = crop(skip_connection, x.shape)
    x = torch.cat([x, skip_connection], axis=1)

    x = self.c2(x)
    x = self.activation(x)

    x = self.c3(x)
    x = self.activation(x)

    return x

In [8]:
class FeatureMapBlock(nn.Module):

  def __init__(self, inputChannels, outputChannels):
    super(FeatureMapBlock, self).__init__()

    self.conv = nn.Conv2d(inputChannels, outputChannels, kernel_size=1)


  def feedForward(self, x):

    x= self.conv(x)

    return x

In [10]:
class UNet(nn.Module):

  def __init__(self, inputChannels, outpuChannels, hiddenChannels=64):
    super(UNet, self).__init__()

    self.upfeature = nn.FeatureMapBlock(inputChannels, hiddenChannels)
    
    self.c1 = ContractingBlock(hiddenChannels)
    self.c2 = ContractingBlock(2*hiddenChannels)
    self.c3 = ContractingBlock(4*hiddenChannels)
    self.c4 = ContractingBlock(8*hiddenChannels)

    self.e1 = ExpandingBlock(16*hiddenChannels)
    self.e2 = ExpandingBlock(8*hiddenChannels)
    self.e3 = ExpandingBlock(4*hiddenChannels)
    self.e4 = ExpandingBlock(2*hiddenChannels)

    self.downfeature = FeatureMapBlock(hiddenChannels, outputChannels)


  def feedForward(self, x):

    x0 = self.upfeature(x)
    x1 = self.c1(x0)
    x2 = self.c2(x1)
    x3 = self.c3(x2)
    x4 = self.c4(x3)

    x5 = self.e1(x4, x4)
    x6 = self.e2(x5, x2)
    x7 = self.e3(x6, x1)
    x8 = self.e4(x7, x0)

    x = self.downfeature(x8)

    return x

### Training 

* criterion: the loss function
* epochs: the number of training steps
* input_dim: the number of channels of the input image
* label_dim: the number of channels of the output image
* display_step: how often to display/visualize the images
* batch_size: the number of images per forward/backward pass
* lr: the learning rate
* initial_shape: the size of the input image (in pixels)
* target_shape: the size of the output image (in pixels)
* device: gpu or cpu

In [14]:
criterion = nn.BCEWithLogitsLoss()
epochs = 200
input_dim = 1
label_dim = 1
display_step = 20
batch_size = 4
lr = 0.0002
initial_shape = 512
target_shape = 373
device = 'cuda'

In [None]:
volumes = torch.Tensor(io.imread('/train'))[:, None, :, :] / 255
labels = torch.Tensor(io.imread('/labels.tif', plugin="tifffile"))[:, None, :, :] / 255
labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
dataset = torch.utils.data.TensorDataset(volumes, labels)

In [None]:
def train():
  data = DataLoader(dataset, batch_size, shuffle=True)

  model = UNet(input_dim, label_dim).to(device)
  optimizer = torch.optim.Adam(model.parameter(), lr=lr)
  current_step = 0

  for i in range(0, epochs):
    for real, labels in tqdm(data):
      current_batch = len(real)

      real = real.to(device)
      labels = labels.to(device)

      optimizer.zero_grad()

      predictions = model(real)
      loss = criterion(predictions, labels)
      loss.backward()

      optimizer.step()

      if current_step % display_step == 0:

        print(f"Epoch {epoch}: Step {current_step}: Model Loss: {loss.item.item()}")

        show_image(crop(real, torch.Size([len(real), 1, target_shape])), size=(input_dim, target_shape, target_shape))
        show_image(labels, size=(label_dim, target_shape, target_shape))
        show_image(torch.sigmoid(predictions), size=(label_dim, target_shape, target_shape))

      current_step += 1

train()