In [0]:
! git clone https://github.com/sigtot/unet-auto unet

Cloning into 'unet'...
remote: Enumerating objects: 441, done.[K
remote: Counting objects: 100% (441/441), done.[K
remote: Compressing objects: 100% (233/233), done.[K
remote: Total 441 (delta 246), reused 393 (delta 200), pack-reused 0[K
Receiving objects: 100% (441/441), 48.78 MiB | 31.52 MiB/s, done.
Resolving deltas: 100% (246/246), done.


# New Section

In [0]:
! git -C unet checkout master

Already on 'master'
Your branch is up to date with 'origin/master'.


In [0]:
! git -C unet pull

Already up to date.


In [0]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
%load_ext autoreload
%autoreload 2
import os

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Cityscapes
from torchvision.utils import save_image
from unet.unet import PavelNet
from unet.unet import SigurdModel
import datetime
import numpy as np
import random

In [0]:
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Using device {device}")

# make a folder to save output images
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')
    

# a function used to transform numpy array to image format
def to_img(x):
    #x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), x.size(1), x.size(2), x.size(3))
    return x

#################### select your hyperparameters ############################
num_epochs = 200
batch_size = 1
n_samples = 10

####### define image transforms, you can have other choices, explore it! #####
class DoubleCompose(transforms.Compose):
    def __init__(self, *args, **kwargs):
        super(DoubleCompose, self).__init__(*args, **kwargs)

    def __call__(self, image, target):
        seed = np.random.randint(2147483647)
        random.seed(seed)
        t_image = super(DoubleCompose, self).__call__(image)
        random.seed(seed)
        t_target = super(DoubleCompose, self).__call__(target)
        return t_image, t_target

transform = DoubleCompose([
    transforms.Resize((286, 286)),
    transforms.RandomCrop((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# download dataset
dataset = Cityscapes('/content/drive/My Drive/citytiny', transforms=transform, target_type='color')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

Using device cuda:0


In [0]:
model = SigurdModel(lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-8, disc_mult=0.5, l1_mult=100, l1_only=False)
model.to(device)

SigurdModel(
  (discriminator): ArtNet(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (3): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (4): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (5): Conv2

In [0]:
model_save_path = "/content/drive/My Drive/models/model_cgan_l1_new.pt"

In [0]:
refresh_model = False
if os.path.isfile(model_save_path) and not refresh_model:
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    loss_G = checkpoint['loss_G']
    loss_D = checkpoint['loss_D']
    print(f"Loaded {start_epoch + 1} epochs from file")
else:
    start_epoch = 0
    print("Starting with a fresh model")

Loaded 189 epochs from file


In [0]:
for epoch in range(start_epoch, num_epochs):
    for i, data in enumerate(dataloader):
        if i > 10:
            exit(0)
        img, mask_with_alpha = data
        mask = mask_with_alpha[:, :3, :, :]
        img = img.to(device)
        mask = mask.to(device)
        #img = img.view(img.size(0), -1)
        
        img_pred = model.forward(mask, img)
        loss_G, loss_D = model.backward()

    # ===================user interaction========================
    print('epoch [{}/{}], loss G:{:.4f}, loss D: {:.4f}'.format(epoch + 1, num_epochs, loss_G, loss_D))
    if epoch % 1 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss_G': loss_G,
            'loss_D': loss_D
            }, model_save_path)

        pic = to_img(img_pred.data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
        
        ori_pic = to_img(img.data)
        save_image(ori_pic, './mlp_img/ori_image_{}.png'.format(epoch))
        
        mask_pic = to_img(mask.data)
        save_image(mask_pic, './mlp_img/mask_image_{}.png'.format(epoch))


epoch [189/200], loss G:7.5662, loss D: 0.0000
epoch [190/200], loss G:4.9392, loss D: 0.0001
epoch [191/200], loss G:7.8371, loss D: 0.0000
epoch [192/200], loss G:8.5910, loss D: 0.0000
epoch [193/200], loss G:6.8394, loss D: 0.0000
epoch [194/200], loss G:5.9928, loss D: 0.0001
epoch [195/200], loss G:6.0210, loss D: 0.0000
epoch [196/200], loss G:5.5521, loss D: 0.0000
epoch [197/200], loss G:6.7824, loss D: 0.0000
epoch [198/200], loss G:5.3400, loss D: 0.0000
epoch [199/200], loss G:6.0071, loss D: 0.0000
epoch [200/200], loss G:5.4111, loss D: 0.0000


In [0]:
class CityScapesWithPaths(Cityscapes):
    def __init__(self, root, split='train', mode='fine', target_type='instance',
        transform=None, target_transform=None, transforms=None):
        super(CityScapesWithPaths, self).__init__(root, split=split, mode=mode, target_type=target_type,transform=transform, target_transform=target_transform, transforms=transforms)

    def __getitem__(self, index):
        original_tuple = super(CityScapesWithPaths, self).__getitem__(index)
        path = self.images[index]
        tuple_with_path = ((path,) + original_tuple)
        return tuple_with_path

batch_size = 1
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])
dataset_val = CityScapesWithPaths('/content/drive/My Drive/citysmall', transform=transform, target_transform=transform, target_type='color', split="val")
num_imgs = 20

if not os.path.exists('./val_img'):
    os.mkdir('./val_img')

for i, data in enumerate(dataset_val):
    if i >= num_imgs:
        break
    path, img, mask_with_alpha = data
    save_image(img, './val_img/ori_image_{}.jpg'.format(i))
    save_image(mask_with_alpha, './val_img/mask_image_{}.png'.format(i))

    mask = mask_with_alpha[:3, :, :]

    mask = mask.unsqueeze(0)
    img = img.unsqueeze(0)

    img = img.to(device)
    mask = mask.to(device)
    
    img_pred = model.forward(mask, img)

    pic = to_img(img_pred.data)
    save_image(img_pred, './val_img/image_{}.jpg'.format(i))
    


/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000035_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000043_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000012_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000052_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000000_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000019_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000015_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000026_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000007_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lindau_000042_000019_leftImg8bit.png
/content/drive/My Drive/citysmall/leftImg8bit/val/lindau/lin