# Using https://youtu.be/u1loyDCoGbE for implementation of U-net : https://arxiv.org/pdf/1505.04597.pdf

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
import cv2 as cv
import time

In [2]:
# #Google Collab specifics
# from google.colab import drive
# drive.mount('/content/drive')
# !cp "/content/drive/MyDrive/helpers.py" .
# imgs_names = glob.glob( '/content/drive/MyDrive/th_analysedimages/*.tif')
# labels_names = glob.glob('/content/drive/MyDrive/labels/*.png')

In [11]:
#Local github project specifics
imgs_names = glob.glob( '/Users/theophanemayaud/Dev/EPFL MA1/Machine Learning/cs-433-project-2-ml_fools/th_analysedimages/*.tif')
labels_names = glob.glob('/Users/theophanemayaud/Dev/EPFL MA1/Machine Learning/cs-433-project-2-ml_fools/th_csv_labels/png_masks_emb/*.png')
!cp "/Users/theophanemayaud/Dev/EPFL MA1/Machine Learning/cs-433-project-2-ml_fools/helpers.py" .

In [12]:
imgs_names= sorted(imgs_names)
#imgs = [cv.imread(name, cv.IMREAD_UNCHANGED) for name in imgs_names[1]]
print(f"Found {len(imgs_names)} images")

labels_names= sorted(labels_names)
#labels = [png_to_mask(cv.imread(name, cv.IMREAD_UNCHANGED)) for name in labels_names]
print(f"Found {len(labels_names)} labels")

Found 357 images
Found 357 labels


<img src="./U-Net structure.png" width="500">

In [13]:
# UNet definitions
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        # functions for going down the U
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.d_double_conv_1 = double_conv(1, 64)
        self.d_double_conv_2 = double_conv(64, 128)
        self.d_double_conv_3 = double_conv(128, 256)
        self.d_double_conv_4 = double_conv(256, 512)
        self.d_double_conv_5 = double_conv(512, 1024)
        
        # functions for going up the U
        self.up_trans_4 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)        
        self.u_double_conv_4 = double_conv(1024, 512)
        self.up_trans_3 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.u_double_conv_3 = double_conv(512, 256)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.u_double_conv_2 = double_conv(256, 128)
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.u_double_conv_1 = double_conv(128, 64)
        
        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        # self.sigmoid = nn.Sigmoid()
        
    def forward(self, image):
        '''makes the 388x388 prediction with the model, image must be 572x572pixels'''
        
        # Going down the U
        d1 = self.d_double_conv_1(image) # first "level"
        # print(x1.size())
        x = self.max_pool_2x2(d1)
        d2 = self.d_double_conv_2(x) # second
        x = self.max_pool_2x2(d2)
        d3 = self.d_double_conv_3(x) # third
        x = self.max_pool_2x2(d3)
        d4 = self.d_double_conv_4(x) # fourth
        x = self.max_pool_2x2(d4)
        x = self.d_double_conv_5(x) # last layer (fifth) : no max pool
        # plt.imshow(x.detach().numpy()[0, 0, :, :])
        
        # Going up the U
        x = self.up_trans_4(x)
        d4 = crop_img(tensor=d4, target_tensor=x) #crop to copy
        x = self.u_double_conv_4(torch.cat([d4, x], 1))
        
        x = self.up_trans_3(x)
        d3 = crop_img(tensor=d3, target_tensor=x)
        x = self.u_double_conv_3(torch.cat([d3, x], 1))
        
        x = self.up_trans_2(x)
        d2 = crop_img(tensor=d2, target_tensor=x)
        x = self.u_double_conv_2(torch.cat([d2, x], 1))
        
        x = self.up_trans_1(x)
        d1 = crop_img(tensor=d1, target_tensor=x)
        x = self.u_double_conv_1(torch.cat([d1, x], 1))
        
        x = self.out(x)
        # x = self.sigmoid(x) # see https://medium.com/analytics-vidhya/simple-neural-network-with-bceloss-for-binary-classification-for-a-custom-dataset-8d5c69ffffee why necessary
        return x
        
        
    
# some functions so reduce redunduncy
def double_conv(nb_in_channels, nb_out_channels):
    conv = nn.Sequential(
        nn.Conv2d(nb_in_channels, nb_out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(nb_out_channels, nb_out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
    )
    return conv

def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2] # NB they are square so .size[2]=.size[3]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size #target is always smaller
    pix_crop = delta // 2
    return tensor[:, :, pix_crop:tensor_size-pix_crop, pix_crop:tensor_size-pix_crop]

In [14]:
# Test forward to see how it behaves
# image = torch.rand((1, 1, 572, 572))
# print(image.type())
# model = UNet()
# y = model(image).detach().numpy()
# print(f"Output shape : {y.shape}")
# print(f"Output max value : {y[0,0, :, :].max()}, min={y[0,0, :, :].min()}")
# _ , axs = plt.subplots(ncols=2)
# axs[0].imshow(image.detach().numpy()[0,0,:,:])
# axs[0].set_title("Original image")
# axs[1].imshow(y[0,0, :, :])
# axs[1].set_title("Output label image")
# plt.show()

In [15]:
# Training function
from helpers import png_to_mask, segment_dataset

def train_model(model, img_pathnames, label_pathnames, criterion, optimizer, device, num_epochs=25):
    
    print("Starting the training on images !\n")
    model.train()
    for epoch in range(num_epochs):
        print(f"Epoch {1+epoch}/{num_epochs}", end="")
        
        for image_i in range(len(img_pathnames)):
            image = cv.imread(img_pathnames[image_i], cv.IMREAD_UNCHANGED)
            label = png_to_mask(cv.imread(label_pathnames[image_i], cv.IMREAD_UNCHANGED))
            
            image_segments, label_segments = segment_dataset([image], [label])
            
            for segment_i in range(len(image_segments[:, 0, 0])):
                since = time.process_time() # For process monitoring
                img_seg = torch.tensor(image_segments[segment_i, :, :], requires_grad=True).view(1, 1, 572, 572).to(device).float()
                label_seg = torch.tensor(label_segments[segment_i, :, :].astype(float)).view(1, 388, 388).to(device).long()
                
                prediction = model(img_seg)
                loss = criterion(prediction, label_seg)
            
                # Compute the gradient
                optimizer.zero_grad()
                loss.backward()
                
                # Update the parameters of the model with a gradient step
                optimizer.step()
                            
                disp_img_mod = 50
                if image_i%disp_img_mod==0:
                    time_elapsed = time.process_time()-since
                    if segment_i==0:
                        print(f"\n|  Image {1+image_i}/{len(img_pathnames)} '{img_pathnames[image_i]}'", end="")
                    if segment_i%3==0 and segment_i<9:
                        new_prediction = model(img_seg)
                        new_loss = criterion(new_prediction, label_seg)

                        ori_lab_seg = label_seg.cpu().detach().numpy()[0,:,:].astype(int)
                        pred_lab_seg = torch.argmax(prediction, dim=1).cpu().detach().numpy()[0, :, :]
                        ori_lab_counts = np.count_nonzero(ori_lab_seg == 1)
                        pred_lab_counts = np.count_nonzero(pred_lab_seg == 1)
                        # if image_i%100==0:
                        #     print("Error of 0 label pixels in original label...")
                        #     _ , axs = plt.subplots(ncols=3, figsize=(40, 40))
                        #     axs[0].set_title("Segment of original image")
                        #     axs[0].imshow(img_seg.cpu().detach().numpy()[0,0,:,:])
                        #     axs[1].set_title("Segment of original label")
                        #     axs[1].imshow(label_seg.cpu().detach().numpy()[0,:,:])
                        #     axs[2].set_title("Predicted label segment")
                        #     axs[2].imshow(torch.argmax(prediction, dim=1).cpu().detach().numpy()[0,:,:])
                        #     plt.show()
                        if ori_lab_counts == 0:
                            ori_lab_counts = 1; #fix when some masks are 0 to not have divide by 0
                        emb_surf_pred_error = round(100*(pred_lab_counts-ori_lab_counts)/ori_lab_counts,4)
                        print(f"\n|  |  Segment {1+segment_i}/{len(image_segments[:, 0, 0])} of image {1+image_i} : loss={loss} duration={int(time_elapsed)//60}m {int(time_elapsed%60)}s. Loss reduced {loss-new_loss}. Emb surf pred err (pred1s-orig1s)/orig1s = {emb_surf_pred_error}% (target is 0%)", end="")
                else:
                    if image_i%disp_img_mod==1 and segment_i==0:
                        print("\n|  Next images & segments ", end="")
                    if segment_i%6==0:
                        print(".", end="")

    print("\n Finished training")

In [None]:
# If a GPU is available (ex when on Google Colab)
if not torch.cuda.is_available():
    print("Things will go much quicker if you enable a GPU, ex in Colab under 'Runtime / Change Runtime Type'")
else:
    del model # only needed when re-running multiple times
    torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"    Selected device is: {device}/n")
    
learning_rate = 10e-5
model = UNet().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_model(model, imgs_names, labels_names, criterion, optimizer, device, num_epochs=100)

In [None]:
# Test prediction
img_nb = 100
test_seg_i = 0
test_img = cv.imread(imgs_names[img_nb], cv.IMREAD_UNCHANGED) 
test_label = png_to_mask(cv.imread(labels_names[img_nb], cv.IMREAD_UNCHANGED))
test_image_segments, test_label_segments = segment_dataset([test_img], [test_label])
test_img_seg = torch.tensor(test_image_segments[test_seg_i, :, :], requires_grad=True).view(1, 1, 572, 572).to(device).float()
test_label_seg = torch.tensor(test_label_segments[test_seg_i, :, :].astype(float)).view(1, 1, 388, 388).to(device).float()

model.eval()
test_pred = model(test_img_seg)
y_test = torch.argmax(test_pred, dim=1).cpu().detach().numpy()

In [None]:
# Show predictions
# y_test[y_test<0.5] = 0
# y_test[y_test>=0.5] = 1
y_test_round = y_test.copy()
# y_test_round[y_test<0.5] = 0
# y_test_round[y_test>=0.5] = 1

print(f"Output shape : {y_test_round.shape}")
print(f"Output max value : {y_test_round[0, :, :].max()}, min={y_test_round[0, :, :].min()}")
print(f"Output average={y_test_round.mean()}")

_ , axs = plt.subplots(ncols=3, figsize=(40, 40))

axs[0].set_title("Original image")
orig_image = test_img_seg.cpu().detach().numpy()[0,0,:,:]
axs[0].imshow(orig_image)

axs[1].set_title("Original label")
ori_lab = test_label_seg.cpu().detach().numpy()[0,0,:,:]
axs[1].imshow(ori_lab)

axs[2].set_title("Model predicted label")
pred_lab = y_test_round[0, :, :]
axs[2].imshow(pred_lab)

ori_lab_seg = ori_lab.astype(int)
pred_lab_seg = pred_lab
ori_lab_counts = np.count_nonzero(ori_lab_seg == 1)
pred_lab_counts = np.count_nonzero(pred_lab_seg == 1)
print(f"Original embolism pixels = {ori_lab_counts}, predicted={pred_lab_counts}")
print(f"Emb surf pred err (orig1s-pred1s)/orig1s = {100*(ori_lab_counts-pred_lab_counts)/ori_lab_counts}% (target is 0%)")

plt.show()