In [1]:
import os
import glob
import imageio
import torch
import numpy as np
from tqdm import tqdm
from scipy.stats import entropy
import torch.nn as nn
from dataset import Dataset
from architectures.architecture_dcgan import DCGAN_D, DCGAN_G
from architectures.architecture_resnet import ResNetGan_D, ResNetGan_G
from architectures.architecture_wavegan import Wave_D, Wave_G
from trainers_advanced.trainer import Trainer
from utils import save, load

In [2]:
torch.backends.cudnn.benchmark = True

In [3]:
dir_name = '/home/tone/Workspace/128'
basic_types = None

In [4]:
lr_D, lr_G, bs = 0.0002, 0.0002, 16
sz, nc, nz, ngf, ndf = 128, 3, 100, 64, 64
use_sigmoid, spectral_norm, attention_layer = False, True, 256

In [5]:
data = Dataset(dir_name, basic_types)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
netD = DCGAN_D(sz, nc, ndf, use_sigmoid).to(device)
netG = DCGAN_G(sz, nz, nc, ngf).to(device)

In [7]:
trn_dl = data.get_loader(sz, bs)

In [8]:
#trainer = Trainer('SGAN', netD, netG, device, trn_dl, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = None, use_gradient_penalty = False, loss_interval = 130, image_interval = 130, save_img_dir = 'saved_images')
trainer = Trainer('WGAN', netD, netG, device, trn_dl, lr_D = lr_D, lr_G = lr_G, resample = True, weight_clip = 0.01, use_gradient_penalty = False, loss_interval = 150, image_interval = 300, save_img_dir = 'saved_imges')

In [9]:
trainer.train(50)
save('saved/cur_state_1.state', netD, netG, trainer.optimizerD, trainer.optimizerG)
torch.save(netG.state_dict(), 'saved/cur_state_G_1.pth')

  0%|          | 0/130 [00:00<?, ?it/s]

[1/50] [1/130] errD : 0.0820, errG : 0.0027


100%|██████████| 130/130 [00:08<00:00, 15.50it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.47it/s]

[2/50] [1/130] errD : -0.4235, errG : 0.9175


100%|██████████| 130/130 [00:07<00:00, 16.69it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.29it/s]

[3/50] [1/130] errD : -0.5053, errG : -0.1776


100%|██████████| 130/130 [00:07<00:00, 16.66it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.47it/s]

[4/50] [1/130] errD : -0.3886, errG : 0.4840


100%|██████████| 130/130 [00:07<00:00, 16.64it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.46it/s]

[5/50] [1/130] errD : -0.0676, errG : -0.4194


100%|██████████| 130/130 [00:07<00:00, 16.63it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.35it/s]

[6/50] [1/130] errD : -0.3531, errG : 0.3350


100%|██████████| 130/130 [00:07<00:00, 16.57it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.47it/s]

[7/50] [1/130] errD : -0.2939, errG : 0.8011


100%|██████████| 130/130 [00:07<00:00, 16.57it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.54it/s]

[8/50] [1/130] errD : -0.6206, errG : -0.1267


100%|██████████| 130/130 [00:07<00:00, 16.58it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.34it/s]

[9/50] [1/130] errD : -0.5872, errG : 0.8305


100%|██████████| 130/130 [00:07<00:00, 16.56it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.47it/s]

[10/50] [1/130] errD : -0.6919, errG : -0.7463


100%|██████████| 130/130 [00:07<00:00, 16.61it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.38it/s]

[11/50] [1/130] errD : -0.3455, errG : 0.2919


100%|██████████| 130/130 [00:07<00:00, 16.54it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.46it/s]

[12/50] [1/130] errD : -0.3770, errG : 0.5099


100%|██████████| 130/130 [00:07<00:00, 16.52it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.36it/s]

[13/50] [1/130] errD : -0.2495, errG : 0.1345


100%|██████████| 130/130 [00:07<00:00, 16.50it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.38it/s]

[14/50] [1/130] errD : -0.2895, errG : 0.6472


100%|██████████| 130/130 [00:07<00:00, 16.53it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.50it/s]

[15/50] [1/130] errD : -0.3154, errG : 0.2076


100%|██████████| 130/130 [00:07<00:00, 16.53it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.38it/s]

[16/50] [1/130] errD : -0.3088, errG : 0.9643


100%|██████████| 130/130 [00:07<00:00, 16.51it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.42it/s]

[17/50] [1/130] errD : -0.0162, errG : -0.9008


100%|██████████| 130/130 [00:07<00:00, 16.51it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.42it/s]

[18/50] [1/130] errD : -0.1864, errG : 0.1765


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.45it/s]

[19/50] [1/130] errD : -0.3914, errG : 0.8724


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.34it/s]

[20/50] [1/130] errD : -0.1205, errG : -0.3562


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.43it/s]

[21/50] [1/130] errD : -0.0457, errG : -0.7677


100%|██████████| 130/130 [00:07<00:00, 16.53it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.39it/s]

[22/50] [1/130] errD : -0.0667, errG : -0.2462


100%|██████████| 130/130 [00:07<00:00, 16.46it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.41it/s]

[23/50] [1/130] errD : -0.1739, errG : 0.1955


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.40it/s]

[24/50] [1/130] errD : -0.3963, errG : -0.2909


100%|██████████| 130/130 [00:07<00:00, 16.50it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.49it/s]

[25/50] [1/130] errD : -0.4672, errG : 0.8165


100%|██████████| 130/130 [00:07<00:00, 16.50it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.29it/s]

[26/50] [1/130] errD : -0.6916, errG : 0.9856


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.48it/s]

[27/50] [1/130] errD : -0.2496, errG : 0.7913


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.37it/s]

[28/50] [1/130] errD : -0.4735, errG : -0.2889


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.43it/s]

[29/50] [1/130] errD : -0.0585, errG : -0.1037


100%|██████████| 130/130 [00:07<00:00, 16.51it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.12it/s]

[30/50] [1/130] errD : -0.0782, errG : 0.3964


100%|██████████| 130/130 [00:07<00:00, 16.46it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.50it/s]

[31/50] [1/130] errD : -0.1226, errG : 0.2314


100%|██████████| 130/130 [00:07<00:00, 16.52it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.23it/s]

[32/50] [1/130] errD : -0.2833, errG : 0.6910


100%|██████████| 130/130 [00:07<00:00, 16.48it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.53it/s]

[33/50] [1/130] errD : -0.3512, errG : 0.5267


100%|██████████| 130/130 [00:07<00:00, 16.50it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.36it/s]

[34/50] [1/130] errD : -0.3077, errG : 0.8237


100%|██████████| 130/130 [00:07<00:00, 16.48it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.54it/s]

[35/50] [1/130] errD : -0.1495, errG : -0.1891


100%|██████████| 130/130 [00:07<00:00, 16.51it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.49it/s]

[36/50] [1/130] errD : -0.4496, errG : 0.8785


100%|██████████| 130/130 [00:07<00:00, 16.48it/s]
  2%|▏         | 3/130 [00:00<00:18,  6.70it/s]

[37/50] [1/130] errD : -1.4072, errG : -0.0532


100%|██████████| 130/130 [00:07<00:00, 16.52it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.59it/s]

[38/50] [1/130] errD : -2.1886, errG : 1.3033


100%|██████████| 130/130 [00:07<00:00, 16.45it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.57it/s]

[39/50] [1/130] errD : -2.8685, errG : 1.4210


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.50it/s]

[40/50] [1/130] errD : -2.8160, errG : 1.4377


100%|██████████| 130/130 [00:07<00:00, 16.49it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.64it/s]

[41/50] [1/130] errD : -2.7869, errG : 1.3632


100%|██████████| 130/130 [00:07<00:00, 16.48it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.39it/s]

[42/50] [1/130] errD : -1.4251, errG : 1.0027


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.45it/s]

[43/50] [1/130] errD : -0.5945, errG : 1.2066


100%|██████████| 130/130 [00:07<00:00, 16.48it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.42it/s]

[44/50] [1/130] errD : -2.0961, errG : 0.9168


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.29it/s]

[45/50] [1/130] errD : -0.8620, errG : -0.4396


100%|██████████| 130/130 [00:07<00:00, 16.46it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.64it/s]

[46/50] [1/130] errD : -1.3289, errG : 0.3769


100%|██████████| 130/130 [00:07<00:00, 16.45it/s]
  2%|▏         | 3/130 [00:00<00:20,  6.29it/s]

[47/50] [1/130] errD : -1.6598, errG : 1.3597


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.55it/s]

[48/50] [1/130] errD : -1.3724, errG : -0.0315


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.36it/s]

[49/50] [1/130] errD : -1.5198, errG : 0.0104


100%|██████████| 130/130 [00:07<00:00, 16.45it/s]
  2%|▏         | 3/130 [00:00<00:19,  6.54it/s]

[50/50] [1/130] errD : -1.5998, errG : 1.4096


100%|██████████| 130/130 [00:07<00:00, 16.47it/s]


In [10]:
N    = 2080
cuda = True

# Set up dtype
if cuda:
    dtype = torch.cuda.FloatTensor
else:
    if torch.cuda.is_available():
        print("WARNING: You have a CUDA device, so you should probably set cuda=True")
    dtype = torch.FloatTensor

In [11]:
# Get predictions
preds = np.zeros((N, 1))

for i, data in enumerate(tqdm(trn_dl)):
    real_images = data[0].to(device)
    #print('pred ', i ,'=',netD(real_images)[0][0])
    preds[i*bs:i*bs + bs] = netD(real_images)[0][0].cpu().detach().numpy()
    

100%|██████████| 130/130 [00:01<00:00, 90.29it/s]


In [12]:
preds

array([[1.51164424],
       [1.51164424],
       [1.51164424],
       ...,
       [1.3673563 ],
       [1.3673563 ],
       [1.3673563 ]])

In [13]:
# Now compute the mean kl-div
split_scores = []
splits       = 10

scores = []
for i in range(splits):
    part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
    kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
    kl = np.mean(np.sum(kl, 1))
    scores.append(np.exp(kl))
    
print('IS = ',np.mean(scores), '+-', np.std(scores))

IS =  1.002745778091639 +- 0.0008931803294836496


In [14]:
#1.0490048101931169 +- 0.018750268311493064

In [16]:

filenames = glob.glob('./saved_images/*.jpg')
filenames = sorted(filenames)
images = []
for filename in filenames:
    images.append(imageio.imread(filename))
imageio.mimsave('wgan.gif', images)


In [None]:
'''
from PIL import Image, ImageSequence

# Output (max) size
size = 320, 240

# Open source
im = Image.open("sgan.gif")

# Get sequence iterator
frames = ImageSequence.Iterator(im)

# Wrap on-the-fly thumbnail generator
def thumbnails(frames):
    for frame in frames:
        thumbnail = frame.copy()
        thumbnail.thumbnail(size, Image.ANTIALIAS)
        yield thumbnail

frames = thumbnails(frames)

# Save output
om = next(frames) # Handle first frame separately
om.info = im.info # Copy sequence info
om.save("out.gif", save_all=True, append_images=list(frames))
'''

In [None]:
'''
trainer.train(50)
save('saved/cur_state_2.state', netD, netG, trainer.optimizerD, trainer.optimizerG)
torch.save(netG.state_dict(), 'saved/cur_state_G_2.pth')
'''