In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorboardX
import PIL
import os
import time
import h5py
import matplotlib.pyplot as plt
import itertools

from deeplabv3 import deeplabv3_resnet101

In [2]:
import train_unary
import train_potts

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
voc_dir = 'VOCdevkit/VOC2012'
image_dir = 'JPEGImages'
segmentation_dir = 'SegmentationClass'
names_dir = 'ImageSets/Segmentation'

In [5]:
contour_dir = 'benchmark_RELEASE/dataset/cls'
contour_img_dir = 'benchmark_RELEASE/dataset/img'
contour_names = 'train_noval.txt'

In [6]:
batch_size = 6
n_epochs = 10
train_voc_dataset = train_unary.PascalVOCDataset(voc_dir, image_dir, segmentation_dir, names_dir, 'train', batch_size=batch_size)
train_voc_loader = torch.utils.data.DataLoader(train_voc_dataset, batch_size=batch_size, shuffle=True)

val_voc_dataset = train_unary.PascalVOCDataset(voc_dir, image_dir, segmentation_dir, names_dir, 'val', batch_size=batch_size)
val_voc_loader = torch.utils.data.DataLoader(val_voc_dataset, batch_size=batch_size, shuffle=True)

In [7]:
train_expand_dataset = train_unary.ExpandedVOCDataset(voc_dir, image_dir, 
                segmentation_dir, names_dir, 'train', contour_dir, contour_img_dir, contour_names, batch_size)
train_expand_loader = torch.utils.data.DataLoader(train_expand_dataset, batch_size=batch_size, shuffle=False)

n_iter = n_epochs * len(train_expand_dataset) // batch_size
print(n_iter)

11810


In [8]:
default_dataset = train_unary.DefaultDataset(voc_dir, image_dir, segmentation_dir, names_dir, 'val', batch_size=1)
default_loader = torch.utils.data.DataLoader(default_dataset, batch_size=1, shuffle=False)

In [9]:
class PolyLrDecay():
    def __init__(self, n_iter, power):
        self.n_iter = n_iter
        self.power = power

    def step(self, cur_iter):
        return (1 - cur_iter / self.n_iter) ** self.power


In [11]:
unary = deeplabv3_resnet101()
unary.load_state_dict(torch.load('unary_final.pth'))
unary = unary.to(device)

In [12]:
pairwise = train_potts.load_pairwise('unary_final.pth', 64, last_layer=2)

In [13]:
opt = torch.optim.Adam(itertools.chain(unary.classifier.parameters(), pairwise.parameters()), lr=3e-6)

In [14]:
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=PolyLrDecay(n_iter, 0.9).step)

In [None]:
unary, pairwise, iter_n = train_potts.train_final(train_expand_loader, default_loader, unary, pairwise,
                                                  opt, sch, n_epochs, start_epoch=0, iter_n=0)