In [1]:
import argparse
import numpy
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
import matplotlib.pyplot as plt
from torchvision import transforms, datasets

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from easydict import EasyDict as edict

opt = edict()

opt.n_iters = 100
opt.batch_size = 48
opt.lr = 0.0001
opt.b1 = 0.5 
opt.b2 = 0.999
opt.latent_dim = 128
opt.input_folder = '/home/trojan/Desktop/dimentia/DCGAN/dataset'
opt.save_dir = '/home/trojan/Desktop/dimentia/ralsgan/results'
opt.dict_dir = '/home/trojan/Desktop/dimentia/ralsgan/models'
opt.img_size = 256
opt.channels = 3
opt.display_port = 8097
opt.display_server = 'http://localhost'
opt.sample_interval = 5

In [3]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [4]:
img_dims = (opt.channels, opt.img_size, opt.img_size)
n_features = opt.channels * opt.img_size * opt.img_size

In [5]:
# Appendix D.4., DCGAN for 0.
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):
            block = [
                nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(n_output),
                nn.ReLU(inplace=True),
            ]
            return block

        self.model = nn.Sequential(
            *convlayer(opt.latent_dim, 1024, 4, 1, 0), # Fully connected layer via convolution.
            *convlayer(1024, 512, 4, 2, 1),
            *convlayer(512, 256, 4, 2, 1),
            *convlayer(256, 128, 4, 2, 1),
            *convlayer(128, 64, 4, 2, 1),
            *convlayer(64, 32, 4, 2, 1),
            nn.ConvTranspose2d(32, opt.channels, 4, 2, 1),
            nn.Tanh()
        )
        '''
        There is a slight error in v2 of the relativistic gan paper, where
        the architecture goes from 128>64>32 but then 64>3.
        '''


    def forward(self, z):
        z = z.view(-1, opt.latent_dim, 1, 1)
        img = self.model(z)
        return img

In [6]:
# PACGAN2.
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False):
            block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
            if bn:
                block.append(nn.BatchNorm2d(n_output))
            block.append(nn.LeakyReLU(0.2, inplace=True))
            return block

        self.model = nn.Sequential(
            *convlayer(opt.channels * 2, 32, 4, 2, 1),
            *convlayer(32, 64, 4, 2, 1),
            *convlayer(64, 128, 4, 2, 1, bn=True),
            *convlayer(128, 256, 4, 2, 1, bn=True),
            *convlayer(256, 512, 4, 2, 1, bn=True),
            *convlayer(512, 1024, 4, 2, 1, bn=True),
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False),  # FC with Conv.
        )

    def forward(self, imgs):
        critic_value = self.model(imgs)
        critic_value  = critic_value.view(imgs.size(0), -1)
        return critic_value

In [7]:
transform = transforms.Compose([
    transforms.Resize((opt.img_size, opt.img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
])

data = datasets.ImageFolder(root=opt.input_folder, transform=transform)


In [8]:
def generate_random_sample():
    while True:
        random_indexes = numpy.random.choice(data.__len__(), size=opt.batch_size * 2, replace=False)
        batch = [data[i][0] for i in random_indexes]
        yield torch.stack(batch, 0)


random_sample = generate_random_sample()

def mse_loss(input, target):
    return torch.sum((input - target)**2) / input.data.nelement()

In [9]:
cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
gan_loss = mse_loss

generator = Generator()
discriminator = Discriminator()

optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Loss record.
g_losses = []
d_losses = []
epochs = []
loss_legend = ['Discriminator', 'Generator']

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

noise_fixed = Variable(Tensor(25, opt.latent_dim).normal_(0, 1), requires_grad=False)
fixed_noise = torch.randn(1, opt.latent_dim, 1, 1)

In [10]:
import torchvision.utils as vutils

for it in range(int(opt.n_iters)):
    print('Iter. {}'.format(it))

    batch = random_sample.__next__()

    imgs_real = Variable(batch.type(Tensor))
    imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1)
    real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)

    noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1))
    imgs_fake = generator(noise)
    imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1)

    # == Discriminator update == #
    optimizer_D.zero_grad()

    c_xr = discriminator(imgs_real)
    c_xf = discriminator(imgs_fake.detach())

    d_loss = gan_loss(c_xr, torch.mean(c_xf) + real) + gan_loss(c_xf, torch.mean(c_xr) - real)

    d_loss.backward()
    optimizer_D.step()

    # == Generator update == #
    batch = random_sample.__next__()

    imgs_real = Variable(batch.type(Tensor))
    imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1)

    noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1))
    imgs_fake = generator(noise)
    imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1)

    c_xr = discriminator(imgs_real)
    c_xf = discriminator(imgs_fake)
    real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)

    optimizer_G.zero_grad()

    g_loss = gan_loss(c_xf, torch.mean(c_xr) + real) + gan_loss(c_xr, torch.mean(c_xf) - real)

    g_loss.backward()
    optimizer_G.step()

    if it % opt.sample_interval == 0:

        # Keep a record of losses for plotting.
        epochs.append(it)
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        # Generate images for a given set of fixed noise
        # so we can track how the GAN learns.
        imgs_fake_fixed = generator(noise_fixed).detach().data
        #imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.

        vutils.save_image(imgs_fake_fixed,
                    opt.save_dir + '/test{}.png'.format(it),
                    normalize=True)
        
        torch.save(generator.state_dict(), '%s/gen_it_%d.pth' % (opt.dict_dir, it))
        torch.save(discriminator.state_dict(), '%s/disc_it_%d.pth' % (opt.dict_dir, it))
        #plt.imshow(imgs_fake_fixed)

Iter. 0
Iter. 1
Iter. 2
Iter. 3
Iter. 4
Iter. 5
Iter. 6
Iter. 7
Iter. 8
Iter. 9
Iter. 10
Iter. 11
Iter. 12
Iter. 13
Iter. 14


KeyboardInterrupt: 