In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from utils import create_dataloader

### Creating config object (argparse workaround)

In [2]:
class Config:
    pass

config = Config()
config.mnist_path = None
config.batch_size = 16
config.num_workers = 3
config.num_epochs = 10
config.noise_size = 50
config.print_freq = 100


### Create dataloder

In [3]:
dataloader = create_dataloader(config)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [4]:
len(dataloader)

3750

In [5]:
for image, cat in dataloader:
    break

In [6]:
image.size()

torch.Size([16, 1, 28, 28])

In [7]:
28*28

784

### Create generator and discriminator

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential( 
            nn.Linear(config.noise_size, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 28*28),
            nn.Sigmoid())
        
    def forward(self, x):
        return self.model(x)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 50),
            nn.ReLU(inplace=True),
            nn.Linear(50, 1), 
            nn.Sigmoid())
    def forward(self, x):
        return self.model(x)

In [9]:
generator = Generator()
discriminator = Discriminator()

### Create optimizers and loss

In [10]:
optim_G = optim.Adam(params=generator.parameters(), lr=0.0001)
optim_D = optim.Adam(params=discriminator.parameters(), lr=0.0001)

criterion = nn.BCELoss()

### Create necessary variables

In [11]:
input = Variable(torch.FloatTensor(config.batch_size, 28*28))
noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size))
fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size).normal_(0, 1))
label = Variable(torch.FloatTensor(config.batch_size))
real_label = 1
fake_label = 0

### Задание

1) Посмотрите на реализацию GAN

2) Поменяйте ее, чтобы получился LSGAN https://arxiv.org/pdf/1611.04076v2.pdf

3) Попробуйте оба GAN и LSGAN на CelebA http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

4) Напишите отчет что попробовали, какие результаты получили, как вам кажется надо обучать GAN, чтобы добиться сходимости?

Обязательны графики.

### GAN

In [12]:
for epoch in range(config.num_epochs):
    for iteration, (images, cat) in enumerate(dataloader):
        ####### 
        # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z))) 
        #######
        discriminator.zero_grad()
        
        # real
        label.data.fill_(real_label)
        input.data.copy_(images)
        output = discriminator(input)
        errD_x = criterion(output, label)
        errD_x.backward()
        
        # fake 
        noise.data.normal_(0, 1)
        fake = generator(noise)
        label.data.fill_(fake_label)
        output = discriminator(fake.detach())
        errD_z = criterion(output, label)
        errD_z.backward()
        
        optim_D.step()
        
        ####### 
        # Generator stage: maximize log(D(G(x))
        #######
        generator.zero_grad()
        label.data.fill_(real_label)
        output = discriminator(fake)
        errG = criterion(output, label)
        errG.backward()
        
        optim_G.step()
        
        if (iteration+1) % config.print_freq == 0:
            print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,
                                                                                            iteration+1, 
                                                                                            errD_x.data[0],
                                                                                            errD_z.data[0], 
                                                                                            errG.data[0]))

  # Remove the CWD from sys.path while we load stuff.
  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch:1 Iter: 100 errD_x: 0.15 errD_z: 0.05 errG: 3.15
Epoch:1 Iter: 200 errD_x: 0.06 errD_z: 0.01 errG: 4.91
Epoch:1 Iter: 300 errD_x: 0.09 errD_z: 0.03 errG: 5.31
Epoch:1 Iter: 400 errD_x: 0.34 errD_z: 0.15 errG: 3.85
Epoch:1 Iter: 500 errD_x: 0.19 errD_z: 0.28 errG: 1.90
Epoch:1 Iter: 600 errD_x: 0.19 errD_z: 0.06 errG: 3.31
Epoch:1 Iter: 700 errD_x: 0.10 errD_z: 0.04 errG: 3.61
Epoch:1 Iter: 800 errD_x: 0.04 errD_z: 0.03 errG: 3.69
Epoch:1 Iter: 900 errD_x: 0.02 errD_z: 0.04 errG: 3.54
Epoch:1 Iter: 1000 errD_x: 0.03 errD_z: 0.05 errG: 3.50
Epoch:1 Iter: 1100 errD_x: 0.03 errD_z: 0.04 errG: 3.69
Epoch:1 Iter: 1200 errD_x: 0.04 errD_z: 0.05 errG: 3.29
Epoch:1 Iter: 1300 errD_x: 0.03 errD_z: 0.08 errG: 2.79
Epoch:1 Iter: 1400 errD_x: 0.04 errD_z: 0.12 errG: 2.34
Epoch:1 Iter: 1500 errD_x: 0.11 errD_z: 0.20 errG: 1.98
Epoch:1 Iter: 1600 errD_x: 0.22 errD_z: 0.09 errG: 2.69
Epoch:1 Iter: 1700 errD_x: 0.08 errD_z: 0.20 errG: 1.95
Epoch:1 Iter: 1800 errD_x: 0.01 errD_z: 0.12 errG: 2.51
E

Epoch:4 Iter: 3700 errD_x: 0.01 errD_z: 0.03 errG: 5.38
Epoch:5 Iter: 100 errD_x: 0.06 errD_z: 0.04 errG: 4.78
Epoch:5 Iter: 200 errD_x: 0.06 errD_z: 0.03 errG: 4.39
Epoch:5 Iter: 300 errD_x: 0.04 errD_z: 0.01 errG: 4.85
Epoch:5 Iter: 400 errD_x: 0.14 errD_z: 0.02 errG: 4.76
Epoch:5 Iter: 500 errD_x: 0.20 errD_z: 0.00 errG: 6.28
Epoch:5 Iter: 600 errD_x: 0.14 errD_z: 0.02 errG: 4.77
Epoch:5 Iter: 700 errD_x: 0.00 errD_z: 0.08 errG: 4.68
Epoch:5 Iter: 800 errD_x: 0.03 errD_z: 0.05 errG: 4.92
Epoch:5 Iter: 900 errD_x: 0.03 errD_z: 0.07 errG: 4.30
Epoch:5 Iter: 1000 errD_x: 0.11 errD_z: 0.06 errG: 4.58
Epoch:5 Iter: 1100 errD_x: 0.42 errD_z: 0.01 errG: 6.57
Epoch:5 Iter: 1200 errD_x: 0.09 errD_z: 0.01 errG: 5.54
Epoch:5 Iter: 1300 errD_x: 0.01 errD_z: 0.14 errG: 3.78
Epoch:5 Iter: 1400 errD_x: 0.06 errD_z: 0.03 errG: 5.42
Epoch:5 Iter: 1500 errD_x: 0.09 errD_z: 0.14 errG: 4.71
Epoch:5 Iter: 1600 errD_x: 0.13 errD_z: 0.01 errG: 6.18
Epoch:5 Iter: 1700 errD_x: 0.00 errD_z: 0.04 errG: 4.60
E

Epoch:8 Iter: 3600 errD_x: 0.05 errD_z: 0.13 errG: 3.91
Epoch:8 Iter: 3700 errD_x: 0.00 errD_z: 0.07 errG: 6.03
Epoch:9 Iter: 100 errD_x: 0.11 errD_z: 0.07 errG: 4.54
Epoch:9 Iter: 200 errD_x: 0.02 errD_z: 0.02 errG: 4.58
Epoch:9 Iter: 300 errD_x: 0.15 errD_z: 0.09 errG: 3.95
Epoch:9 Iter: 400 errD_x: 0.06 errD_z: 0.04 errG: 4.87
Epoch:9 Iter: 500 errD_x: 0.09 errD_z: 0.02 errG: 5.83
Epoch:9 Iter: 600 errD_x: 0.01 errD_z: 0.04 errG: 4.22
Epoch:9 Iter: 700 errD_x: 0.14 errD_z: 0.26 errG: 4.32
Epoch:9 Iter: 800 errD_x: 0.01 errD_z: 0.04 errG: 4.23
Epoch:9 Iter: 900 errD_x: 0.14 errD_z: 0.06 errG: 4.44
Epoch:9 Iter: 1000 errD_x: 0.30 errD_z: 0.04 errG: 4.57
Epoch:9 Iter: 1100 errD_x: 0.05 errD_z: 0.08 errG: 4.14
Epoch:9 Iter: 1200 errD_x: 0.01 errD_z: 0.07 errG: 4.36
Epoch:9 Iter: 1300 errD_x: 0.00 errD_z: 0.20 errG: 3.24
Epoch:9 Iter: 1400 errD_x: 0.01 errD_z: 0.30 errG: 3.50
Epoch:9 Iter: 1500 errD_x: 0.00 errD_z: 0.24 errG: 3.87
Epoch:9 Iter: 1600 errD_x: 0.15 errD_z: 0.25 errG: 4.03
E