In [1]:
%matplotlib notebook

import sys, os
sys.path.append("/home/daniil/repos/pytorch-segmentation-detection/")
sys.path.append("/home/daniil/repos/pytorch-segmentation-detection/synchronized_batchnorm/")
sys.path.insert(0, '/home/daniil/repos/pytorch-segmentation-detection/vision/')


import torch.nn as nn
import torchvision.models as models
import torch

from pytorch_segmentation_detection.datasets.pascal_voc import PascalVOCSegmentation

from pytorch_segmentation_detection.transforms import (ComposeJoint,
                                                       RandomHorizontalFlipJoint,
                                                       RandomScaleJoint,
                                                       CropOrPad,
                                                       ResizeAspectRatioPreserve)

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.transforms as transforms

import numbers
import random

from matplotlib import pyplot as plt

import numpy as np
from PIL import Image

from sklearn.metrics import confusion_matrix

def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    
    return logits_flatten

def flatten_annotations(annotations):
    
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)


def adjust_learning_rate(optimizer, iteration):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    
    max_iteration = 13000.0
    
    multiplier = (1.0 - (iteration / max_iteration)) ** (0.9)
    
    lr = 0.0001 * multiplier
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



from pytorch_segmentation_detection.transforms import RandomCropJoint


number_of_classes = 21

labels = range(number_of_classes)

train_transform = ComposeJoint(
                [
                    RandomHorizontalFlipJoint(),
                    RandomScaleJoint(low=0.5, high=2.0),
                    RandomCropJoint(crop_size=(513, 513)),
                    [transforms.ToTensor(), None],
                    [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                    [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])

trainset = PascalVOCSegmentation(download=False,
                                 joint_transform=train_transform,
                                 split_mode=1)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=4, drop_last=True)


valid_transform = ComposeJoint(
                [
                     [transforms.ToTensor(), None],
                     [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                     [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])


valset = PascalVOCSegmentation(train=False,
                               download=False,
                               joint_transform=valid_transform,
                               split_mode=1)


valset_loader = torch.utils.data.DataLoader(valset, batch_size=1,
                                            shuffle=False, num_workers=2)

train_subset_sampler = torch.utils.data.sampler.SubsetRandomSampler(range(904))
train_subset_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1,
                                                   sampler=train_subset_sampler,
                                                   num_workers=2)


# Define the validation function to track MIoU during the training
def validate():
    
    fcn.eval()
    
    overall_confusion_matrix = None

    for image, annotation in valset_loader:

        image = Variable(image.cuda())
        logits = fcn(image)

        # First we do argmax on gpu and then transfer it to cpu
        logits = logits.data
        _, prediction = logits.max(1)
        prediction = prediction.squeeze(1)

        prediction_np = prediction.cpu().numpy().flatten()
        annotation_np = annotation.numpy().flatten()

        # Mask-out value is ignored by default in the sklearn
        # read sources to see how that was handled

        current_confusion_matrix = confusion_matrix(y_true=annotation_np,
                                                    y_pred=prediction_np,
                                                    labels=labels)

        if overall_confusion_matrix is None:


            overall_confusion_matrix = current_confusion_matrix
        else:

            overall_confusion_matrix += current_confusion_matrix
    
    
    intersection = np.diag(overall_confusion_matrix)
    ground_truth_set = overall_confusion_matrix.sum(axis=1)
    predicted_set = overall_confusion_matrix.sum(axis=0)
    union =  ground_truth_set + predicted_set - intersection

    intersection_over_union = intersection / union.astype(np.float32)
    mean_intersection_over_union = np.mean(intersection_over_union)
    
    fcn.train()

    return mean_intersection_over_union


def validate_train():
    
    fcn.eval()
    
    overall_confusion_matrix = None

    for image, annotation in train_subset_loader:

        image = Variable(image.cuda())
        logits = fcn(image)

        # First we do argmax on gpu and then transfer it to cpu
        logits = logits.data
        _, prediction = logits.max(1)
        prediction = prediction.squeeze(1)

        prediction_np = prediction.cpu().numpy().flatten()
        annotation_np = annotation.numpy().flatten()

        # Mask-out value is ignored by default in the sklearn
        # read sources to see how that was handled

        current_confusion_matrix = confusion_matrix(y_true=annotation_np,
                                                    y_pred=prediction_np,
                                                    labels=labels)

        if overall_confusion_matrix is None:


            overall_confusion_matrix = current_confusion_matrix
        else:

            overall_confusion_matrix += current_confusion_matrix
    
    
    intersection = np.diag(overall_confusion_matrix)
    ground_truth_set = overall_confusion_matrix.sum(axis=1)
    predicted_set = overall_confusion_matrix.sum(axis=0)
    union =  ground_truth_set + predicted_set - intersection

    intersection_over_union = intersection / union.astype(np.float32)
    mean_intersection_over_union = np.mean(intersection_over_union)
    
    fcn.train()

    return mean_intersection_over_union


In [2]:
%matplotlib notebook

from matplotlib import pyplot as plt


# Create the training plot
loss_current_iteration = 0
loss_history = []
loss_iteration_number_history = []

validation_current_iteration = 0
validation_history = []
validation_iteration_number_history = []

train_validation_current_iteration = 0
train_validation_history = []
train_validation_iteration_number_history = []
 
f, (loss_axis, validation_axis) = plt.subplots(2, 1)

loss_axis.plot(loss_iteration_number_history, loss_history)
validation_axis.plot(validation_iteration_number_history, validation_history, 'b',
                     train_validation_iteration_number_history, train_validation_history, 'r')

loss_axis.set_title('Training loss')
validation_axis.set_title('MIoU on validation dataset')

plt.tight_layout()

<IPython.core.display.Javascript object>

In [3]:
from sync_batchnorm import SynchronizedBatchNorm2d, DataParallelWithCallback
from pytorch_segmentation_detection.models.psp import Resnet50_8s_psp

bn_layers = []

def make_batchnorm_syncronized(module):
    
    for child_module_name, child_module in module.named_children():
        
        if isinstance(child_module, nn.BatchNorm2d):
            
            sync_bn = SynchronizedBatchNorm2d(child_module.num_features)
            sync_bn.weight = child_module.weight
            sync_bn.bias = child_module.bias
            sync_bn.running_var = child_module.running_var
            sync_bn.running_mean = child_module.running_mean
            module.__setattr__(child_module_name, sync_bn)
            bn_layers.append(sync_bn)

fcn = Resnet50_8s_psp(num_classes=21)
#change_resnet_architecture_to_kaiming( fcn.resnet101_16s )
#fcn.resnet50_8s.load_state_dict(torch.load('resnet50-imagenet.pth'), strict=False)

fcn.apply(make_batchnorm_syncronized)

# parameters = list(map(lambda x: list(x.parameters()), [fcn.resnet101_16s.layer5,
#                                                        fcn.resnet101_16s.layer6,
#                                                        fcn.resnet101_16s.layer7,
#                                                        fcn.resnet101_16s.logits_conv_final]))
# parameters = sum(parameters, [])

fcn = DataParallelWithCallback(fcn, device_ids=[0, 1, 2, 3], output_device=3)

# fcn.load_state_dict( torch.load('resnet_101_16s_multigrid.pth'), strict=False )



In [4]:
fcn.cuda()
fcn.train()

#final_criterion = nn.CrossEntropyLoss(size_average=False).cuda(3)

final_criterion = nn.CrossEntropyLoss(size_average=False).cuda(3)
aux_criterion = nn.CrossEntropyLoss(size_average=False).cuda(3)

bn_layers_params = list(map(lambda x: list(x.parameters()), bn_layers))
bn_layers_params_set = set(sum(bn_layers_params, []))

all_params_set = set(fcn.parameters())

non_bn_laeyers_params_set = all_params_set - bn_layers_params_set


# optimizer = torch.optim.SGD(fcn.parameters(),
#                             lr=0.007,
#                             momentum=0.9)
# optimizer = optim.SGD([
#                 {'params': list(non_bn_laeyers_params_set), 'weight_decay': 0.0001},
#                 {'params': list(bn_layers_params_set), 'weight_decay': 0.9997}
#             ], lr=0.000000001, momentum=0.9)



#optimizer = optim.Adam(parameters, lr=0.0001, weight_decay=0.0001)

optimizer = optim.Adam(fcn.parameters(), lr=0.0001, weight_decay=0.0001)




In [5]:
best_validation_score = 0
#loss_current_iteration = 0

iter_size = 20

for epoch in range(1000):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        
        # get the inputs
        img, anno = data
        
        # We need to flatten annotations and logits to apply index of valid
        # annotations. All of this is because pytorch doesn't have tf.gather_nd()
        anno_flatten = flatten_annotations(anno)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)

        # wrap them in Variable
        # the index can be acquired on the gpu
        img, anno_flatten_valid, index = Variable(img.cuda()), Variable(anno_flatten_valid.cuda(3)), Variable(index.cuda(3))

        # zero the parameter gradients
        optimizer.zero_grad()
        
        adjust_learning_rate(optimizer, loss_current_iteration)


        # forward + backward + optimize
        final_logits = fcn(img)
        
        
        final_logits_flatten = flatten_logits(final_logits, number_of_classes=21)
        final_logits_flatten_valid = torch.index_select(final_logits_flatten, 0, index)
        loss = final_criterion(final_logits_flatten_valid, anno_flatten_valid)
        
        
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += (loss.data[0] / (final_logits_flatten.size(0) * 2)) 
        if i % 2 == 1:
            
            
            loss_history.append(running_loss / 2)
            loss_iteration_number_history.append(loss_current_iteration)
            
            loss_current_iteration += 1
            
            loss_axis.lines[0].set_xdata(loss_iteration_number_history)
            loss_axis.lines[0].set_ydata(loss_history)

            loss_axis.relim()
            loss_axis.autoscale_view()
            loss_axis.figure.canvas.draw()
            
            loss_current_iteration += 1
            
            running_loss = 0.0
        
            
            
    current_validation_score = validate()
    validation_history.append(current_validation_score)
    validation_iteration_number_history.append(validation_current_iteration)

    validation_current_iteration += 1

    validation_axis.lines[0].set_xdata(validation_iteration_number_history)
    validation_axis.lines[0].set_ydata(validation_history)



    current_train_validation_score = validate_train()
    train_validation_history.append(current_train_validation_score)
    train_validation_iteration_number_history.append(train_validation_current_iteration)

    train_validation_current_iteration += 1

    validation_axis.lines[1].set_xdata(train_validation_iteration_number_history)
    validation_axis.lines[1].set_ydata(train_validation_history)


    validation_axis.relim()
    validation_axis.autoscale_view()
    validation_axis.figure.canvas.draw()

    # Save the model if it has a better MIoU score.
    if current_validation_score > best_validation_score:

        torch.save(fcn.state_dict(), 'resnet_50_16s.pth')
        best_validation_score = current_validation_score
        print(best_validation_score)
        
                

print('Finished Training')



0.5465321830038283
0.5637863885429427
0.6174545663409374
0.6588139264402988
0.6642781720095828
0.6843888837353947
0.7047682773896456
0.7140351496218439
0.7160279419375072
0.7202573916368987
0.7291834144841298
0.7313510007892978
0.734529209030578
0.7360619722406061


TypeError: addcdiv_() takes 2 positional arguments but 3 were given