In [1]:
# part 4 experimentation: Unsupervised domain adaptation using adaptive batch normalization
# major library dependencies: jupyter, numpy, matplotlib, pytorch, scikit-image, pillow

# Fine-tuning with Pseudo-ground Truth
import torch
from dataset import camvidLoader
from torch.utils.data import Dataset, DataLoader
import numpy as np
import data_aug as aug
import torch.nn.functional as F
from unet import UNet
from tempfile import TemporaryDirectory
import os
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage.io import imsave

device = 'cpu' # can be set to "cuda" if you have a GPU
unet = torch.load('camvid_sunny_model.pt', map_location=torch.device(device))

data_root = './CamVid/cloudy'
test_data = camvidLoader(root=data_root, split='test', is_aug=False, img_size = [256, 256], is_pytorch_transform=True)

num_classes = 14 # number of classes is always 14 for this project.
labels = ['Sky', 'Building', 'Pole', 'Road', 'LaneMarking', 'SideWalk', 'Pavement', 'Tree', 'SignSymbol', 
          'Fence', 'Car_Bus', 'Pedestrian', 'Bicyclist', 'Others']

## Load parameters, model and dataset

In [2]:
# define hyper-parameters
batch_size = 4
num_workers = 8
lr = 5e-6
epochs = 5

In [3]:
# import pre-trained unet model
unet_model = UNet(3, num_classes, width=32, bilinear=True)
unet = torch.load('camvid_sunny_model.pt', map_location=torch.device(device))
unet_model.load_state_dict(unet.state_dict())
unet_model = unet_model.to(device)

# define loss function and optimizer
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet_model.parameters(), lr=lr)

In [4]:
# load train and test dataset
aug_obj = aug.Compose([aug.RandomHorizontalFlip(), aug.RandomResizedCrop(256),
                   aug.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)])

train_dataset = camvidLoader(root=data_root, split='train', is_aug=True, img_size = [256, 256], 
                             is_pytorch_transform = True, aug = aug_obj) 
train_loader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True, drop_last=True)

val_dataset = camvidLoader(root=data_root, split='val', is_aug=False, img_size = [256, 256], 
                             is_pytorch_transform = True, aug = None) 
val_loader = DataLoader(val_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=False)

test_loader = DataLoader(test_data, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=False)

## Implement model finetuning using AdaBN

In [5]:
# write the function that computes the entropy map for the unet output
def compute_entropy_map(model_output):

    pixel_entropy = []
    
    for idx in range(0, model_output.shape[0]):
    
        # output size is 14*256*256
        probs = F.softmax(model_output[idx], dim=0)

        # calculate the entropy for each pixel
        epsilon = 1e-5
        entropy = -torch.sum(probs * torch.log(probs + epsilon), dim=0)
        pixel_entropy.append(entropy.cpu().detach().numpy())

    return pixel_entropy

In [None]:
# create a file to store checkpoint model parameters
model_dir = './model_dir/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# frozen all layers and their parameters
for param in unet_model.parameters():
    param.requires_grad = False

with TemporaryDirectory() as tempdir:

    # store the best parameters
    best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
    torch.save(unet_model.state_dict(), best_model_params_path)

    best_loss = float('Inf')

    for epoch in range(epochs):

        # training loop
        train_loss = 0
        count = 0
        unet_model.train()

        for idx_batch, (imagergb, _, filename) in enumerate(train_loader):
            
            # zero the grad of the network before feed-forward
            optimizer.zero_grad()
            
            # send to the device (GPU or CPU) and do a forward pass
            x = imagergb.to(device)

            # prediction with 14 probs 
            y = unet_model(x)

            # pseudo_labels based on pre-trained model
            pseudo_labels = torch.argmax(y, dim=1)

            # only update parameters in batch normalization,
            # and update statistic data here
            for module in unet_model.modules():
                if isinstance(module, torch.nn.BatchNorm2d):
                    module.train()  

            loss = loss_func(y, pseudo_labels)
            optimizer.step()

            if idx_batch % 2 == 0:
                print("train epoch = " + str(epoch) + " | batch = " + str(idx_batch) + " | loss = "+str(loss.item()))

            train_loss += loss.item()
            count += 1

        train_loss /= count

        # evaluation loop
        unet_model.eval()
        val_loss = 0
        count = 0

        for idx_batch, (imagergb, labelmask, filename) in enumerate(val_loader):

            with torch.no_grad(): # no gradient required during validation loop
                x = imagergb.to(device)
                y_ = labelmask.to(device)
                y = unet_model(x)
                loss = loss_func(y, y_)
                val_loss += loss.item()
                count += 1

        val_loss/=count

        # choose the best model with minimal loss
        # save parameters to the unet_model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(unet_model.state_dict(), best_model_params_path)

        print(f'In epoch {epoch}, the train loss is: {train_loss}, the val loss is: {val_loss}')

        # record the parameters at each chechpoint
        model_location = model_dir + "model_file_epoch_" + str(epoch) + ".pt"
        torch.save(unet_model, model_location)

    unet_model.load_state_dict(torch.load(best_model_params_path))

train epoch = 0 | batch = 0 | loss = 0.40610939264297485
train epoch = 0 | batch = 2 | loss = 0.3450988531112671
train epoch = 0 | batch = 4 | loss = 0.34040388464927673
train epoch = 0 | batch = 6 | loss = 0.3558237552642822
train epoch = 0 | batch = 8 | loss = 0.3471675217151642
train epoch = 0 | batch = 10 | loss = 0.3194240927696228
train epoch = 0 | batch = 12 | loss = 0.39745572209358215
train epoch = 0 | batch = 14 | loss = 0.3933342695236206
train epoch = 0 | batch = 16 | loss = 0.3413313925266266
In epoch 0, the train loss is: 0.36306368311246234, the val loss is: 1.305429220199585
train epoch = 1 | batch = 0 | loss = 0.3962474465370178
train epoch = 1 | batch = 2 | loss = 0.4062955379486084
train epoch = 1 | batch = 4 | loss = 0.34931445121765137
train epoch = 1 | batch = 6 | loss = 0.3907968997955322
train epoch = 1 | batch = 8 | loss = 0.3547840118408203
train epoch = 1 | batch = 10 | loss = 0.37394607067108154
train epoch = 1 | batch = 12 | loss = 0.3365594744682312
train 

## Evaluate cloudy dataset performance

In [7]:
# evaluation metric of accuracy.
def global_accuracy_metric(y_true, y_pred):
    return np.sum(y_true == y_pred)/y_pred.size

# evaluation metric of iou.
def IoU_metric(y_true, y_pred):

    iou_per_image = []
    
    for i in range(num_classes):
        intersection = np.logical_and(y_pred == i, y_true == i).sum()
        union = np.logical_or(y_pred == i, y_true == i).sum()
        
        # if the union is 0, then the iou should be null
        # otherwise, the iou is intersection/union
        if union == 0:
            iou = np.NAN
        else:
            iou = intersection/union
            
        iou_per_image.append(iou)

    return iou_per_image

In [8]:
global_acc = []
perclass_acc = []
img_file = []

unet_model.eval()

for idx_batch, (imagergb, labelmask, filename) in enumerate(test_loader):

    img_file.extend(filename)
    
    with torch.no_grad(): 

        x = imagergb.to(device) 
        y_ = labelmask.to(device) 
        y = unet_model(x) 

        for idx in range(0, y.shape[0]):

            # choose the most likely label
            max_index = torch.argmax(y[idx], dim=0).cpu().int().numpy()
            gt_correct_format = y_[idx].cpu().int().numpy()

            # calculate the global accuracy of each image
            correct_prediction = global_accuracy_metric(gt_correct_format, max_index)
            global_acc.append(correct_prediction)

            # calculate the iou per class of each image
            iou_per_image = IoU_metric(gt_correct_format, max_index)
            perclass_acc.append(iou_per_image)

In [9]:
import warnings
warnings.filterwarnings("ignore")

# print and calculate the global image accuracy 
print(f'The global accuracy overall image is: {np.mean(global_acc)}')

# print and calculate the average mIOU
overall_class_iou = np.nanmean(perclass_acc, axis=0)
print(f'The average mIoU scores is: {np.nanmean(overall_class_iou)}\n')

# print and calculate the IOU per class
for idx in range(num_classes):
    print(f'The overall IOU scores for class {labels[idx]} is {overall_class_iou[idx]}')

The global accuracy overall image is: 0.7122147878011068
The average mIoU scores is: 0.32051338053990197

The overall IOU scores for class Sky is 0.8191607140651854
The overall IOU scores for class Building is 0.47109222447447824
The overall IOU scores for class Pole is 0.0
The overall IOU scores for class Road is 0.6600544010361143
The overall IOU scores for class LaneMarking is 0.15773450144547776
The overall IOU scores for class SideWalk is 0.6623841835538951
The overall IOU scores for class Pavement is nan
The overall IOU scores for class Tree is 0.6254697701298607
The overall IOU scores for class SignSymbol is 0.0
The overall IOU scores for class Fence is 0.0199589543375284
The overall IOU scores for class Car_Bus is 0.48060968842677415
The overall IOU scores for class Pedestrian is 0.2662015479132645
The overall IOU scores for class Bicyclist is 0.0
The overall IOU scores for class Others is 0.004007961636147182
