In [2]:
"""
reference: https://github.com/alwynmathew/Vanilla-GAN/blob/master/GAN_1.ipynb
Changed convolution layers to FC layers as GAN. 
"""

import os #,argparse 
import gzip
import torch.nn as nn
import numpy as np
import scipy.misc
import imageio
import matplotlib.pyplot as plt

import torch, time, pickle
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
"""input arguments"""

dataset = 'mnist'
epoch = 25 #25
batch_size = 128
sample_num = 100 #16
save_dir = './models/FC'
result_dir = './results/FC'
log_dir = './logs/FC'
lrG = 0.0002
lrD = 0.0002
beta1 = 0.5
beta2 = 0.999
gpu_mode = False
model_name = 'GAN_FC'

In [10]:
"""checking arguments"""

# --save_dir
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# --result_dir
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

# --result_dir
if not os.path.exists(log_dir):
    os.makedirs(log_dir)


In [11]:
"""print network"""

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [12]:
"""save images"""

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

def imsave(images, size, path):
    image = np.squeeze(merge(images, size))
    return scipy.misc.imsave(path, image)

"""merge images"""

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    if (images.shape[3] in (3,4)):
        c = images.shape[3]
        img = np.zeros((h * size[0], w * size[1], c))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w, :] = image
        return img
    elif images.shape[3]==1:
        img = np.zeros((h * size[0], w * size[1]))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
        return img
    else:
        raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')

"""generate animation"""
        
def generate_animation(path, num):
    images = []
    for e in range(num):
        img_name = path + '_epoch%03d' % (e+1) + '.png'
        images.append(imageio.imread(img_name))
    imageio.mimsave(path + '_generate_animation.gif', images, fps=5)

In [13]:
"""plot loss"""

def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
    x = range(len(hist['D_loss']))

    y1 = hist['D_loss']
    y2 = hist['G_loss']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iter')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    path = os.path.join(path, model_name + '_loss.png')

    plt.savefig(path)

    plt.close()

In [14]:
"""initialize weights"""

def initialize_weights(net):
    for m in net.modules():

        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()


In [15]:
"""generator"""

class generator(nn.Module):

    def __init__(self):
        #print('---------- generator -------------')
        super(generator, self).__init__()

        self.input_height = 28
        self.input_width = 28
        self.input_dim = 62
        self.output_dim = 1


        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 784),
            nn.Sigmoid(),
        )
        #utils.
        initialize_weights(self)
    
    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 1, 28, 28)

        return x

In [48]:
"""discriminator"""

class discriminator(nn.Module):

    def __init__(self):
        super(discriminator, self).__init__()

        self.input_height = 28
        self.input_width = 28
        self.input_dim = 1
        self.output_dim = 1

        self.fc = nn.Sequential(
            nn.Linear(784, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.LeakyReLU(0.2),
            nn.Linear(2048,1),
            nn.Sigmoid(),
        )
        #utils.
        initialize_weights(self)
        
    def forward(self, input):
        x = input.view(-1, 784)
        #print("x11",x.shape)
        x = self.fc(x)
        #print("x22", x.shape)
        x = self.fc2(x)
        #print("x333", x.shape)

        return x

In [50]:
class GAN():
    def __init__(self):

        # networks init
        self.G = generator()
        self.D = discriminator()
        #print('---------- GAN -------------')
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=lrG, betas=(beta1, beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=lrD, betas=(beta1, beta2))


        self.G.to(device)
        self.D.to(device)
        self.BCE_loss = nn.BCELoss().to(device)


        print('---------- Networks architecture -------------')
        #utils.
        print_network(self.G)
        print('-----------------------------------------------')
        #utils.
        print_network(self.D)
        print('-----------------------------------------------')

        # load dataset

        self.data_loader = DataLoader(datasets.MNIST('./data', train=True, download=True,
                                                                      transform=transforms.Compose(
                                                                          [transforms.ToTensor()])),
                                                       batch_size=batch_size, shuffle=True)

        
        print ("Size of %s data loader : " % (dataset), self.data_loader.dataset.__len__())
        print('-----------------------------------------------')
        self.z_dim = 62

        # fixed noise
        self.sample_z_ = Variable(torch.rand((batch_size, self.z_dim))).to(device)


    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []


        self.y_real_, self.y_fake_ = Variable(torch.ones(batch_size, 1)).to(device), Variable(torch.zeros(batch_size, 1)).to(device)


        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epochs in range(epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // batch_size:
                    break

                z_ = torch.rand((batch_size, self.z_dim))

                x_, z_ = Variable(x_).to(device), Variable(z_).to(device)


                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.data[0])

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_)
                D_fake = self.D(G_)
                G_loss = self.BCE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.data[0])

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epochs + 1), (iter + 1), self.data_loader.dataset.__len__() // batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epochs+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        #utils.
        generate_animation(result_dir + '/' + dataset + '/' + model_name + '/' + model_name,
                                 epoch)
        #utils.
        loss_plot(self.train_hist, os.path.join(save_dir, dataset, model_name), model_name)

    def visualize_results(self, epochs, fix=True):
        self.G.eval()

        if not os.path.exists(result_dir + '/' + dataset + '/' + model_name):
            os.makedirs(result_dir + '/' + dataset + '/' + model_name)

        tot_num_samples = min(sample_num, batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        with torch.no_grad():
            
            if fix:
                """ fixed noise """
                samples = self.G(self.sample_z_)
            else:
                """ random noise """

                sample_z_ = Variable(torch.rand((batch_size, self.z_dim))).to(device)

                samples = self.G(sample_z_)


            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)


        #utils.
        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                          result_dir + '/' + dataset + '/' + model_name + '/' + model_name + '_epoch%03d' % epochs + '.png')
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)

        return x


    def save(self):
        save_dir_full = os.path.join(save_dir, dataset, model_name)

        if not os.path.exists(save_dir_full):
            os.makedirs(save_dir_full)

        torch.save(self.G.state_dict(), os.path.join(save_dir_full, model_name + '_G.pkl'))
        torch.save(self.D.state_dict(), os.path.join(save_dir_full, model_name + '_D.pkl'))

        with open(os.path.join(save_dir_full, model_name + '_history.pkl'), 'wb') as f:
            pickle.dump(self.train_hist, f)

    def load(self):
        save_dir_full = os.path.join(save_dir, dataset, model_name)

        self.G.load_state_dict(torch.load(os.path.join(save_dir_full, model_name + '_G.pkl')))
        self.D.load_state_dict(torch.load(os.path.join(save_dir_full, model_name + '_D.pkl')))


In [51]:
"""run GAN"""
gan = GAN()


---------- Networks architecture -------------
generator(
  (fc): Sequential(
    (0): Linear(in_features=62, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=2048, out_features=784, bias=True)
    (7): Sigmoid()
  )
)
Total number of parameters: 3776272
-----------------------------------------------
discriminator(
  (fc): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (fc2): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=2048, out_features=1, bias

In [52]:


"""train GAN"""

gan.train()
print(" [*] Training finished!")



training start!!




Epoch: [ 1] [ 100/ 468] D_loss: 1.00089478, G_loss: 1.21788812
Epoch: [ 1] [ 200/ 468] D_loss: 0.95763034, G_loss: 1.34242034
Epoch: [ 1] [ 300/ 468] D_loss: 1.11121500, G_loss: 1.31351554
Epoch: [ 1] [ 400/ 468] D_loss: 0.96685958, G_loss: 1.31351531


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  


Epoch: [ 2] [ 100/ 468] D_loss: 1.01915467, G_loss: 1.27695477
Epoch: [ 2] [ 200/ 468] D_loss: 0.97965437, G_loss: 1.36600912
Epoch: [ 2] [ 300/ 468] D_loss: 1.02389574, G_loss: 1.11102450
Epoch: [ 2] [ 400/ 468] D_loss: 1.03224218, G_loss: 1.08895791
Epoch: [ 3] [ 100/ 468] D_loss: 0.97423315, G_loss: 1.37309456
Epoch: [ 3] [ 200/ 468] D_loss: 1.02497053, G_loss: 1.46939242
Epoch: [ 3] [ 300/ 468] D_loss: 0.92513692, G_loss: 1.42075467
Epoch: [ 3] [ 400/ 468] D_loss: 1.04585361, G_loss: 1.32140732
Epoch: [ 4] [ 100/ 468] D_loss: 1.05246997, G_loss: 1.45350337
Epoch: [ 4] [ 200/ 468] D_loss: 1.12713385, G_loss: 1.05427825
Epoch: [ 4] [ 300/ 468] D_loss: 1.11999846, G_loss: 1.89888525
Epoch: [ 4] [ 400/ 468] D_loss: 1.05544233, G_loss: 1.35864329
Epoch: [ 5] [ 100/ 468] D_loss: 1.10269225, G_loss: 1.55846763
Epoch: [ 5] [ 200/ 468] D_loss: 0.97971928, G_loss: 1.63332987
Epoch: [ 5] [ 300/ 468] D_loss: 0.91346729, G_loss: 1.61674762
Epoch: [ 5] [ 400/ 468] D_loss: 0.92214721, G_loss: 1.3

In [53]:
"""test GAN"""

# visualize learned generator
gan.visualize_results(epoch)
print(" [*] Testing finished!")

 [*] Testing finished!


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  


In [1]:
from IPython.display import HTML
HTML('<img src="./results/GAN/GAN_FC_generate_animation.gif">')