In [None]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import numpy as np
import matplotlib.pyplot as plt
import math
import os
% matplotlib inline
import skimage
from skimage import img_as_float
from skimage import io
import torch.utils.data as Utils # for Dataset module
from skimage.measure import block_reduce
import torch.nn as nn
import torch.nn.parallel


import IPython.display
from IPython.display import Image
couchpath = 'couch-gag'

if (torch.cuda.is_available()):
    print("We are using CUDA")
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3, 4, 5, 6, 7'
print("Let's use", torch.cuda.device_count(), "GPUs! :D")


In [None]:
from utilsGAN import Logger


In [None]:
# compose = transforms.Compose(
#         [transforms.ToTensor(),
#          transforms.Normalize((.5, .5, .5), (.5, .5, .5))
#         ])
    
data = []

# read each image, transform to tensor, and normalize

for _, image in enumerate(os.listdir(couchpath)):
    path1 = couchpath+str('/')+image
    image = io.imread(path1)
    image = block_reduce(image, block_size=(5, 5, 1), func=np.mean).reshape(3,72,128)
    data.append(image)






In [None]:

# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True)
# Num batches
num_batches = len(data_loader)
print(num_batches)

In [None]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 27648
        ngf = 256
        n_out = 1
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     n_features, ngf * 4, 4, 8, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 4, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        
    def forward(self, input):
        output = self.main(input)
        return output
discriminator = torch.nn.DataParallel(DiscriminatorNet()).cuda()

In [None]:
def images_to_vectors(images):
    return images.view(images.size(0), 27648)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 3, 72, 128)

In [None]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_out = 27648
        ndf = 256
        
        
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 4, 4, 4, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 4, n_out, 4, 8, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)

        
        
generator = torch.nn.DataParallel(GeneratorNet()).cuda()

In [None]:
def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 3,72,128)).cuda()
    return n

In [None]:
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.9, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.9, 0.999))

In [None]:
loss = nn.BCELoss()
def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1)).cuda()
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1)).cuda()
    return data

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, ones_target(N) )
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error and predictions for real and fake inputs
    return error_real + error_fake, prediction_real, prediction_fake

In [None]:
def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, ones_target(N))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

In [None]:
num_test_samples = 81
test_noise = noise(num_test_samples)


In [None]:
# Create logger instance
try:
    logger = logger.load_models(G_epoch_0_batch_2, D_epoch_0_batch_2, 0, 2)
    print("Loaded checkpoint")
except:
    print("No checkpoint Loaded")
    logger = Logger(model_name='SimpsonsGAN', data_name='couch-gag')


# Total number of epochs to train
num_epochs = 2
for epoch in range(num_epochs):
    for n_batch, (real_batch) in enumerate(data_loader):
        N = real_batch.size(0)
        # 1. Train Discriminator
        print(real_batch.size())
        real_data = Variable(images_to_vectors(real_batch)).cuda()
        print(real_data.size())
        # Generate fake data and detach 
        # (so gradients are not calculated for generator)
        fake_data = generator(noise(N)).detach()
        print(fake_data.size())
        # Train D
        d_error, d_pred_real, d_pred_fake = \
              train_discriminator(d_optimizer, real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(N))
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 1 == 0: 
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples, 
                epoch, n_batch, num_batches
            );
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
            
            logger.save_models(generator, discriminator, epoch, n_batch)
            print("Checkpoint saved epoch "+ str(epoch) + " batch " + str(n_batch))