In [1]:
%matplotlib inline

import sys, os
sys.path.insert(0, '../../../../vision/')
sys.path.append('../../../../../pytorch-segmentation-detection/')

# Use second GPU -pytorch-segmentation-detection- change if you want to use a first one
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

from PIL import Image
from matplotlib import pyplot as plt

import torch
from torchvision import transforms
import torchvision
from torch.autograd import Variable
from pytorch_segmentation_detection.layers import GlobalAvgPool2d

import numpy as np
import torch.nn as nn

from pytorch_segmentation_detection.datasets.pascal_voc import PascalVOCSegmentation

import pytorch_segmentation_detection.models.fcn as fcns
import pytorch_segmentation_detection.models.resnet_dilated as resnet_dilated
from pytorch_segmentation_detection.transforms import (ComposeJoint,
                                                       RandomHorizontalFlipJoint,
                                                       RandomScaleJoint,
                                                       CropOrPad,
                                                       ResizeAspectRatioPreserve,
                                                       RandomCropJoint,
                                                       Split2D)


import torch.optim as optim

from sklearn.metrics import confusion_matrix

def flatten_logits(logits, number_of_classes):
    """Flattens the logits batch except for the logits dimension"""
    
    logits_permuted = logits.permute(0, 2, 3, 1)
    logits_permuted_cont = logits_permuted.contiguous()
    logits_flatten = logits_permuted_cont.view(-1, number_of_classes)
    
    return logits_flatten

def flatten_annotations(annotations):
    
    return annotations.view(-1)

def get_valid_annotations_index(flatten_annotations, mask_out_value=255):
    
    return torch.squeeze( torch.nonzero((flatten_annotations != mask_out_value )), 1)

def adjust_learning_rate(optimizer, iteration):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    """Pete: this also currently effectively stops training via an error
    since at max_iterations+1, it tries to raise a negative number to a fractional power"""
    
    max_iteration = 10000.0
    
    multiplier = (1.0 - (iteration / max_iteration)) ** (0.9)
    
    lr = 0.0001 * multiplier
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


from pytorch_segmentation_detection.transforms import RandomCropJoint


# need to stick to 21 for now until change loss functions
# since the data has labels for 21 classes, we get an error
# unless this is 21
descriptor_dimensionality = 21

In [2]:
train_transform = ComposeJoint(
                [
                    RandomHorizontalFlipJoint(),
                    RandomCropJoint(crop_size=(513, 513)),
                    [transforms.ToTensor(), None],
                    [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                    [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])

path_to_VOC = '/media/peteflo/3TBbackup/pytorch-pretrained/VOC'

trainset = PascalVOCSegmentation(path_to_VOC,
                                 download=False,
                                 joint_transform=train_transform)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=4, drop_last=True)



valid_transform = ComposeJoint(
                [
                     [transforms.ToTensor(), None],
                     [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), None],
                     [None, transforms.Lambda(lambda x: torch.from_numpy(np.asarray(x)).long()) ]
                ])


valset = PascalVOCSegmentation(path_to_VOC,
                               train=False,
                               download=False,
                               joint_transform=valid_transform)


valset_loader = torch.utils.data.DataLoader(valset, batch_size=1,
                                            shuffle=False, num_workers=2)

train_subset_sampler = torch.utils.data.sampler.SubsetRandomSampler(xrange(904))
train_subset_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1,
                                                   sampler=train_subset_sampler,
                                                   num_workers=2)

# TODO: define something like a validate() function
# # Define the validation function to track MIoU during the training
# def validate():

#     ...

#     return mean_intersection_over_union



In [3]:
%matplotlib notebook
from matplotlib import pyplot as plt

# Create the training plot
loss_current_iteration = 0
loss_history = []
loss_iteration_number_history = []

validation_current_iteration = 0
validation_history = []
validation_iteration_number_history = []

train_validation_current_iteration = 0
train_validation_history = []
train_validation_iteration_number_history = []
 
f, (loss_axis, validation_axis) = plt.subplots(2, 1)

loss_axis.plot(loss_iteration_number_history, loss_history)
validation_axis.plot(validation_iteration_number_history, validation_history, 'b',
                     train_validation_iteration_number_history, train_validation_history, 'r')

loss_axis.set_title('Training loss')
validation_axis.set_title('Score (to be defined) on validation dataset')

plt.tight_layout()

<IPython.core.display.Javascript object>

In [4]:
fcn = resnet_dilated.Resnet34_8s(num_classes=descriptor_dimensionality)
fcn.cuda()
fcn.train()

# note: the softmax happens inside the CrossEntropyLoss
criterion = nn.CrossEntropyLoss(size_average=False).cuda()

optimizer = optim.Adam(fcn.parameters(), lr=0.0001, weight_decay=0.0001)

  own_state[name].copy_(param)


In [5]:
import time

best_validation_score = 0

iter_size = 20

for epoch in range(130):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        start = time.time()
        
        # get the inputs
        img, anno = data      
        # img is 10 x 3 x 513 x 513 (10 was batch size, 3 is for RGB)
        # anno is 10 x 513 x 513 (same as above, but no rgb)
        
        # We need to flatten annotations and logits to apply index of valid
        # annotations. All of this is because pytorch doesn't have tf.gather_nd()
        anno_flatten = flatten_annotations(anno)
        index = get_valid_annotations_index(anno_flatten, mask_out_value=255)
        anno_flatten_valid = torch.index_select(anno_flatten, 0, index)

        # wrap them in Variable
        # the index can be acquired on the gpu
        img, anno_flatten_valid, index = Variable(img.cuda()), Variable(anno_flatten_valid.cuda()), Variable(index.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, loss_current_iteration)

        W = 513
        H = 513
        N = 1
        D_in = W*H*3
        num_Hidden = 1000
        D_descriptor = descriptor_dimensionality
        D_out = W*H*D_descriptor
        
        # forward + backward + optimize
        image_a_pred = fcn(img)
        image_a_pred = image_a_pred.view(N,W*H,D_descriptor)
        
        img_2 = img + torch.ones_like(img)
        image_b_pred = fcn(img_2)
        image_b_pred = image_b_pred.view(N,W*H,D_descriptor)
        
        # setup needed for matches
        # N is batch size; D_in is input dimension;
        # H is hidden dimension; D_out is output dimension.
        
        dtype = torch.cuda.FloatTensor
        dtype_long = torch.cuda.LongTensor
        
        # make 300 fake matches
        num_matches = 3
        matches_a = Variable(torch.zeros(num_matches).type(dtype_long), requires_grad=False)
        matches_a[0] = 20*W + 20
        matches_a[1] = 4*W  + 20
        matches_a[2] = 20*W + 5
        matches_a_cat = Variable(torch.zeros(1).type(dtype_long))
        for j in range(100):
            matches_a_cat = torch.cat((matches_a_cat, matches_a), 0)
        matches_a = matches_a_cat[:-1]

        matches_b = Variable(torch.zeros(num_matches).type(dtype_long), requires_grad=False)
        matches_b[0] = 10*W + 10
        matches_b[1] = 3*W  + 10
        matches_b[2] = 20*W + 5
        matches_b_cat = Variable(torch.zeros(1).type(dtype_long))
        for j in range(100):
            matches_b_cat = torch.cat((matches_b_cat, matches_b), 0)
        matches_b = matches_b_cat[:-1]
        
        loss = 0
    
        # add loss via matches
        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)
        
        loss += (matches_a_descriptors - matches_b_descriptors).pow(2).sum()
                
        loss.backward()
        optimizer.step()
        
        print time.time() - start, " seconds on gpu " + os.environ["CUDA_VISIBLE_DEVICES"]
        print (i, loss.data[0])



0.889913082123  seconds on gpu 0
(99, 711.9472045898438)
0.24299788475  seconds on gpu 0
(99, 3319.25244140625)
0.213974952698  seconds on gpu 0
(99, 962.6979370117188)
0.250316143036  seconds on gpu 0
(99, 764.55859375)
0.273765087128  seconds on gpu 0
(99, 1479.88134765625)
0.218839883804  seconds on gpu 0
(99, 1586.27685546875)
0.220192909241  seconds on gpu 0
(99, 173.2860870361328)
0.217671871185  seconds on gpu 0
(99, 926.4786376953125)
0.216414928436  seconds on gpu 0
(99, 551.9984741210938)
0.219424009323  seconds on gpu 0
(99, 564.038818359375)
0.223541975021  seconds on gpu 0
(99, 604.707763671875)
0.275671005249  seconds on gpu 0
(99, 62.91255187988281)
0.225982904434  seconds on gpu 0
(99, 347.6622314453125)
0.236505031586  seconds on gpu 0
(99, 291.3843688964844)
0.217339992523  seconds on gpu 0
(99, 318.18701171875)
0.219408035278  seconds on gpu 0
(99, 210.2272491455078)
0.226052045822  seconds on gpu 0
(99, 781.9302368164062)
0.222570896149  seconds on gpu 0
(99, 2296.2

Process Process-3:
Process Process-1:
Process Process-4:
Traceback (most recent call last):
Process Process-2:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
    self.run()
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
    self._target(*se

0.220783948898  seconds on gpu 0
(99, 28.100666046142578)


KeyboardInterrupt: 