## pix2pix pytorch lightning

### Data loader

In [1]:
%matplotlib inline
import os
import sys
from PIL import Image
import re
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [2]:
import os
import torchvision
from argparse import ArgumentParser, Namespace
from collections import OrderedDict

In [3]:
from models.pix2pix.datasets import ImageDataset, FloorplanDataset

import albumentations as A
import torchvision.transforms as transforms

ModuleNotFoundError: No module named 'models.pix2pix'; 'models' is not a package

In [None]:
dataset_name = "floor/newyork"
img_height = 256
img_width = 256

In [None]:
transforms_ = [
            A.Resize(
                img_height,
                img_width,
            ),
            A.Rotate(23),
            A.HorizontalFlip(),
            A.RandomBrightnessContrast(),
            A.HueSaturationValue(),
            A.RGBShift(),
            A.RandomGamma(),
        ]

In [None]:
dataset = FloorplanDataset("./datasets/%s" % dataset_name, transforms_=transforms_, mode="test")

In [None]:
len(dataset)

In [None]:
def draw(A, B):    
    fig = plt.figure(figsize=(10, 10))

    plt.subplot(1, 2, 1)
    plt.imshow(A)
    plt.title('original')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(B)
    plt.title('mask')
    plt.axis('off')
    plt.show()

In [None]:
dataset.files[12]

In [None]:
for i in range(len(dataset)):
    item = dataset[i]
    print(item['A'].shape, item['B'].shape)
    break

In [None]:
item = dataset[10]
draw(item['A'].transpose((1, 2, 0)), item['B'].transpose((1, 2, 0)))

### Train

In [None]:
from pytorch_lightning.trainer import Trainer
from models.pix2pix_model import Pix2PixModel
from torchsummary import summary
import torch
from pytorch_lightning.loggers import TensorBoardLogger
from pathlib import Path
from time import time

#### Model Summary

In [None]:
model = Pix2PixModel(dataset_name="floor/newyork")

In [None]:
train_loader = model.train_dataloader()

In [None]:
for batch in train_loader:
    print(batch['A'].shape)
    print(batch['B'].shape)
    break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = model.to(device)
summary(model, input_size=(3,256,256))

#### Model Training

In [None]:
tb_logger = TensorBoardLogger(
        f'logs',
        name=f'pix2pix_floorplan',
#         version=str(int(time())),
    )
log_dir = Path(tb_logger.log_dir)
log_dir.mkdir(parents=True, exist_ok=True)

In [None]:
model = Pix2PixModel(batch_size=2, dataset_name="floor/newyork")

In [None]:
trainer = Trainer(logger=tb_logger, weights_save_path=f'experiments', gpus=[0], max_epochs=100)
trainer.fit(model)

## Evaluate

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = model.to(device)

In [None]:
item = dataset[15]
draw(item['A'].transpose((1, 2, 0)), item['B'].transpose((1, 2, 0)))

In [None]:
input = torch.from_numpy(item['B']).unsqueeze(0)
input = input.to(device)
gt = torch.from_numpy(item['A']).unsqueeze(0)
gt = gt.to(device)

In [None]:
input.shape

In [None]:
generated = model(input)

In [None]:
img_sample = torch.cat((input, generated, gt), -1)
print(generated.shape, img_sample.shape)
grid = torchvision.utils.make_grid(img_sample)

In [None]:
def display_generated(grid):
    npgrid = grid.cpu().numpy()
    fig = plt.figure(figsize=(20, 10))

    plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')

In [None]:
display_generated(grid)

#### JP Dataset

In [None]:
import glob
import re

In [None]:
dataset_name = "./datasets/floor/jp"
mode = "train"
img_height = 256
img_width = 256

In [None]:
files = sorted(glob.glob(os.path.join(dataset_name, mode) + "/*_multi.*"))

In [None]:
files[:10]

In [None]:
path_A = files[2]

In [None]:
im = Image.open(path_A).convert('RGB')
im_resized = im.resize((img_width, img_height))
img_A = np.array(im_resized).astype(np.float32).transpose((2, 0, 1))
img_A = img_A / 255.

In [None]:
input = torch.from_numpy(img_A).unsqueeze(0)
input = input.to(device)
input.shape

In [None]:
generated = model(input)
generated.shape

In [None]:
img_sample = torch.cat((input, generated), -1)
print(img_sample.shape)
grid = torchvision.utils.make_grid(img_sample)

In [None]:
display_generated(grid)