In [2]:
import matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import os
import torch
import random
import scipy
from os.path import basename
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from torch.nn import Linear, GRU, Conv2d, Dropout, MaxPool2d, BatchNorm1d
from torch.nn.functional import relu, elu, relu6, sigmoid, tanh, softmax
from torch.utils.data import DataLoader, random_split
import torchvision
from torch import optim
from pytorch_toolbelt.losses import DiceLoss

In [4]:
from augmentation.batch_loader import * # these are scripts created by us
from unet.Network import *


In [None]:
# load data and create train and validation laoder

data_path = r"/zhome/20/8/175218/data/clean_data/*.*"
train_one_hot, train_image, train_image_standard, train_image_standard_hot = image_standard_format(path)
test_one_hot, test_image, test_image_standard, test_image_standard_hot = test_images(train_one_hot, # one hot encoded image
                                                                                     train_image, # real image
                                                                                     train_image_standard, # four dimension array with 3 channels for RGB in one channel 
                                                                                     train_image_standard_hot) # four dimension array for batch creation and in image are four classes 
trainloader, validloader = batch(train_image_standard,
                                 train_image_standard_hot) # create train and validation loader with batch size 9

In [None]:
# choose network 
device = "cuda" if torch.cuda.is_available() else "cpu"
model_for_training = "deeplab"
if model_for_training != "deeplab":
    Net = U_Net_Model() # you can find it in the script Network where we wrote the UNet on ourself to be able to understand it but also to add dropout at various stages
else:
    Net = torchvision.models.segmentation.deeplabv3_resnet101(num_classes = 9) # loading deeplab model and configure 9 classes for output
    Net.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 

In [None]:
## assign network to cuda if available
if torch.cuda.is_available():
    Net.cuda()

In [None]:
## hyperparameters
LEARNING_RATE = 0.00004
wDecay = 0.000001
epochs = 100
min_valid_loss = np.inf
class_weights = torch.tensor([0.2, 1, 1, 1, 1, 3, 3, 3, 3],dtype=torch.float).cuda() # here you can change the weights for the corresponding class. For instance, the first one is the background.

In [None]:
optimizer = optim.Adam(Net.parameters(), lr=LEARNING_RATE, weight_decay=wDecay)
criterion1 = DiceLoss( mode = 'multilabel') ## for our final result we used a combined loss with dice and crossEntropy. Multilabel considers the fact that we have muliple classes in our result
criterion2 = nn.CrossEntropyLoss(weight = class_weights)

In [None]:
## m
def accuracy(ys, ts):
    predictions = torch.max(ys, 1)[1]
    correct_prediction = torch.eq(predictions, ts)
    return torch.mean(correct_prediction.float())
validation_losses = []
validation_accuracies = []
training_losses = []
training_accuracies = []
step = 0

In [None]:
## training and validation 
# the model is saved when the lowest overall loss in validation is achieved
for e in range(epochs):
    print(e)
    train_loss = 0.0
    n_totalT = 0
    numT = 0
    train_accuracies_batches = []
    valid_accuracies_batches = []
    Net.train()    
    for data, target in trainloader:

        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        
        optimizer.zero_grad()
        pred = Net(data)['out']
        loss = criterion1(pred,target) + criterion2(pred,target)
        loss.backward()
        optimizer.step()
        numT += len(target)
        train_loss += loss.item()*len(target)
        n_totalT += 1
        step =+ 1
        predictions = pred.max(1)[1]
        train_accuracies_batches.append(accuracy(target, predictions))
    b = sum(train_accuracies_batches)/n_totalT
    training_accuracies.append(b.cpu().numpy()) # for plot
    training_losses.append(train_loss / numT)
    valid_loss = 0.0
    Net.eval()     
    numV = 0
    n_totalV = 0
    for data, labels in validloader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        pred = Net(data)['out']
        loss = criterion1(pred,labels) + criterion2(pred,labels)
        valid_loss += loss.item() * len(labels)
        numV += len(labels)
        
        n_totalV += 1
        predictions = pred.max(1)[1]
        valid_accuracies_batches.append(accuracy(labels, predictions))
    a = sum(valid_accuracies_batches) / n_totalV
    validation_accuracies.append(a.cpu().numpy()) # for plot
    print(a.cpu().numpy())
    validation_losses.append(valid_loss / numV)
    print(f'Epoch {e+1} \t\t Training Loss: {train_loss / numT} \t\t Validation Loss: {valid_loss / numV}')
    print(f'Epoch {e+1} \t\t Train Accuracy: {sum(train_accuracies_batches)/n_totalT} \t\t Validation Accuracy: {sum(valid_accuracies_batches) / n_totalV}')
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
       # torch.savexite(Net.state_dict(), 'saved_model_dice_deepCE_smol_col_W.pth')
        # Log i figure

In [None]:
epoch_range = np.arange(0, epochs, 1)
validation_losses = np.array(validation_losses)
np.savetxt("validation_losses", validation_losses)
validation_accuracies = np.array(validation_accuracies)
np.savetxt("validation_accuracies", validation_accuracies)
training_accuracies = np.array(training_accuracies)
np.savetxt("training_accuracies", training_accuracies)
training_losses = np.array(training_losses)
np.savetxt("training_losses", training_losses)
fig = plt.figure(figsize=(12,4))
plt.subplot(1, 2, 1)
plt.plot(epoch_range, training_losses, label='train_loss')
plt.plot(epoch_range, validation_losses, label='valid_loss')
plt.legend()