In [1]:
%matplotlib inline

import sys, os
sys.path.insert(0, '../../../../vision/')
sys.path.append('../../../../../pytorch-segmentation-detection/')

# Use second GPU -pytorch-segmentation-detection- change if you want to use a first one
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

from PIL import Image
from matplotlib import pyplot as plt

import torch
from torchvision import transforms
import torchvision
from torch.autograd import Variable
from pytorch_segmentation_detection.layers import GlobalAvgPool2d

import numpy as np
import torch.nn as nn

from pytorch_segmentation_detection.datasets.pascal_voc import PascalVOCSegmentation

import pytorch_segmentation_detection.models.fcn as fcns
import pytorch_segmentation_detection.models.resnet_dilated as resnet_dilated
from pytorch_segmentation_detection.transforms import (ComposeJoint,
                                                       RandomHorizontalFlipJoint,
                                                       RandomScaleJoint,
                                                       CropOrPad,
                                                       ResizeAspectRatioPreserve,
                                                       RandomCropJoint,
                                                       Split2D)


import torch.optim as optim

from sklearn.metrics import confusion_matrix

def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    
    return logits_flatten

def flatten_annotations(annotations):
    
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)

def adjust_learning_rate(optimizer, iteration):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    """Pete: this also currently effectively stops training via an error
    since at max_iterations+1, it tries to raise a negative number to a fractional power"""
    
    max_iteration = 10000.0
    
    multiplier = (1.0 - (iteration / max_iteration)) ** (0.9)
    
    lr = 0.0001 * multiplier
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


from pytorch_segmentation_detection.transforms import RandomCropJoint


# need to stick to 21 for now until change loss functions
# since the data has labels for 21 classes, we get an error
# unless this is 21
descriptor_dimensionality = 21

In [2]:
train_transform = ComposeJoint(
                [
                    RandomHorizontalFlipJoint(),
                    RandomCropJoint(crop_size=(513, 513)),
                    [transforms.ToTensor(), None],
                    [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                    [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])

path_to_VOC = '/media/peteflo/3TBbackup/pytorch-pretrained/VOC'

trainset = PascalVOCSegmentation(path_to_VOC,
                                 download=False,
                                 joint_transform=train_transform)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=4, drop_last=True)



valid_transform = ComposeJoint(
                [
                     [transforms.ToTensor(), None],
                     [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                     [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])


valset = PascalVOCSegmentation(path_to_VOC,
                               train=False,
                               download=False,
                               joint_transform=valid_transform)


valset_loader = torch.utils.data.DataLoader(valset, batch_size=1,
                                            shuffle=False, num_workers=2)

train_subset_sampler = torch.utils.data.sampler.SubsetRandomSampler(xrange(904))
train_subset_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1,
                                                   sampler=train_subset_sampler,
                                                   num_workers=2)

# TODO: define something like a validate() function
# # Define the validation function to track MIoU during the training
# def validate():

#     ...

#     return mean_intersection_over_union



In [3]:
%matplotlib notebook
from matplotlib import pyplot as plt

# Create the training plot
loss_current_iteration = 0
loss_history = []
loss_iteration_number_history = []

validation_current_iteration = 0
validation_history = []
validation_iteration_number_history = []

train_validation_current_iteration = 0
train_validation_history = []
train_validation_iteration_number_history = []
 
f, (loss_axis, validation_axis) = plt.subplots(2, 1)

loss_axis.plot(loss_iteration_number_history, loss_history)
validation_axis.plot(validation_iteration_number_history, validation_history, 'b',
                     train_validation_iteration_number_history, train_validation_history, 'r')

loss_axis.set_title('Training loss')
validation_axis.set_title('Score (to be defined) on validation dataset')

plt.tight_layout()

<IPython.core.display.Javascript object>

In [4]:
fcn = resnet_dilated.Resnet34_8s(num_classes=descriptor_dimensionality)
fcn.cuda()
fcn.train()

# note: the softmax happens inside the CrossEntropyLoss
criterion = nn.CrossEntropyLoss(size_average=False).cuda()

optimizer = optim.Adam(fcn.parameters(), lr=0.0001, weight_decay=0.0001)

  own_state[name].copy_(param)


In [5]:
import time

best_validation_score = 0

iter_size = 20

for epoch in range(130):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        start = time.time()
        
        # get the inputs
        img, anno = data      
        # img is 10 x 3 x 513 x 513 (10 was batch size, 3 is for RGB)
        # anno is 10 x 513 x 513 (same as above, but no rgb)
        
        # We need to flatten annotations and logits to apply index of valid
        # annotations. All of this is because pytorch doesn't have tf.gather_nd()
        anno_flatten = flatten_annotations(anno)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)

        # wrap them in Variable
        # the index can be acquired on the gpu
        img, anno_flatten_valid, index = Variable(img.cuda()), Variable(anno_flatten_valid.cuda()), Variable(index.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, loss_current_iteration)

        W = 513
        H = 513
        N = 1
        D_in = W*H*3
        num_Hidden = 1000
        D_descriptor = descriptor_dimensionality
        D_out = W*H*D_descriptor
        
        # forward + backward + optimize
        image_a_pred = fcn(img)
        image_a_pred = image_a_pred.view(N,W*H,D_descriptor)
        
        img_2 = img + torch.ones_like(img)
        image_b_pred = fcn(img_2)
        image_b_pred = image_b_pred.view(N,W*H,D_descriptor)
        
        # setup needed for matches
        # N is batch size; D_in is input dimension;
        # H is hidden dimension; D_out is output dimension.
        
        dtype = torch.cuda.FloatTensor
        dtype_long = torch.cuda.LongTensor
        
        # make 300 fake matches
        num_matches = 3
        matches_a = Variable(torch.zeros(num_matches).type(dtype_long), requires_grad=False)
        matches_a[0] = 20*W + 20
        matches_a[1] = 4*W  + 20
        matches_a[2] = 20*W + 5
        matches_a_cat = Variable(torch.zeros(1).type(dtype_long))
        for j in range(100):
            matches_a_cat = torch.cat((matches_a_cat, matches_a), 0)
        matches_a = matches_a_cat[:-1]

        matches_b = Variable(torch.zeros(num_matches).type(dtype_long), requires_grad=False)
        matches_b[0] = 10*W + 10
        matches_b[1] = 3*W  + 10
        matches_b[2] = 20*W + 5
        matches_b_cat = Variable(torch.zeros(1).type(dtype_long))
        for j in range(100):
            matches_b_cat = torch.cat((matches_b_cat, matches_b), 0)
        matches_b = matches_b_cat[:-1]
        
        # make 300 fake non matches
        num_non_matches = 3
        non_matches_a = Variable(torch.zeros(num_matches).type(dtype_long))
        non_matches_a[0] = 22*W + 10
        non_matches_a[1] = 4*W  + 10
        non_matches_a[2] = 20*W + 3
        non_matches_a_cat = Variable(torch.zeros(1).type(dtype_long))
        for i in range(100):
            non_matches_a_cat = torch.cat((non_matches_a_cat, non_matches_a), 0)
        non_matches_a = non_matches_a_cat[:-1]

        non_matches_b = Variable(torch.zeros(num_matches).type(dtype_long))
        non_matches_b[0] = 10*W + 5
        non_matches_b[1] = 3*W  + 7
        non_matches_b[2] = 20*W + 9
        non_matches_b_cat = Variable(torch.zeros(1).type(dtype_long))
        for i in range(100):
            non_matches_b_cat = torch.cat((non_matches_b_cat, non_matches_b), 0)
        non_matches_b = non_matches_b_cat[:-1]
        
        loss = 0
    
        # add loss via matches
        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)
        loss += (matches_a_descriptors - matches_b_descriptors).pow(2).sum()
        
        # add loss via non_matches
        M_margin = 0.5 # margin parameter
        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
        pixel_wise_loss = (matches_a_descriptors - matches_b_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        loss += torch.max(pixel_wise_loss, zeros_vec).sum()
                
        loss.backward()
        optimizer.step()
        
        print time.time() - start, " seconds on gpu " + os.environ["CUDA_VISIBLE_DEVICES"]
        print (i, loss.data[0])



0.934857845306  seconds on gpu 0
(99, 150.0)
0.237175941467  seconds on gpu 0
(99, 214.29100036621094)
0.243100881577  seconds on gpu 0
(99, 194.12046813964844)
0.237711906433  seconds on gpu 0
(99, 428.8049621582031)
0.221394062042  seconds on gpu 0
(99, 2421.378662109375)
0.243629932404  seconds on gpu 0
(99, 683.4022216796875)
0.217438936234  seconds on gpu 0
(99, 252.01956176757812)
0.21884894371  seconds on gpu 0
(99, 1109.9967041015625)
0.219982147217  seconds on gpu 0
(99, 498.9869384765625)
0.225056886673  seconds on gpu 0
(99, 240.7310791015625)
0.217658996582  seconds on gpu 0
(99, 151.36599731445312)
0.254911184311  seconds on gpu 0
(99, 545.020751953125)
0.224786043167  seconds on gpu 0
(99, 436.1492614746094)
0.224161863327  seconds on gpu 0
(99, 960.2179565429688)
0.212999820709  seconds on gpu 0
(99, 197.9718780517578)
0.232845067978  seconds on gpu 0
(99, 1055.8067626953125)
0.219347000122  seconds on gpu 0
(99, 885.7457275390625)
0.223605155945  seconds on gpu 0
(99, 3

Process Process-3:
Process Process-1:
Process Process-2:
Process Process-4:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self.run()
    self.run()
    self.run()
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/da

0.225718021393  seconds on gpu 0
(99, 279.74505615234375)


KeyboardInterrupt: 