In [None]:
import numpy as np
import argparse
import os
import os.path as osp
import tqdm
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import scipy.misc


In [None]:
# HYPERPARAMS
max_iteration=100000
lr=1.0e-14
momentum=0.99
weight_decay=0.0005
interval_validate=4000

In [None]:
cuda = torch.cuda.is_available()

# to reproduce same results
torch.manual_seed(1337)
if cuda:
    torch.cuda.manual_seed(1337)

In [None]:
class CDATA(torch.utils.data.Dataset): # Extend PyTorch's Dataset class
    def __init__(self, root_dir, train, transform=None):
        if(train):
            rfile = root_dir+'pascal_data/pascal_data/train_id.txt'
        else :
            rfile = root_dir+'pascal_data/pascal_data/test_id.txt'
        ldir = root_dir + 'VOCdevkit/VOC2010/JPEGImages/'
        sdir = root_dir + 'pascal_data/pascal_data/SegmentationPart/'
        self.transform = transform
        self.img = []
        self.seg = []
        
        with open(rfile,'r') as f:
            for line in f:
                    image = PIL.Image.open(ldir+line+'.jpg')
                    self.img.append(image.convert('RGB'))
                    segment = PIL.Image.open(sdir+line+'.png')
                    self.seg.append(segment)
#             PUT DATA IN CORRESPONDING VARS
            
            
#             self.label.append(ord(file_path.split('/')[-2]) - ord('A')) #ord makes A,B,C.. to 0,1,2,.. respectively

            
    def __len__(self):
        # return the size of the dataset (total number of images) as an integer
        # this should be rather easy if you created a mapping in __init__
        return len(self.img)
        
    def __getitem__(self, idx):
        # idx - the index of the sample requested
        #
        # Open the image correspoding to idx, apply transforms on it and return a tuple (image, label)
        # where label is an integer from 0-9 (since notMNIST has 10 classes)
        if self.transform is None:
            return (self.img[idx],self.seg[idx])
        else:
            img_transformed = self.transform(self.img[idx])
#             RETURN VARS
            return (img_transformed,self.seg[idx])


In [None]:
composed_transform = transforms.Compose([transforms.Scale((224,224)),transforms.ToTensor()])
train_dataset = CDATA(root_dir='/extra_data/ayushya/', train=True, transform=composed_transform) # Supply proper root_dir
test_dataset = CDATA(root_dir='/extra_data/ayushya/', train=False, transform=composed_transform) # Supply proper root_dir

# Let's check the size of the datasets, if implemented correctly they should be 16854 and 1870 respectively
print('Size of train dataset: %d' % len(train_dataset))
print('Size of test dataset: %d' % len(test_dataset))


kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}

# Create loaders for the dataset
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, **kwargs)


In [None]:
# MODEL
# DEFINE MODEL
# model = torchfcn.models.FCN8s(n_class=21)

resume = 0

start_epoch = 0
start_iteration = 0
if cuda:
    model = model.cuda()


In [None]:
# LOSS
def cross_entropy2d(input, target, weight=None, size_average=True):
    # input: (n, c, h, w), target: (n, h, w)
    n, c, h, w = input.size()
    # log_p: (n, c, h, w)
    log_p = F.log_softmax(input)
    # log_p: (n*h*w, c)
#     log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
#     log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
#     log_p = log_p.view(-1, c)
    # target: (n*h*w,)
#     mask = target >= 0
#     target = target[mask]
    loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss


In [None]:
# OPTIMIZER
optim = torch.optim.SGD(
    model.parameters(),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay)

In [None]:
# VALIDATION
def validate(iteration):
        val_loss = 0
        label_trues, label_preds = [], []
        for batch_idx, (data, target) in tqdm.tqdm(
                enumerate(test_loader), total=len(test_loader),
                desc='Valid iteration=%d' % iteration, ncols=80,
                leave=False):
            
#             INSERT TARGETS
            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            score = self.model(data)

            loss = cross_entropy2d(score, target)
            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            
            val_loss += float(loss.data[0]) / len(data)

#             imgs = data.data.cpu()
#             lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
#             lbl_true = target.data.cpu()
                
#         SAVE IMAGES        
#         out = "val_out/"
#         if not osp.exists(out):
#             os.makedirs(out)
#         out_file = osp.join(out, 'iter%.jpg' % iteration)
#         scipy.misc.imsave(out_file, image)

        val_loss /= len(test_loader)

        

In [None]:
# TRAINING
def train_model():
    max_epoch = int(math.ceil(1. * max_iter / len(train_loader)))
    for epoch in tqdm.trange(0, max_epoch,
                             desc='Train', ncols=80):
#         INSERT TARGETS
        for batch_idx, (data, target) in tqdm.tqdm( 
                enumerate(train_loader), total=len(train_loader),
                desc='Train epoch=%d' % epoch, ncols=80, leave=False):
            iteration = batch_idx + epoch * len(train_loader)
            
#             VALIDATE
            if iteration % interval_validate == 0:
                validate(iteration)

#             MODIFY FOR TARGETS
            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            optim.zero_grad()
#             MODIFY FOR OUTPUTS
            score = model(data)

            loss = cross_entropy2d(score, target)
            loss /= len(data)
            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while training')
            loss.backward()
            optim.step()

#             METRICS
#             metrics = []
#             lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
#             lbl_true = target.data.cpu().numpy()
#             for lt, lp in zip(lbl_true, lbl_pred):
#                 acc, acc_cls, mean_iu, fwavacc = \
#                     torchfcn.utils.label_accuracy_score(
#                         [lt], [lp], n_class=n_class)
#                 metrics.append((acc, acc_cls, mean_iu, fwavacc))
#             metrics = np.mean(metrics, axis=0)

            if iteration >= max_iter:
                break

