In [1]:
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
from skimage import color

import os
import cv2
import time
import sys
import argparse
import shutil
import collections
import numpy as np
import scipy.misc as misc

import math
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)
        
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)
        
        self.conv4 = nn.Conv2d(256, 256, 3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.LeakyReLU(0.1)

        self.conv5 = nn.Conv2d(256, 256, 26, stride=2, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.LeakyReLU(0.1)

        self.conv6 = nn.Conv2d(256, 1, 1, stride=1, padding=0, bias=False)
        
        self._initialize_weights()
            
    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 200, 200
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 100, 100
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 256, 50, 50
        
        h = self.conv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 256, 25, 25

        h = self.conv5(h)
        h = self.bn5(h)
        h = self.relu5(h) # 256, 1, 1
        
        h = self.conv6(h)
        h = F.sigmoid(h)

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))



In [3]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)

        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)

        self.conv4 = nn.Conv2d(256, 256, 3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.LeakyReLU(0.1)
        
        self.deconv5 = nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.LeakyReLU(0.1)
        
        self.deconv6 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(128)
        self.relu6 = nn.LeakyReLU(0.1)
        
        self.deconv7 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(64)
        self.relu7 = nn.LeakyReLU(0.1)

        self.deconv8 = nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(3)
        self.relu8 = nn.LeakyReLU(0.1)
        
        self._initialize_weights()

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bn1(h)
        h = self.relu1(h) # 64, 200, 200
        pool1 = h
        
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h) # 128, 100, 100
        pool2 = h
        
        h = self.conv3(h)
        h = self.bn3(h)
        h = self.relu3(h) # 256, 50, 50
        pool3 = h
        
        h = self.conv4(h)
        h = self.bn4(h)
        h = self.relu4(h) # 256, 25, 25

        h = self.deconv5(h)
        h = self.bn5(h)
        h = self.relu5(h) # 256, 50, 50
        h += pool3

        h = self.deconv6(h)
        h = self.bn6(h)
        h = self.relu6(h) # 128, 100, 100
        h += pool2

        h = self.deconv7(h)
        h = self.bn7(h)
        h = self.relu7(h) # 64, 200, 200
        h += pool1
        
        h = self.deconv8(h)
        h = F.tanh(h) # 3, 400, 400

        return h
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


In [4]:
torch.manual_seed(111)

<torch._C.Generator at 0x233e3544890>

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [6]:
args_dict = {
   'path': '',
   'dataset': '',
   'large': False,
   'batch_size': 32,
   'lr': 1e-4,
   'weight_decay': 0,
   'num_epoch': 15,
   'lamb': 100,
   'test': '',
   'generator': 'model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/G_epoch4.pth.tar',
   'discriminator': 'model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/D_epoch4.pth.tar',
   'save': True,
   'gpu': 0
}


In [7]:
# from torch.utils.data import Dataset

# class CustomDataset(Dataset):
#     def __init__(self, X_train, y_train):
#         self.X_train = X_train
#         self.y_train = y_train

#     def __len__(self):
#         return len(self.X_train)

#     def __getitem__(self, idx):
#         img = self.X_train[idx]
#         img = np.array(img)

#         img_lab = self.y_train[idx]
#         img_lab = np.array(img_lab)

#         img = torch.FloatTensor(np.transpose(img, (2,0,1)))
#         img_lab = torch.FloatTensor(np.transpose(img_lab, (2,0,1)))

#         img = np.reshape(img, img.shape+(1,))
#         img_lab = np.reshape(img_lab, img_lab.shape+(1,))


#         return img, img_lab
# #

In [8]:
root_dataset_train = 'data/train_black'
root_dataset_test = 'data/test'
X_train_images_dir = f'data/train_black'
X_test_dir = f'data/test_black'
y_train_images_dir = f'data/train_color'
y_test_dir = f'data/test_color'
X_train, X_test = [], []
y_train, y_test = [], []

In [9]:
transform = transforms.Compose(
    [transforms.CenterCrop(400), transforms.ToTensor()]
)

In [10]:
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
   def __init__(self, X_dir, y_dir, transform=None):
       self.X_dir = X_dir
       self.y_dir = y_dir
       self.transform = transform
       self.X_filenames = os.listdir(X_dir)
       self.y_filenames = os.listdir(y_dir)

   def __len__(self):
       return len(self.X_filenames)

   def __getitem__(self, idx):
       X_path = os.path.join(self.X_dir, self.X_filenames[idx])
       y_path = os.path.join(self.y_dir, self.y_filenames[idx])
       X = Image.open(X_path).convert('L')
       y = Image.open(y_path)
       if self.transform:
           X = self.transform(X)
           y = self.transform(y)
       return X, y


In [11]:
dataset_train = CustomDataset(X_train_images_dir, y_train_images_dir, transform=transform)

train_loader = data.DataLoader(dataset_train, batch_size=32, shuffle=True)

In [12]:
dataset_test = CustomDataset(X_test_dir, y_test_dir, transform=transform)

val_loader = data.DataLoader(dataset_test, batch_size=32, shuffle=True)

In [13]:
# for filename in os.listdir(X_train_images_dir):
#     if filename.endswith('.jpg'):
#         img = cv2.imread(os.path.join(X_train_images_dir, filename), cv2.IMREAD_GRAYSCALE)
#         X_train.append(img)
# for filename in os.listdir(X_test_dir):
#     if filename.endswith('.jpg'):
#         img = cv2.imread(os.path.join(X_test_dir, filename), cv2.IMREAD_GRAYSCALE)
#         X_test.append(img)
# for filename in os.listdir(y_train_images_dir):
#     if filename.endswith('.jpg'):
#         img = cv2.imread(os.path.join(y_train_images_dir, filename))
#         y_train.append(img)
# for filename in os.listdir(y_test_dir):
#     if filename.endswith('.jpg'):
#         img = cv2.imread(os.path.join(y_test_dir, filename))
#         y_test.append(img)

In [14]:
# #X_train_tensor = torch.from_numpy(np.array(X_train))
# #y_train_tensor = torch.from_numpy(np.array(y_train))

# X_train = [x.astype('float32') for x in X_train]
# y_train = [y.astype('float32') for y in y_train]


# dataset = CustomDataset(X_train, y_train)
# dataset_val = CustomDataset(X_test, y_test)
# #train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)

# batch_size = 32
# train_loader = data.DataLoader(
#     dataset, batch_size=args_dict['batch_size'], shuffle=False
# )

# batch_size = 32
# val_loader = data.DataLoader(
#     dataset_val, batch_size=args_dict['batch_size'], shuffle=False
# )

In [15]:
# print(np.array(X_train).shape)
# print(np.array(y_train).shape)

In [16]:
generator = Generator().to(device=device)
discriminator = Discriminator().to(device=device)

In [17]:
start_epoch_G = start_epoch_D = 0
if args_dict['generator']:
    print('Resume model G: %s' % args_dict['generator'])
    checkpoint_G = torch.load(args_dict['generator'])
    generator.load_state_dict(checkpoint_G['state_dict'])
    start_epoch_G = checkpoint_G['epoch']
if args_dict['discriminator']:
    print('Resume model D: %s' % args_dict['discriminator'])
    checkpoint_D = torch.load(args_dict['discriminator'])
    discriminator.load_state_dict(checkpoint_D['state_dict'])
    start_epoch_D = checkpoint_D['epoch']
assert start_epoch_G == start_epoch_D
if args_dict['generator'] == '' and args_dict['discriminator'] == '':
    print('No Resume')
    start_epoch = 0
print(start_epoch_D)

Resume model G: model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/G_epoch4.pth.tar
Resume model D: model/1127.large/GAN__100L1_bs32_Adam_lr0.0001/D_epoch4.pth.tar
5


In [18]:
lr = 0.0001
num_epochs = 50
criterion = nn.BCELoss()
L1 = nn.L1Loss()

optimizer_discriminator = optim.Adam(discriminator.parameters(), 
            lr=args_dict['lr'], betas=(0.5, 0.999), 
            eps=1e-8, weight_decay=args_dict['weight_decay'])
optimizer_generator = optim.Adam(generator.parameters(), 
            lr=args_dict['lr'], betas=(0.5, 0.999),
            eps=1e-8, weight_decay=args_dict['weight_decay'])

In [19]:
if args_dict['generator']:
    optimizer_generator.load_state_dict(checkpoint_G['optimizer'])
if args_dict['discriminator']:
    optimizer_discriminator.load_state_dict(checkpoint_D['optimizer'])

In [20]:
class Scale(object):
    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        print("scale")
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            # import ipdb; ipdb.set_trace()
            return img.resize(self.size, self.interpolation)
        
        
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.history = []
        self.dict = {} # save all data values here
        self.save_dict = {} # save mean and std here, for summary table

    def update(self, val, n=1, history=0):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        if history:
            self.history.append(val)

    def __len__(self):
        return self.count

class Plotter_GAN(object):
    """plot loss for G and D"""
    def __init__(self):
        self.g_loss = []
        self.d_loss = []

    def g_update(self, loss):
        if type(loss) != float:
            loss = float(loss)
        self.g_loss.append(loss)

    def d_update(self, loss):
        if type(loss) != float:
            loss = float(loss)
        self.d_loss.append(loss)

    def draw(self, filename):
        name = 'loss'
        if len(self.g_loss) == len(self.d_loss):
            plt.figure()
            plt.plot(self.g_loss, label='G')
            plt.plot(self.d_loss, label='D')
            plt.legend(loc='upper left')
            plt.xlabel('epoch')
            plt.title(name)

            plt.tight_layout()
            plt.savefig(filename)
            plt.clf()
            plt.close()

class Plotter_GAN_TV(object):
    """plot loss for G and D with Training and Validation"""
    def __init__(self):
        self.g_loss_t = []
        self.d_loss_t = []
        self.g_loss_v = []
        self.d_loss_v = []

    def train_update(self, g_loss, d_loss):
        if type(g_loss) != float:
            g_loss = float(g_loss)
        if type(d_loss) != float:
            d_loss = float(d_loss)
        self.g_loss_t.append(g_loss)
        self.d_loss_t.append(d_loss)

    def val_update(self, g_loss, d_loss):
        if type(g_loss) != float:
            g_loss = float(g_loss)
        if type(d_loss) != float:
            d_loss = float(d_loss)
        self.g_loss_v.append(g_loss)
        self.d_loss_v.append(d_loss)

    def draw(self, filename):
        name = 'loss'
        if len(self.g_loss_t) == len(self.d_loss_t) and\
           len(self.g_loss_v) == len(self.d_loss_v) and\
           len(self.g_loss_t) == len(self.g_loss_v):
            plt.figure()
            plt.plot(self.g_loss_t, label='G_train')
            plt.plot(self.d_loss_t, label='D_train')
            plt.plot(self.g_loss_v, label='G_val')
            plt.plot(self.d_loss_v, label='D_val')
            plt.legend(loc='upper left')
            plt.xlabel('epoch')
            plt.title(name)

            plt.tight_layout()
            plt.savefig(filename)
            plt.clf()
            plt.close()



In [21]:
iteration = 0
print_interval = 500
plotter = Plotter_GAN_TV()
plotter_basic = Plotter_GAN()
date = f'1127.large'

In [22]:
size = ''
img_path = 'img/%s/GAN_%s%s_%dL1_bs%d_%s_lr%s/' \
           % (date, args_dict['dataset'], size, args_dict['lamb'], args_dict['batch_size'], 'Adam', str(args_dict['lr']))
model_path = 'model/%s/GAN_%s%s_%dL1_bs%d_%s_lr%s/' \
           % (date, args_dict['dataset'], size, args_dict['lamb'], args_dict['batch_size'], 'Adam', str(args_dict['lr']))
if not os.path.exists(img_path):
    os.makedirs(img_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [23]:
def save_checkpoint(state, is_best=0, filename='models/checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'models/model_best.pth.tar')

In [24]:
def train(train_loader, model_G, model_D, optimizer_G, optimizer_D, epoch, iteration):
    errorG = AverageMeter() # will be reset after each epoch
    errorD = AverageMeter() # will be reset after each epoch
    errorG_basic = AverageMeter() # basic will be reset after each print
    errorD_basic = AverageMeter() # basic will be reset after each print
    errorD_real = AverageMeter()
    errorD_fake = AverageMeter()
    errorG_GAN = AverageMeter()
    errorG_R = AverageMeter()

    model_G.train()
    model_D.train()

    real_label = 1
    fake_label = 0

    for i, (data, target) in enumerate(train_loader):
        data, target = Variable(data.to(device=device)), Variable(target.to(device=device))

        ########################
        # update D network
        ########################
        # train with real
        model_D.zero_grad()
        output = model_D(target)
        label = torch.FloatTensor(target.size(0)).fill_(real_label).to(device=device)
        labelv = Variable(label)
        errD_real = criterion(torch.squeeze(output), labelv)
        errD_real.backward()
        D_x = output.data.mean()

        # train with fake
        fake =  model_G(data)
        labelv = Variable(label.fill_(fake_label))
        output = model_D(fake.detach())
        errD_fake = criterion(torch.squeeze(output), labelv)
        errD_fake.backward()
        D_G_x1 = output.data.mean()

        errD = errD_real + errD_fake
        optimizer_D.step()

        ########################
        # update G network
        ########################
        model_G.zero_grad()
        labelv = Variable(label.fill_(real_label))
        output = model_D(fake)
        errG_GAN = criterion(torch.squeeze(output), labelv)
        errG_L1 = L1(fake.view(fake.size(0),-1), target.view(target.size(0),-1))

        errG = errG_GAN + args_dict['lamb'] * errG_L1
        errG.backward()
        D_G_x2 = output.data.mean()
        optimizer_G.step()

        # store error values
        errorG.update(errG.item(), target.size(0), history=1)
        errorD.update(errD.item(), target.size(0), history=1)
        errorG_basic.update(errG.item(), target.size(0), history=1)
        errorD_basic.update(errD.item(), target.size(0), history=1)
        errorD_real.update(errD_real.item(), target.size(0), history=1)
        errorD_fake.update(errD_fake.item(), target.size(0), history=1)

        errorD_real.update(errD_real.item(), target.size(0), history=1)
        errorD_fake.update(errD_fake.item(), target.size(0), history=1)
        errorG_GAN.update(errG_GAN.item(), target.size(0), history=1)
        errorG_R.update(errG_L1.item(), target.size(0), history=1)

        if iteration % print_interval == 0:
            print('Epoch%d[%d/%d]: Loss_D: %.4f (R %0.4f + F %0.4f) Loss_G: %0.4f (GAN %.4f + R %0.4f) D(x): %.4f D(G(z)): %.4f / %.4f' \
                % (epoch, i, len(train_loader),
                errorD_basic.avg, errorD_real.avg, errorD_fake.avg,
                errorG_basic.avg, errorG_GAN.avg, errorG_R.avg,
                D_x, D_G_x1, D_G_x2
                ))
            # plot image
            plotter_basic.g_update(errorG_basic.avg)
            plotter_basic.d_update(errorD_basic.avg)
            plotter_basic.draw(img_path + 'train_basic.png')
            # reset AverageMeter
            errorG_basic.reset()
            errorD_basic.reset()
            errorD_real.reset()
            errorD_fake.reset()
            errorG_GAN.reset()
            errorG_R.reset()

        iteration += 1

    return errorG.avg, errorD.avg


def validate(val_loader, model_G, model_D, optimizer_G, optimizer_D, epoch):
    errorG = AverageMeter()
    errorD = AverageMeter()

    model_G.eval()
    model_D.eval()

    real_label = 1
    fake_label = 0

    i = 0
    for i, (data, target) in enumerate(val_loader):
        data, target = Variable(data.to(device=device)), Variable(target.to(device=device))
        ########################
        # D network
        ########################
        # validate with real
        output = model_D(target)
        label = torch.FloatTensor(target.size(0)).fill_(real_label).to(device=device)
        labelv = Variable(label)
        errD_real = criterion(torch.squeeze(output), labelv)

        # validate with fake
        fake =  model_G(data)
        labelv = Variable(label.fill_(fake_label))
        output = model_D(fake.detach())
        errD_fake = criterion(torch.squeeze(output), labelv)

        errD = errD_real + errD_fake

        ########################
        # G network
        ########################
        labelv = Variable(label.fill_(real_label))
        output = model_D(fake)
        errG_GAN = criterion(torch.squeeze(output), labelv)
        errG_L1 = L1(fake.view(fake.size(0),-1), target.view(target.size(0),-1))

        errG = errG_GAN + args_dict['lamb'] * errG_L1

        errorG.update(errG.item(), target.size(0), history=1)
        errorD.update(errD.item(), target.size(0), history=1)

        if i == 0:
            vis_result(data.data, target.data, fake.data, epoch)

        if i % 50 == 0:
            print('Validating Epoch %d: [%d/%d]' \
                % (epoch, i, len(val_loader)))
        

    print('Validation: Loss_D: %.4f Loss_G: %.4f '\
        % (errorD.avg, errorG.avg))

    return errorG.avg, errorD.avg

def vis_result(data, target, output, epoch):
    '''visualize images for GAN'''
    img_list = []
    for i in range(32):
        l = torch.unsqueeze(torch.squeeze(data[i]), 0).cpu().numpy()
        raw = target[i].cpu().numpy()
        pred = output[i].cpu().numpy()

        raw_rgb = (np.transpose(raw, (1,2,0)).astype(np.float64) + 1) / 2.
        pred_rgb = (np.transpose(pred, (1,2,0)).astype(np.float64) + 1) / 2.

        grey = np.transpose(l, (1,2,0))
        grey = np.repeat(grey, 3, axis=2).astype(np.float64)
        img_list.append(np.concatenate((grey, raw_rgb, pred_rgb), 1))

    img_list = [np.concatenate(img_list[4*i:4*(i+1)], axis=1) for i in range(len(img_list) // 4)]
    img_list = np.concatenate(img_list, axis=0)

    plt.figure(figsize=(36,27))
    plt.imshow(img_list)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(img_path + 'epoch%d_val.png' % epoch)
    plt.clf()


In [27]:
for epoch in range(start_epoch_D, args_dict['num_epoch']):
    print('Epoch {}/{}'.format(epoch, args_dict['num_epoch'] - 1))
    print('-' * 20)
    if epoch == 0:
        val_lerrG, val_errD = validate(val_loader, generator, discriminator, optimizer_generator, optimizer_discriminator, epoch=-1)
    # train
    train_errG, train_errD = train(train_loader, generator, discriminator, optimizer_generator, optimizer_discriminator, epoch, iteration)
    # validate
    val_lerrG, val_errD = validate(val_loader, generator, discriminator, optimizer_generator, optimizer_discriminator, epoch)
    
    plotter.train_update(train_errG, train_errD)
    plotter.val_update(val_lerrG, val_errD)
    plotter.draw(img_path + 'train_val.png')
    
    if args_dict['save']:
        print('Saving check point')
        save_checkpoint({'epoch': epoch + 1,
                         'state_dict': generator.state_dict(),
                         'optimizer': optimizer_generator.state_dict(),
                         },
                         filename=model_path+'G_epoch%d.pth.tar' \
                         % epoch)
        save_checkpoint({'epoch': epoch + 1,
                         'state_dict': discriminator.state_dict(),
                         'optimizer': optimizer_discriminator.state_dict(),
                         },
                         filename=model_path+'D_epoch%d.pth.tar' \
                         % epoch)
    
    

Epoch 5/14
--------------------
Epoch5[0/157]: Loss_D: 6.5982 (R 0.1077 + F 6.4904) Loss_G: 27.8142 (GAN 10.0197 + R 0.1779) D(x): 0.9337 D(G(z)): 0.1206 / 0.0871
Validating Epoch 5: [0/24]
Validation: Loss_D: 12.3291 Loss_G: 20.3519 
Saving check point
Epoch 6/14
--------------------
Epoch6[0/157]: Loss_D: 12.8558 (R 0.0389 + F 12.8170) Loss_G: 28.2906 (GAN 12.4967 + R 0.1579) D(x): 0.9672 D(G(z)): 0.1813 / 0.1498
Validating Epoch 6: [0/24]
Validation: Loss_D: 15.4161 Loss_G: 31.5615 
Saving check point
Epoch 7/14
--------------------
Epoch7[0/157]: Loss_D: 28.9765 (R 2.8960 + F 26.0806) Loss_G: 26.4398 (GAN 13.8558 + R 0.1258) D(x): 0.5900 D(G(z)): 0.5266 / 0.4395
Validating Epoch 7: [0/24]
Validation: Loss_D: 6.6508 Loss_G: 15.4653 
Saving check point
Epoch 8/14
--------------------
Epoch8[0/157]: Loss_D: 13.0437 (R 0.3161 + F 12.7277) Loss_G: 17.9587 (GAN 6.3399 + R 0.1162) D(x): 0.8993 D(G(z)): 0.2775 / 0.1946
Validating Epoch 8: [0/24]
Validation: Loss_D: 14.4958 Loss_G: 14.1302 

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

<Figure size 3600x2700 with 0 Axes>

In [None]:
for i, (data, target) in enumerate(train_loader):
    latent_space_samples = Variable(data.to(device=device))
    target_v = Variable(target.to(device=device))
    if i>32:
        break

generated_samples = generator(latent_space_samples)

In [None]:
generated_samples = generated_samples.cpu().detach()

img_list = []
for i in range(32):
    l = torch.unsqueeze(torch.squeeze(latent_space_samples[i]), 0).cpu().numpy()
    raw = target_v[i].cpu().numpy()
    pred = generated_samples[i].cpu().numpy()

    raw_rgb = (np.transpose(raw, (1,2,0)).astype(np.float64) + 1) / 2.
    pred_rgb = (np.transpose(pred, (1,2,0)).astype(np.float64) + 1) / 2.

    grey = np.transpose(l, (1,2,0))
    grey = np.repeat(grey, 3, axis=2).astype(np.float64)
    img_list.append(np.concatenate((grey, raw_rgb, pred_rgb), 1))

img_list = [np.concatenate(img_list[4*i:4*(i+1)], axis=1) for i in range(len(img_list) // 4)]
img_list = np.concatenate(img_list, axis=0)

plt.figure(figsize=(36,27))
plt.imshow(img_list)
plt.axis('off')
plt.tight_layout()
plt.savefig(img_path + 'epoch14_val.png')
plt.clf()