In [5]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image


### GAN 초간단 정리
- Discriminator
    - 784, 256
    - ReLU
    - 256, 256
    - ReLU
    - 256, 1
    - Sigmoid
- Generator
    - 64, 256
    - ReLU
    - 256, 256
    - 256, 784
    - Tanh

In [4]:
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
    nn.Sigmoid())

G = nn.Sequential(
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 784),
    nn.Tanh())

In [7]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.5, 0.5, 0.5))])

In [8]:
mnist = datasets.MNIST(root='./data/',
                       train=True,
                       transform=transform,
                       download=True)

In [37]:
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=100,
                                          shuffle=True)

### torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate at 0x1093ef048>, pin_memory=False, drop_last=False)

### Docstring:     
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.

### Arguments:
- dataset (Dataset): dataset from which to load the data.  
- batch_size (int, optional): how many samples per batch to load(default: 1).  
- shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False).  
- sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False.  
- batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
- num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process(default: 0)  
- collate_fn (callable, optional): merges a list of samples to form a mini-batch.  
- pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them.  
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

In [10]:
if torch.cuda.is_available():
    D.cuda()
    G.cuda()

In [29]:
num_epochs = 200

In [None]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)

for epoch in range(100):
    for i, (images, _) in enumerate(data_loader):
        batch_size = images.size(0)
        images = Variable(images.view(batch_size, -1))
        
        real_labels = Variable(torch.ones(batch_size))
        fake_labels = Variable(torch.zeros(batch_size))
        # GPU를 사용할 땐 x = x.cuda()를 하고 Variable(x)를 return
        
        # Train discriminator
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_scroe = outputs
        
        z = Variable(torch.randn(batch_size, 64))
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_scroe = outputs
        
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator
        z = Variable(torch.randn(batch_size, 64))
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels)
        
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (epoch+1) % 5 == 0:
            print ('Epoch [{}/{}], Discriminator Loss: {:4f}, Generator Loss: {:4f}'.format(
                epoch+1, num_epochs, d_loss.data[0], g_loss.data[0]))
        
torch.save(G.state_dict(), './generator.pkl')
torch.save(D.state_dict(), './discriminator.pkl')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.312897
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.399529
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.485228
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.569799
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.593847
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.619501
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.648684
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.681005
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.716059
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.753416
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.792627
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.833277
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.875092
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 19.917604
Epoch 

Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.422632
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.417824
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.417730
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.421968
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.429962
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.441498
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.455931
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.473099
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.492582
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.514088
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.535131
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.557833
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.581812
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.598509
Epoch 

Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.828096
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.854219
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.880623
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.907316
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.934080
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.960848
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 20.987661
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.014307
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.040796
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.067114
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.093134
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.110712
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.128756
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.147215
Epoch 

Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.979649
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.985619
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.991970
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 21.998672
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.005705
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.012932
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.020370
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.022058
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.023794
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.026340
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.029539
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.033358
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.037710
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.042587
Epoch 

Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.537249
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.539711
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.542568
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.545681
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.549044
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.552517
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.555578
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.555418
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.555262
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.555687
Epoch [5/200], Discriminator Loss: 0.000010, Generator Loss: 22.541100
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.528538
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.517736
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.508656
Epoch 

Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.693249
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.692879
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.693123
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.689949
Epoch [5/200], Discriminator Loss: 0.000000, Generator Loss: 22.687639
