In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import clear_output
from IPython.display import HTML
import pickle
import time
# from google.colab import files

In [2]:
import numpy as np
from pathlib import Path
from torchvision import transforms
import torch
from torch import nn
import torchvision.utils as vutils

In [3]:
from utils import visual_data, load_cifar10
from networks import Generator, Discriminator, weights_init, DiscriminatorMiniBatchDiscrimination, DCGAN
from trainer import  update_params
from losses import D_loss, G_featMatch_loss, G_loss

In [4]:
!nvidia-smi

Wed Apr 15 19:15:11 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 440.64.00    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 960     Off  | 00000000:01:00.0  On |                  N/A |
|  0%   44C    P5    14W / 128W |    547MiB /  4040MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [5]:
# Setting up constants
# device
ngpu = 1

# single image
imageSize = 64
imageMean = (0.4923172 , 0.48307145, 0.4474483)
imageStd = (0.24041407, 0.23696952, 0.25565723)

# data loader
numWorkers = 3
batchSize = 16

# Network Arch
nc = 3 # Number of channels
nz = 100 # Latent vector
ngf = 64 # relates to the depth of feature maps carried through the generator
ndf = 32 # sets the depth of feature maps propagated through the discriminator

# Training
num_epochs = 5

# Adam Optimizer
lr = .0002
beta1 = .5

# convention of the labeling for the real and the fake datasets
## one-sided label smoothing
real_label = .9
fake_label = 0

# label smoothing
## insdead of real label=.9 give uniform between .8, 1
## insdead of fake label=0 give uniform between 0, .2
label_smoothing = True

# flip labels, with probability pFlip, flip the labels passed to the discriminator
pFlip = 0.05

# if to use last two layers' features for feature matching
double_layer = False

# inner loop repetitions
D_inner_repeat = 1
G_inner_repeat = 1

# whether to do minibatch normalization
miniBatchDiscrimination = True

In [6]:
# folder to store/load data
dataFolder = Path("./data")
# Decide which device to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [7]:
torch.cuda.is_available()

True

In [8]:
cifarFolder = dataFolder/"CIFAR10"
# prepare data for loading
tsfms = transforms.Compose([
    transforms.Resize(imageSize), 
    transforms.ToTensor(),
    transforms.Normalize(imageMean, imageStd)
])
trainLoader, test_loader = load_cifar10(cifarFolder, tsfms, batchSize, numWorkers)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
# init D and G network
gen = Generator(ngpu, nz, ngf, nc).to(device)
gen.apply(weights_init);

if not miniBatchDiscrimination:
    disc = Discriminator(ngpu, nc, ndf).to(device)
else: 
    disc = DiscriminatorMiniBatchDiscrimination(ngpu, nc, ndf).to(device)

disc.apply(weights_init);

In [10]:
dcgan = DCGAN(gen, disc, device, real_label, fake_label, pFlip, label_smoothing, double_layer)

In [11]:
# setup optmization 
optimizerD = torch.optim.Adam(dcgan.discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# optimizerD = torch.optim.SGD(disc.parameters(), lr=0.1, momentum=0.9)
optimizerG = torch.optim.Adam(dcgan.generator.parameters(), lr=lr, betas=(beta1, 0.999))

In [12]:
# fixed noise z for viusalization of the progress of the training
fixed_noise = torch.randn(batchSize, dcgan.generator.nz, 1, 1, device=device)

# training loop 
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
g=0
print("Starting Training Loop...")

Starting Training Loop...


In [13]:
num_epochs = 30

In [16]:
g=0
# For each epoch
current = time.time()
# for each epoch
for epoch in range(num_epochs):
    # For each batch
    for i, data in enumerate(trainLoader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################   ` 

        # real batch and fake batch
        real_batch = data[0].to(device)
        fake_batch = dcgan.fake_batch(real_batch.size(0))
#         fake_batch = batch_fake_samples(real_batch.size(0), gen, nz, device)
        
#         # forward pass
#         output_real, f1real, f2real = dcgan.discriminate(real_batch)
#         output_fake, _, _ = dcgan.discriminate(fake_batch.detach())
        
#         # loss for D
#         errD = dcgan.discriminator_loss(output_real, output_fake, 
#                                         real_label, fake_label, 
#                                         pFlip, label_smoothing,
#                                         device)
        errD = dcgan.discriminator_loss(real_batch, fake_batch)
        
        # backward pass and optimization step
        update_params(optimizerD, errD)
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        # for _ in range(G_inner_repeat):
        if errD.item() < .8 or g==0:
            g += 1
            # fake batch 
            fake_batch = dcgan.fake_batch(real_batch.size(0))
            
            # forward pass
#             output, f1fake, f2fake = dcgan.discriminate(fake_batch)
            
#             # loss function
#             errG = dcgan.generator_loss(
#                 output, real_label, label_smoothing, device,
#                 f1fake, f2fake, f1real.detach(), f2real.detach(), double_layer
#             )
            errG = dcgan.generator_loss(real_batch, fake_batch)

#             errG += .1*G_loss(output, real_label, label_smoothing, device)
#             errG = G_featMatch_loss(f1fake, f2fake, 
#                                     f1real.detach(), f2real.detach(), double_layer)            

            
            # backward pass and optimization step
            update_params(optimizerG, errG)

        ###########################
        # (3) save and print progress
        ###########################
        # for progress prints
        D_x = output_real.mean().item()
        D_G_z1 = output_fake.mean().item()
        D_G_z2 = output.mean().item()
        
        # Output training stats
        if i % 50 == 0:
            print('[{:0>2}/{:0>2}][{:0>4}/{:0>4}]  Loss_D: {:.3f}  Loss_G: {:.3f}  D(x): {:.3f}  D(G(z)): {:.3f} / {:.3f}  t={:6.3f}  iterG={}/50'.format(
                epoch, num_epochs, i, len(trainLoader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, time.time()-current, g, prec=3))
            g = 0
            current = time.time()

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(trainLoader)-1)):
            clear_output()
            with torch.no_grad():
                fake = gen(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

[00/30][0550/3125]  Loss_D: 0.768  Loss_G: 29245.742  D(x): 0.385  D(G(z)): 0.421 / 0.370  t= 2.179  iterG=42/50
[00/30][0600/3125]  Loss_D: 0.682  Loss_G: 31667.564  D(x): 0.385  D(G(z)): 0.421 / 0.370  t= 2.246  iterG=47/50
[00/30][0650/3125]  Loss_D: 0.668  Loss_G: 30688.010  D(x): 0.385  D(G(z)): 0.421 / 0.370  t= 2.121  iterG=42/50
[00/30][0700/3125]  Loss_D: 0.687  Loss_G: 31329.967  D(x): 0.385  D(G(z)): 0.421 / 0.370  t= 2.085  iterG=41/50
[00/30][0750/3125]  Loss_D: 0.790  Loss_G: 30501.301  D(x): 0.385  D(G(z)): 0.421 / 0.370  t= 2.084  iterG=40/50


Process Process-9:
Process Process-8:
Traceback (most recent call last):
Process Process-7:
Traceback (most recent call last):
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/pytho

KeyboardInterrupt: 

  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/rainbow/miniconda3/envs/dcgan_faster/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt


In [None]:
len(G_losses)

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D", alpha=.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
trainIter = iter(trainLoader)

In [None]:
real_batch = next(trainIter)

noise = torch.randn(64, nz, 1, 1, device=device)
fake = gen(noise).detach().cpu()
fake = vutils.make_grid(fake, padding=2, normalize=True)

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(fake,(1,2,0)))
plt.show()

In [None]:
# print(optimizerD, optimizerG, trainLoader.batch_size, ngf, ndf, real_label)

In [None]:
# genParams = gen.state_dict()
# discParams = disc.state_dict()

In [None]:
def dumpAndDL(variable, fileName):
    with open(fileName, 'wb') as f:
        pickle.dump(variable,
                    f)
        # files.download(fileName)
        
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

In [None]:
# dumpAndDL(
#     {'genParams': genParams,
#      'discParams': discParams,
#     },
#     'bestParams.p'
# )

In [None]:
dumpAndDL(
    {'G_losses': G_losses,
     'D_losses':D_losses,
     },
     'losses.p'
       )

In [None]:
n=len(img_list)

for idx, chunk in enumerate(chunks(img_list, n)):
    fileName = 'img_list_{}.p'.format(idx)
    dumpAndDL(
        chunk,
        fileName
    )

In [None]:
# # load params to networks
# fileName = 'bestParams.p'
# netParams = None
# with open(fileName, 'rb') as f:
#      netParams = pickle.load(f)
# gen.load_state_dict(netParams['genParams'])
# disc.load_state_dict(netParams['discParams'])

In [None]:
# # load params to networks
# fileNames = ['img_list_0.p', 'img_list_1.p', 'img_list_2.p', 'img_list_3.p']

# img_list = []
# for fileName in fileNames:
#     with open(fileName, 'rb') as f:
#         img_list += pickle.load(f)['img_list']


In [None]:
# len(img_list)

In [None]:
# #%%capture animation
# fig = plt.figure(figsize=(10, 10))
# plt.axis("off")
# ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list[:100]]
# ani = animation.ArtistAnimation(fig, ims, interval=100, repeat_delay=1000, blit=True)

# HTML(ani.to_jshtml())

# # writer = animation.writers['ffmpeg']
# # ani.save('im.mp4', writer=writer, dpi=100)