# Import data and packages

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp "/content/drive/MyDrive/Colab Notebooks/helpers.py" .
!cp -av "/content/drive/MyDrive/Colab Notebooks/at_dataset" .

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import glob
import cv2 as cv
from helpers import *
%load_ext autoreload
%autoreload 2


In [None]:
#Accuracy 
def accuracy(pred, test_labels):
    '''
    pred: torch.tensor (result of U-net) of size [num_batches=1, 2, dim_image1, dim_image2]
    test_labels: torch.tensor (Real labels for the image) 
    '''
    '''
    Calculate the percentage of correct pixels labeled
    '''
    label_pred=torch.argmax(pred,dim=1)
    error = (torch.abs(label_pred-test_labels)).mean()
    return 1-error 

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling and conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # self.up_conv = nn.Sequential(
        #   nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
        #   nn.Conv2d(in_channels, out_channels, kernel_size=2)
        # )
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=2, padding=1, dilation=2)
        self.double_conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        # up-conv input
        x1 = self.up(x1)
        x1 = self.conv(x1)

        # crop and contatinate
        diffY = int((x2.size()[2] - x1.size()[2]) // 2)
        diffX = int((x2.size()[3] - x1.size()[3]) // 2)
        x2 = x2[:,:,diffY:-diffY, diffX:-diffX]
        x = torch.cat([x2, x1], dim=1)
        return self.double_conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class U_net(nn.Module):
    def __init__(self, n_channels=1, n_classes=2, bilinear=True):
        super(U_net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
#Training (adapting function from lab10)
def train(model, criterion, pathX, pathY, optimizer, num_epochs, data_size, split, device):
  """
  @param model: torch.nn.Module
  @param criterion: torch.nn.modules.loss._Loss
  @param image_input: numpy.ndarray
  @param labeled_images: numpy.ndarray
  @param optimizer: torch.optim.Optimizer
  @param num_epochs: int
  """
  in_size, out_size = 572, 388
  # split dataset
  imgs_names = np.array(sorted(glob.glob(pathX)),dtype=object)
  labels_names = np.array(sorted(glob.glob(pathY)),dtype=object)
  idx = np.random.choice(len(imgs_names),data_size)
  idx_train = idx[:int(split*data_size)]
  idx_test = idx[int(split*data_size):]
  img_names_train = imgs_names[idx_train]
  img_names_test = imgs_names[idx_test]
  labels_names_train = labels_names[idx_train]
  labels_names_test = labels_names[idx_test]

  print("Starting training")
  #Cycle for epochs
  for epoch in range(num_epochs):
    # Train an epoch
    model.train()
    for i,img_path in enumerate(img_names_train):
      print(f'epoch: {epoch+1}/{num_epochs}, img: {i+1}/{len(img_names_train)}')
      # load individual image and label and segment them
      img = [cv.imread(img_path, cv.IMREAD_UNCHANGED)/(2**16-1)]
      lbl = [cv.imread(labels_names_train[i], cv.IMREAD_UNCHANGED)]
      X, y = segment_dataset(img, lbl, in_size, out_size, extend = True, augment=False)
      Y = np.repeat(y[:, np.newaxis, :, :], 2, axis=1).astype(np.uint8)
      Y[:,0,:,:] ^= True

      for j in range(X.shape[0]):
        # print(f'epoch: {epoch+1}/{num_epochs}, img: {i+1}/{len(img_names_train)}, segment: {j+1}/{X.shape[0]}')
        # convert numpy values to tensor form and load to GPU
        tensor_X = torch.Tensor(X[j]).view(1,1,in_size,in_size)
        tensor_Y = torch.Tensor(Y[j]).view(1,2,out_size,out_size)
        tensor_X, tensor_Y = tensor_X.to(device), tensor_Y.to(device)

        # Evaluate the network (forward pass)
        prediction = model(tensor_X)
        loss = criterion(prediction,tensor_Y)
        
        # Compute the gradient
        optimizer.zero_grad()
        loss.backward()

        # Update the parameters of the model with a gradient step
        optimizer.step()

  model.eval()
  # accuracies_test = []
  for i,img_path in enumerate(img_names_test):
    # load individual image and label and segment them
    img = [cv.imread(img_path, cv.IMREAD_UNCHANGED)/(2**16-1)]
    lbl = [cv.imread(labels_names_test[i], cv.IMREAD_UNCHANGED)]
    X, y = segment_dataset(img, lbl, in_size, out_size, extend = True, augment=False)

    for j in range(X.shape[0]):
      # convert numpy values to tensor form and load to GPU
      tensor_X = torch.Tensor(X[j]).view(1,1,in_size,in_size)
      tensor_y = torch.Tensor(y[j].astype(np.uint8)).view(1,1,out_size,out_size)
      tensor_X, tensor_y= tensor_X.to(device), tensor_y.to(device)

      # Evaluate the network (forward pass)
      prediction = model(tensor_X)
      # accuracies_test.append(accuracy(prediction, tensor_y))
      accuracies_test = accuracy(prediction, tensor_y)
      print("Test accuracy: {:.5f}".format(accuracies_test.item()))

      pred = prediction.to("cpu")
      pred = torch.argmax(pred,dim=1)
      # plot
      plt.figure(figsize=(10,4))
      plt.subplot(121)
      plt.imshow(y[j])
      plt.axis('off')
      plt.title('Original')
      plt.subplot(122)
      plt.imshow(pred.view(388,388))
      plt.axis('off')
      plt.title('Prediction')
      plt.show()


In [None]:
num_epochs=15
learning_rate=0.001
data_size = 100
split = 0.65

pathX = '/content/at_dataset/images/*.tif'
pathY = '/content/at_dataset/labels/*.tiff'

if not torch.cuda.is_available():
  raise Exception("Things will go much quicker if you enable a GPU in Colab under 'Runtime / Change Runtime Type'")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model=U_net().to(device)

optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.99)
criterion = nn.BCEWithLogitsLoss() 

train(model, criterion, pathX, pathY, optimizer, num_epochs, data_size, split, device) 