In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.utils.data as data
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import numpy as np
import os
from utils import *
from gan import *

os.environ["CUDA_VISIBLE_DEVICES"] = '6'
torch.set_num_threads(4)

In [2]:
# load data
batch_size = 64

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = torchvision.datasets.LSUN('/home/dayun/lsun', classes=['bedroom_train'], transform=transform)
train_loader = data.DataLoader(train_data, batch_size=batch_size, num_workers = 4, shuffle=True)


In [4]:
## hyperparameters
generator_iters = 600000
n_critic = 5

lr = 0.00005
c = 0.01

device = 'cuda' if torch.cuda.is_available() else 'cpu'

generator = DCGAN_BN_G()
discriminator = DCGAN_D()

generator.apply(weights_init)
discriminator.apply(weights_init)

DCGAN_D(
  (conv): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [5]:
## train
optim_D = optim.RMSprop(discriminator.parameters(), lr=lr)
optim_G = optim.RMSprop(generator.parameters(), lr=lr)

loss_D = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor
loss_G = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor

data = get_infinite_batches(train_loader)
LOSS = {'D_real': [],
        'D_fake': [],
        'G': [],
        'W': []}

if not os.path.exists('result_BN/'):
    os.makedirs('result_BN/')
    
if device == 'cuda':
    generator.to(device)
    discriminator.to(device)
    cudnn.benchmark = True        
    
for g in range(generator_iters+1):
    # discriminator
    for p in discriminator.parameters():
        p.requires_grad = True
        
    for n in range(n_critic):    
        discriminator.zero_grad()
        
        real = data.__next__().to(device).requires_grad_(True)
        D_loss_real = torch.mean(discriminator(real))
        
        z = torch.randn(batch_size, 100, 1, 1, device = device, requires_grad = True)
        fake = generator(z)
        D_loss_fake = torch.mean(discriminator(fake))

        W_loss = -D_loss_real + D_loss_fake
        W_loss.backward()
        optim_D.step()
   
        # clip
        for p in discriminator.parameters():
            p.data.clamp_(-c, c)
        
    # generator
    for p in discriminator.parameters():
        p.requires_grad = False
        
    generator.zero_grad()
    
    z = torch.randn(batch_size, 100, 1, 1, device = device, requires_grad = True)
    fake = generator(z)
    loss_G = -torch.mean(discriminator(generator(z)))
    loss_G.backward()
    optim_G.step()
        
    LOSS['D_real'].append(D_loss_real.data.cpu().numpy())
    LOSS['D_fake'].append(D_loss_fake.data.cpu().numpy())
    LOSS['G'].append(loss_G.data.cpu().numpy())
    LOSS['W'].append(-W_loss.data.cpu().numpy())
            
    if g % 1000 == 0:
        # save model
        save_model(generator, discriminator, 'BN')
        save_dict(LOSS, 'loss_BN')
        
        # save generator image
        z = Variable(torch.randn(64, 100, 1, 1)).cuda()
        samples = generator(z)
        samples = samples.mul(0.5).add(0.5)  # denormalize
        samples = samples.data.cpu()[:64]
        grid = torchvision.utils.make_grid(samples)
        torchvision.utils.save_image(grid, 'result_BN/iter_{}.png'.format(str(g)))
       
        # print wasserstein distance
        print("epoch: %d/%d\tW loss:%.5f" %(g, generator_iters, -W_loss))


Models save to ./generator.pkl & ./discriminator.pkl
epoch: 0/600000	W loss:0.21243
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 1000/600000	W loss:2.94358
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 2000/600000	W loss:3.02785
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 3000/600000	W loss:2.97260
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 4000/600000	W loss:2.92566
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 5000/600000	W loss:3.04067
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 6000/600000	W loss:3.02348
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 7000/600000	W loss:2.96425
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 8000/600000	W loss:0.31707
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 9000/600000	W loss:2.81641
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 10000/600000	W loss:2.74621
Models save to ./generator.pkl & ./discrimina

epoch: 93000/600000	W loss:1.61108
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 94000/600000	W loss:1.36402
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 95000/600000	W loss:1.19428
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 96000/600000	W loss:1.11579
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 97000/600000	W loss:1.61056
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 98000/600000	W loss:1.29310
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 99000/600000	W loss:1.31860
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 100000/600000	W loss:1.57745
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 101000/600000	W loss:1.19768
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 102000/600000	W loss:1.39432
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 103000/600000	W loss:1.11849
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 104000/600000	W loss:

Models save to ./generator.pkl & ./discriminator.pkl
epoch: 186000/600000	W loss:1.23924
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 187000/600000	W loss:1.15941
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 188000/600000	W loss:1.00838
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 189000/600000	W loss:1.10944
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 190000/600000	W loss:1.25862
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 191000/600000	W loss:1.18381
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 192000/600000	W loss:1.08798
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 193000/600000	W loss:1.11447
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 194000/600000	W loss:1.20772
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 195000/600000	W loss:1.25265
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 196000/600000	W loss:1.14354
Models save to ./gene

epoch: 278000/600000	W loss:0.86096
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 279000/600000	W loss:1.09255
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 280000/600000	W loss:0.82486
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 281000/600000	W loss:0.91873
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 282000/600000	W loss:1.12606
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 283000/600000	W loss:1.02617
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 284000/600000	W loss:0.95678
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 285000/600000	W loss:0.90459
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 286000/600000	W loss:0.97754
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 287000/600000	W loss:0.98558
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 288000/600000	W loss:0.88608
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 289000/600000	

Models save to ./generator.pkl & ./discriminator.pkl
epoch: 371000/600000	W loss:0.86895
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 372000/600000	W loss:0.95338
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 373000/600000	W loss:0.80111
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 374000/600000	W loss:0.95634
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 375000/600000	W loss:0.99818
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 376000/600000	W loss:0.76549
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 377000/600000	W loss:0.96077
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 378000/600000	W loss:0.74917
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 379000/600000	W loss:0.81842
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 380000/600000	W loss:0.92854
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 381000/600000	W loss:0.78729
Models save to ./gene

epoch: 463000/600000	W loss:0.71468
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 464000/600000	W loss:0.83797
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 465000/600000	W loss:0.94867
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 466000/600000	W loss:0.81169
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 467000/600000	W loss:0.77975
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 468000/600000	W loss:0.84645
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 469000/600000	W loss:0.69722
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 470000/600000	W loss:0.84761
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 471000/600000	W loss:0.88707
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 472000/600000	W loss:0.88813
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 473000/600000	W loss:0.93486
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 474000/600000	

Models save to ./generator.pkl & ./discriminator.pkl
epoch: 556000/600000	W loss:0.78757
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 557000/600000	W loss:0.78223
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 558000/600000	W loss:0.86032
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 559000/600000	W loss:0.88543
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 560000/600000	W loss:0.86013
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 561000/600000	W loss:0.78886
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 562000/600000	W loss:0.80176
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 563000/600000	W loss:0.81276
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 564000/600000	W loss:0.69806
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 565000/600000	W loss:0.81197
Models save to ./generator.pkl & ./discriminator.pkl
epoch: 566000/600000	W loss:0.83353
Models save to ./gene