In [None]:
#Load Headers

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
from util import EarlyStopping, save_nets, save_predictions, load_best_weights
from model import UNet
from dataset import DataFolder
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
#Hardcoded variables
np. random.seed(1000)

batch_size = 8 #change batch size here
epochs = 200 #change number of epochs here
lr = 0.001 
patience = 10
min_delta = 0.001

In [None]:
#Loading Training, Validation, Test Sets

train_loader = data.DataLoader(
    dataset=DataFolder('new_dataset/train/train_images_256/', 'new_dataset/train/train_masks_256/', 'train'),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

valid_loader = data.DataLoader(
    dataset=DataFolder('new_dataset/val/train_images_256/', 'new_dataset/val/train_masks_256/', 'validation'),
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

test_loader = data.DataLoader(
    dataset=DataFolder('new_dataset/test/train_images_256/', 'new_dataset/test/train_masks_256/', 'test'),
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

In [None]:
#Unet and other utilities construction

model = UNet(1, shrink=1).cuda()
nets = [model]
params = [{'params': net.parameters()} for net in nets]
solver = optim.Adam(params, lr=lr)

criterion = nn.CrossEntropyLoss()
es = EarlyStopping(min_delta=min_delta, patience=patience)

In [None]:
#Training Phase
train_allepoch_loss = []
valid_allepoch_loss = []

for epoch in range(1, epochs+1):
    with tqdm(total=len(train_loader.dataset), desc=f'Epoch {epoch}/{epochs}', unit='img', position=0, leave=True) as pbar:
        train_loss = []
        valid_loss = []

        for batch_idx, (img, mask, _) in enumerate(train_loader):

            solver.zero_grad()

            img = img.cuda()
            mask = mask.cuda()

            pred = model(img)
            loss = criterion(pred, mask)
            pbar.set_postfix(**{'loss (batch)': loss.item()})
            loss.backward()
            solver.step()

            train_loss.append(loss.item())
            pbar.update(img.shape[0])
        with torch.no_grad():
            for batch_idx, (img, mask, _) in enumerate(valid_loader):

                img = img.cuda()
                mask = mask.cuda()

                pred = model(img)
                loss = criterion(pred, mask)

                valid_loss.append(loss.item())
                
        train_loss_mean = np.mean(train_loss)
        valid_loss_mean = np.mean(valid_loss)

        print('[EPOCH {}/{}] Train Loss: {:.4f}; Valid Loss: {:.4f}'.format(
            epoch, epochs, train_loss_mean, valid_loss_mean
        ))

        flag, best, bad_epochs = es.step(torch.Tensor([valid_loss_mean]))
        if flag:
            print('Early stopping criterion met')
            break
        else:
            if bad_epochs == 0:
                save_nets(nets, 'saved_model')
                print('Saving current best model')

            print('Current Valid loss: {:.4f}; Current best: {:.4f}; Bad epochs: {}'.format(
                valid_loss_mean, best.item(), bad_epochs
            ))
    train_allepoch_loss.append(train_loss_mean)
    valid_allepoch_loss.append(valid_loss_mean)

In [None]:
#Plot of training loss and validation loss across epochs

plt.figure()
plt.plot(np.arange(1,len(train_allepoch_loss)+1), train_allepoch_loss, 'b-', label="Training Set Loss")
plt.plot(np.arange(1,len(valid_allepoch_loss)+1), valid_allepoch_loss, 'r-', label="Validation Set Loss")
plt.legend(loc="upper right")
plt.xlabel('Epochs')
plt.ylabel('Cross Entropy Loss')
plt.savefig('saved_images/trainval_loss_n.svg',transparent=True)
plt.show()

print('Training is over!!!')

In [None]:
#Testing Phase

with torch.no_grad():
        test_loss = []
        for batch_idx, (img, mask, img_fns) in enumerate(test_loader):
    
            model = load_best_weights(model, 'saved_model')
    
            img = img.cuda()
            mask = mask.cuda()
    
            pred = model(img)
            loss = criterion(pred, mask)
    
            test_loss.append(loss.item())
            
            pred_mask = torch.argmax(F.softmax(pred, dim=1), dim=1)
            pred_mask = torch.chunk(pred_mask, chunks=batch_size, dim=0)
            save_predictions(pred_mask, img_fns, 'test_predictions')
    
            print('[Tested: {}/{}] Test Loss: {:.4f}'.format(
                batch_idx+1, len(test_loader), loss.item()
            ))
    
print('Final Test Loss: {:.4f}'.format(np.mean(test_loss)))

In [None]:
#Plot of last batch of test set

fig = plt.figure(figsize=(15,15))

cmap = mpl.colors.ListedColormap(['black','blue','red', 'green', 'brown', 'cyan','yellow','royalblue'])
cmap.set_over('royalblue')
cmap.set_under('black')
bounds = [0,1,2,3,4,5,6,7,8]
cmaplist = [cmap(i) for i in range(cmap.N)]
cmap = mpl.colors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

for idx, pred in enumerate(pred_mask):
    pred = torch.squeeze(pred, dim=0)
    ax = fig.add_subplot(3,3, idx+1)
    img = ax.imshow(pred.cpu().numpy(), interpolation='none', cmap=cmap, norm=norm)
    fig.tight_layout()
    fig.colorbar(img)
plt.savefig('saved_images/predictions_n.svg',transparent=True)
plt.show()