In [9]:
%load_ext autoreload
%autoreload 2
import os

from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Cityscapes
from torchvision.utils import save_image
from unet import PavelNet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# make a folder to save output images
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')
    

# a function used to transform numpy array to image format
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), x.size(1), x.size(2), x.size(3))
    return x

#################### select your hyperparameters ############################
num_epochs = 1
batch_size = 1
n_samples = 1

####### define image transforms, you can have other choices, explore it! #####
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# download dataset
dataset = Cityscapes('./data/cityscapes', transform=transform, target_transform=transform, target_type='color')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [11]:
model = PavelNet()
############ chose appropriate loss function ######################
criterion = nn.L1Loss()

# set the optimizer, explore the effect of different optimizers
lr = 0.1
optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8)

In [12]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        if i > n_samples:
            break
        img, mask_with_alpha = data
        mask = mask_with_alpha[:, :3, :, :]
        #img = img.view(img.size(0), -1)
        # ===================forward=====================
        img_pred = model(mask)
        loss = criterion(img_pred, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================user interaction========================
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.data))
    if epoch % 1 == 0:
        pic = to_img(img_pred.data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
        
        ori_pic = to_img(img.data)
        save_image(ori_pic, './mlp_img/ori_image_{}.png'.format(epoch))
        
        mask_pic = to_img(mask.data)
        save_image(mask_pic, './mlp_img/mask_image_{}.png'.format(epoch))

epoch [1/1], loss:0.7369
