In [None]:
import os
import numpy as np
import errno
import torchvision.utils as vutils
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter
from IPython import display
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available(), device

In [None]:
class Logger:

    def __init__(self, model_name, data_name):
        self.model_name = model_name
        self.data_name = data_name

        self.comment = f'{model_name}_{data_name}'
        self.data_subdir = f'{model_name}/{data_name}'

        # TensorBoard
        self.writer = SummaryWriter(comment=self.comment)

    def log(self, d_error, g_error, epoch, n_batch, num_batches):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()

        step = Logger._step(epoch, n_batch, num_batches)
        self.writer.add_scalar(f'{self.comment}/D_error', d_error, step)
        self.writer.add_scalar(f'{self.comment}/G_error', g_error, step)

    def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
        '''
        input images are expected in format (NCHW)
        '''
        if type(images) == np.ndarray:
            images = torch.from_numpy(images)
        
        if format=='NHWC':
            images = images.transpose(1, 3)
        

        step = Logger._step(epoch, n_batch, num_batches)
        img_name = f'{self.comment}/images'

        # Make horizontal grid from image tensor
        horizontal_grid = vutils.make_grid(images, normalize=normalize, scale_each=True)
        # Make vertical grid from image tensor
        nrows = int(np.sqrt(num_images))
        grid = vutils.make_grid(images, nrow=nrows, normalize=True, scale_each=True)

        # Add horizontal images to tensorboard
        self.writer.add_image(img_name, horizontal_grid, step)

        # Save plots
        self.save_torch_images(horizontal_grid, grid, epoch, n_batch)

    def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
        out_dir = f'./data/images/{self.data_subdir}'
        Logger._make_dir(out_dir)

        # Plot and save horizontal
        fig = plt.figure(figsize=(16, 16))
        plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
        plt.axis('off')
        if plot_horizontal:
            display.display(plt.gcf())
        self._save_images(fig, epoch, n_batch, 'hori')
        plt.close()

        # Save squared
        fig = plt.figure()
        plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
        plt.axis('off')
        self._save_images(fig, epoch, n_batch)
        plt.close()

    def _save_images(self, fig, epoch, n_batch, comment=''):
        out_dir = f'./data/images/{self.data_subdir}'
        Logger._make_dir(out_dir)
        fig.savefig(f'{out_dir}/{comment}_epoch_{epoch}_batch_{n_batch}.png')

    def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()
        if isinstance(d_pred_real, torch.autograd.Variable):
            d_pred_real = d_pred_real.data
        if isinstance(d_pred_fake, torch.autograd.Variable):
            d_pred_fake = d_pred_fake.data

        print(f'Epoch: [{epoch}/{num_epochs}], Batch Num: [{n_batch}/{num_batches}]')
        print(f'Discriminator Loss: {d_error:.4f}, Generator Loss: {g_error:.4f}')
        print(f'D(x): {d_pred_real.mean():.4f}, D(G(z)): {d_pred_fake.mean():.4f}')

    def save_models(self, generator, discriminator, epoch):
        out_dir = f'../working/models/{self.data_subdir}'
        Logger._make_dir(out_dir)
        torch.save(generator.state_dict(), f'{out_dir}/G_epoch_{epoch}')
        torch.save(discriminator.state_dict(), f'{out_dir}/D_epoch_{epoch}')

    def close(self):
        self.writer.close()

    # Private Functionality
    @staticmethod
    def _step(epoch, n_batch, num_batches):
        return epoch * num_batches + n_batch

    @staticmethod
    def _make_dir(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

In [None]:
def mnist_data():
    compose = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
        ])
    out_dir = '../working'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
# Num batches
num_batches = len(data_loader)

In [None]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

discriminator = DiscriminatorNet()
# discriminator = discriminator.to(device)

In [None]:
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [None]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

generator = GeneratorNet()
# generator = generator.to(device)

In [None]:
def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

In [None]:
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

In [None]:
loss = nn.BCELoss()

In [None]:
def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
#   data = data.to(device)
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
#   data = data.to(device)
    return data

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, ones_target(N))
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error and predictions for real and fake inputs
    return error_real + error_fake, prediction_real, prediction_fake

In [None]:
def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, ones_target(N))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

In [None]:
num_test_samples = 8
test_noise = noise(num_test_samples)

In [None]:
# Create logger instance
logger = Logger(model_name='VGAN', data_name='MNIST')
# Total number of epochs to train
num_epochs = 2
for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0)
        
        # 1. Train Discriminator
        real_data = Variable(images_to_vectors(real_batch))
#         real_data = real_data.to(device)
        
        # Generate fake data and detach 
        # (so gradients are not calculated for generator)
        fake_data = generator(noise(N)).detach()
#         fake_data = fake_data.to(device)
        
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(N))
        
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if num_epochs % 10 == 0 and n_batch == 0:
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(epoch, num_epochs, n_batch, num_batches,
                                  d_error, g_error, d_pred_real, d_pred_fake)

In [None]:
# def mnist_data():
#     compose = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize((0.5,), (0.5,))
#         ])
#     out_dir = '../working'
#     return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

# # Load data
# data = mnist_data()
# # Create loader with data, so that we can iterate over it
# data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
# # Num batches
# num_batches = len(data_loader)

In [None]:
# class GAN():
#     def __init__(self):
#         self.img_rows = 28
#         self.img_cols = 28
#         self.channels = 1
#         self.img_shape = (self.img_rows, self.img_cols, self.channels)

#         optimizer = torch.optim.Adam(0.0002, 0.5)

#         # Build and compile the discriminator
#         self.discriminator = self.build_discriminator()
#         self.discriminator.compile(loss='binary_crossentropy',
#                 optimizer=optimizer,
#             metrics=['accuracy'])

#         # Build and compile the generator
#         self.generator = self.build_generator()
#         self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

#         # The generator takes noise as input and generated imgs
#         z = Input(shape=(100,))
#         img = self.generator(z)

#         # For the combined model we will only train the generator
#         self.discriminator.trainable = False

#         # The valid takes generated images as input and determines validity
#         valid = self.discriminator(img)

#         # The combined model  (stacked generator and discriminator) takes
#         # noise as input => generates images => determines validity
#         self.combined = Model(z, valid)
#         self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

#     def build_generator(self):

#         noise_shape = (100,)

#         model = Sequential()

#         model.add(Dense(256, input_shape=noise_shape))
#         model.add(LeakyReLU(alpha=0.2))
#         model.add(BatchNormalization(momentum=0.8))
#         model.add(Dense(512))
#         model.add(LeakyReLU(alpha=0.2))
#         model.add(BatchNormalization(momentum=0.8))
#         model.add(Dense(1024))
#         model.add(LeakyReLU(alpha=0.2))
#         model.add(BatchNormalization(momentum=0.8))
#         model.add(Dense(np.prod(self.img_shape), activation='tanh'))
#         model.add(Reshape(self.img_shape))

#         model.summary()

#         noise = Input(shape=noise_shape)
#         img = model(noise)

#         return Model(noise, img)

#     def build_discriminator(self):

#         img_shape = (self.img_rows, self.img_cols, self.channels)

#         model = Sequential()

#         model.add(Flatten(input_shape=img_shape))
#         model.add(Dense(512))
#         model.add(LeakyReLU(alpha=0.2))
#         model.add(Dense(256))
#         model.add(LeakyReLU(alpha=0.2))
#         model.add(Dense(1, activation='sigmoid'))
#         model.summary()

#         img = Input(shape=img_shape)
#         validity = model(img)

#         return Model(img, validity)

#     def train(self, epochs, batch_size=128, save_interval=50):

#         # Load the dataset
#         X_train = data

#         # Rescale -1 to 1
#         X_train = (X_train.astype(np.float32) - 127.5) / 127.5
#         X_train = np.expand_dims(X_train, axis=3)

#         half_batch = int(batch_size / 2)

#         for epoch in range(epochs):

#             # ---------------------
#             #  Train Discriminator
#             # ---------------------

#             # Select a random half batch of images
#             idx = np.random.randint(0, X_train.shape[0], half_batch)
#             imgs = X_train[idx]

#             noise = np.random.normal(0, 1, (half_batch, 100))

#             # Generate a half batch of new images
#             gen_imgs = self.generator.predict(noise)

#             # Train the discriminator
#             d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
#             d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
#             d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


#             # ---------------------
#             #  Train Generator
#             # ---------------------

#             noise = np.random.normal(0, 1, (batch_size, 100))

#             # The generator wants the discriminator to label the generated samples
#             # as valid (ones)
#             valid_y = np.array([1] * batch_size)

#             # Train the generator
#             g_loss = self.combined.train_on_batch(noise, valid_y)

#             # Plot the progress
#             print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss}]")

#             # If at save interval => save generated image samples
#             if epoch % save_interval == 0:
#                 self.save_imgs(epoch)

#     def save_imgs(self, epoch):
#         r, c = 5, 5
#         noise = np.random.normal(0, 1, (r * c, 100))
#         gen_imgs = self.generator.predict(noise)

#         # Rescale images 0 - 1
#         gen_imgs = 0.5 * gen_imgs + 0.5

#         fig, axs = plt.subplots(r, c)
#         cnt = 0
#         for i in range(r):
#             for j in range(c):
#                 axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
#                 axs[i,j].axis('off')
#                 cnt += 1
#         fig.savefig(f"../working/mnist_{epoch}.png")
#         plt.close()

# gan = GAN()

In [None]:
# gan.train(epochs=30000, batch_size=32, save_interval=200)