In [143]:
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
from skimage import color

import os
import cv2
import time
import sys
import argparse
import shutil
import collections
import numpy as np
import scipy.misc as misc

import math
import matplotlib.pyplot as plt
from PIL import Image

In [144]:
args_dict = {
   'path': '',
   'dataset': '',
   'large': False,
   'batch_size': 32,
   'lr': 1e-4,
   'weight_decay': 0,
   'num_epoch': 15,
   'lamb': 100,
   'test': '',
   'generator': 'model/1126.filter1/GAN__100L1_bs32_Adam_lr0.0001/G_epoch5.pth.tar',
   'generator_large': 'model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/G_epoch14.pth.tar',
   'discriminator': 'model/1126.filter1/GAN__100L1_bs32_Adam_lr0.0001/D_epoch5.pth.tar',
   'discriminator_large': 'model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/D_epoch14.pth.tar',
   'save': True,
   'gpu': 0
}


In [145]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)
        
        self.conv3 = nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.LeakyReLU(0.1)
        
        self.deconv4 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.LeakyReLU(0.1)
        
        self.deconv5 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(64)
        self.relu5 = nn.LeakyReLU(0.1)
        
        self.deconv6 = nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(3)
        self.relu6 = nn.LeakyReLU(0.1)
        
        self._initialize_weights()

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 16, 16
        pool1 = h
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 8, 8
        pool2 = h
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 128, 4, 4
        
        h = self.deconv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 128, 8, 8
        h += pool2
        
        h = self.deconv5(h)
        h = self.bn5(h)
        h = self.relu5(h) # 64, 16, 16
        h += pool1
        
        h = self.deconv6(h)
        h = F.tanh(h) # 3, 32, 32

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)
        
        self.conv3 = nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.LeakyReLU(0.1)
        
        self.conv4 = nn.Conv2d(128, 128, 4, stride=1, padding=0, bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.LeakyReLU(0.1)
        
        self.conv5 = nn.Conv2d(128, 1, 1, stride=1, padding=0, bias=False)
        
        self._initialize_weights()
            
    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 16, 16
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 8, 8
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 128, 4, 4
        
        h = self.conv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 128, 1, 1
        
        h = self.conv5(h)
        h = F.sigmoid(h)

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))



In [146]:
class Discriminator_Large(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)
        
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)
        
        self.conv4 = nn.Conv2d(256, 256, 3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.LeakyReLU(0.1)

        self.conv5 = nn.Conv2d(256, 256, 26, stride=2, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.LeakyReLU(0.1)

        self.conv6 = nn.Conv2d(256, 1, 1, stride=1, padding=0, bias=False)
        
        self._initialize_weights()
            
    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 200, 200
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 100, 100
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 256, 50, 50
        
        h = self.conv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 256, 25, 25

        h = self.conv5(h)
        h = self.bn5(h)
        h = self.relu5(h) # 256, 1, 1
        
        h = self.conv6(h)
        h = F.sigmoid(h)

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

class Generator_Large(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)

        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)

        self.conv4 = nn.Conv2d(256, 256, 3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.LeakyReLU(0.1)
        
        self.deconv5 = nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.LeakyReLU(0.1)
        
        self.deconv6 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(128)
        self.relu6 = nn.LeakyReLU(0.1)
        
        self.deconv7 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(64)
        self.relu7 = nn.LeakyReLU(0.1)

        self.deconv8 = nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(3)
        self.relu8 = nn.LeakyReLU(0.1)
        
        self._initialize_weights()

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 200, 200
        pool1 = h
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 100, 100
        pool2 = h
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 256, 50, 50
        pool3 = h
        
        h = self.conv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 256, 25, 25

        h = self.deconv5(h)
        h = self.bn5(h)
        h = self.relu5(h) # 256, 50, 50
        h += pool3

        h = self.deconv6(h)
        h = self.bn6(h)
        h = self.relu6(h) # 128, 100, 100
        h += pool2

        h = self.deconv7(h)
        h = self.bn7(h)
        h = self.relu7(h) # 64, 200, 200
        h += pool1
        
        h = self.deconv8(h)
        h = F.tanh(h) # 3, 400, 400

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


In [147]:
if args_dict['large']:
    transform = transforms.Compose(
        [transforms.CenterCrop(400), transforms.ToTensor()]
    )
else:
    transform = transforms.Compose(
        [transforms.CenterCrop(32), transforms.ToTensor()]
    )

In [148]:
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
   def __init__(self, X_dir, transform=None):
       self.X_dir = X_dir
       self.transform = transform
       self.X_filenames = os.listdir(X_dir)

   def __len__(self):
       return len(self.X_filenames)

   def __getitem__(self, idx):
       X_path = os.path.join(self.X_dir, self.X_filenames[idx])
       X = Image.open(X_path).convert('L')
       if self.transform:
           X = self.transform(X)
       return X


In [149]:
X_test = 'data/vis' if args_dict['large'] else 'dataset/vis'

In [150]:
dataset_train = CustomDataset(X_test, transform=transform)

train_loader = data.DataLoader(dataset_train, batch_size=32, shuffle=True)

In [151]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [152]:
if args_dict['large']:
    generator = Generator_Large().to(device=device)
    discriminator = Discriminator_Large().to(device=device)
else:
    generator = Generator().to(device=device)
    discriminator = Discriminator().to(device=device)

In [153]:
if args_dict['large']:
    G = args_dict['generator_large']
    D = args_dict['discriminator_large']
else:
    G = args_dict['generator']
    D = args_dict['discriminator']
start_epoch_G = start_epoch_D = 0
if G:
    print('Resume model G: %s' % G)
    checkpoint_G = torch.load(G)
    generator.load_state_dict(checkpoint_G['state_dict'])
    start_epoch_G = checkpoint_G['epoch']
if D:
    print('Resume model D: %s' % D)
    checkpoint_D = torch.load(D)
    discriminator.load_state_dict(checkpoint_D['state_dict'])
    start_epoch_D = checkpoint_D['epoch']
#assert start_epoch_G == start_epoch_D
if G == '' and D == '':
    print('No Resume')
    start_epoch = 0

Resume model G: model/1126.filter1/GAN__100L1_bs32_Adam_lr0.0001/G_epoch5.pth.tar
Resume model D: model/1126.filter1/GAN__100L1_bs32_Adam_lr0.0001/D_epoch5.pth.tar


In [154]:
lr = 0.0001
num_epochs = 50
criterion = nn.BCELoss()
L1 = nn.L1Loss()

optimizer_discriminator = optim.Adam(discriminator.parameters(), 
            lr=args_dict['lr'], betas=(0.5, 0.999), 
            eps=1e-8, weight_decay=args_dict['weight_decay'])
optimizer_generator = optim.Adam(generator.parameters(), 
            lr=args_dict['lr'], betas=(0.5, 0.999),
            eps=1e-8, weight_decay=args_dict['weight_decay'])

In [155]:
date = 'Test'
size = ''
img_path = 'img/%s/GAN_%s%s_%dL1_bs%d_%s_lr%s/' \
           % (date, args_dict['dataset'], size, args_dict['lamb'], args_dict['batch_size'], 'Adam', str(args_dict['lr']))
model_path = 'model/%s/GAN_%s%s_%dL1_bs%d_%s_lr%s/' \
           % (date, args_dict['dataset'], size, args_dict['lamb'], args_dict['batch_size'], 'Adam', str(args_dict['lr']))
if not os.path.exists(img_path):
    os.makedirs(img_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [156]:
for i, (data) in enumerate(train_loader):
    latent_space_samples = Variable(data.to(device=device))


generated_samples = generator(latent_space_samples)

In [157]:
generated_samples = generated_samples.cpu().detach()


for i in range(0, len(generated_samples)):
    l = torch.unsqueeze(torch.squeeze(latent_space_samples[i]), 0).cpu().numpy()
    pred = generated_samples[i].cpu().numpy()

    pred_rgb = (np.transpose(pred, (1,2,0)).astype(np.float64) + 1) / 2.

    grey = np.transpose(l, (1,2,0))
    grey = np.repeat(grey, 3, axis=2).astype(np.float64)
    img_list = np.concatenate((grey, pred_rgb), 1)

    plt.figure(figsize=(36,27))
    plt.imshow(img_list)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(img_path + f'test{i}.png')
    plt.clf()

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>