In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [2]:
# hyperparameters
EPOCHS = 500
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("다음 장치를 사용합니다: ", DEVICE)

다음 장치를 사용합니다:  cuda


In [3]:
# Fashion MNIST 
trainset = datasets.FashionMNIST(
    './data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ]))

train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 26411008/26421880 [01:00<00:00, 540019.42it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw



0it [00:00, ?it/s][A

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/29515 [00:00<?, ?it/s][A
 56%|█████▌    | 16384/29515 [00:00<00:00, 53524.97it/s][A
32768it [00:01, 25921.50it/s]                           [A

0it [00:00, ?it/s][A

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/4422102 [00:00<?, ?it/s][A
  0%|          | 16384/4422102 [00:00<01:13, 59672.01it/s][A
  1%|          | 32768/4422102 [00:01<01:15, 57797.60it/s][A
  1%|          | 40960/4422102 [00:01<01:35, 46028.83it/s][A
  1%|          | 49152/4422102 [00:01<02:00, 36182.72it/s][A
  1%|▏         | 65536/4422102 [00:02<01:48, 40045.96it/s][A
  2%|▏         | 73728/4422102 [00:02<02:04, 34788.41it/s][A
  2%|▏         | 90112/4422102 [00:02<01:51, 38845.80it/s][A
  2%|▏         | 106496/4422102 [00:03<01:42, 42298.66it/s][A
  3%|▎         | 122880/4422102 [00:03<01:35, 45098.80it/s][A
  3%|▎         | 139264/4422102 [00:03<01:30, 47289.68it/s][A
  4%|▎         | 163840/4422102 [00:03<01:19, 53900.84it/s][A
  4%|▍         | 196608/4422102 [00:04<01:06, 63286.52it/s][A
  5%|▌         | 237568/4422102 [00:04<00:55, 75129.55it/s][A
  6%|▋         | 278528/4422102 [00:04<00:47, 86458.85it/s][A
  8%|▊         | 335872/4422102 [00:05<00:39, 102987.98it/s][A
  9%|▉       

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz




8192it [00:00, 9968.30it/s]             [A[A

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!



26427392it [01:20, 540019.42it/s]                              
4423680it [00:27, 513321.84it/s]                             [A

In [4]:
# Generator
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 784),
    nn.Tanh()
)

In [5]:
# Discriminator
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

In [6]:
# 모델의 가중치를 지정한 장치로 보내기
D = D.to(DEVICE)
G = G.to(DEVICE)

# Binary Cross Entropy Loss와
# optimizer Adam
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [19]:
total_step = len(train_loader)
TRAINING = False
if TRAINING:
    for epoch in range(1, EPOCHS+1):
        for i, (images, _) in enumerate(train_loader):
            images = images.reshape(BATCH_SIZE, -1).to(DEVICE)

            real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

            # 판별자가 진짜 이미지를 진짜로 인식하는 오차 계산
            outputs = D(images)
            d_loss_real = criterion(outputs, real_labels)
            real_score = outputs

            # random tensor로 가짜 이미지 생성
            z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
            fake_images = G(z)

            # 판별자가 가짜 이미지를 가짜로 인식하는 오차 계산
            outputs = D(fake_images)
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs

            # True Positive + True Negative 의 오차를 더해 Discriminator Loss 계산
            d_loss = d_loss_real + d_loss_fake

            # BackPropagation으로 Discriminator모델의 학습을 진행
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # Generator가 Discriminator를 속였는지에 대한 오차 계산
            fake_images = G(z)
            outputs = D(fake_images)
            g_loss = criterion(outputs, real_labels)

            # BackPropagation으로 Generator 학습 진행
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()        
            g_loss.backward()
            g_optimizer.step()

        print("EPOCH: [{}/{}] d_loss: {:.4f} g_loss: {:.4f} D(x):{:.2f} D(G(z)):{:.2f}".format(
            epoch, EPOCHS, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()
        ))
    torch.save(D.state_dict(), './snapshot/gan_discriminator.pt')
    torch.save(G.state_dict(), './snapshot/gan_generator.pt')
else:
    D.load_state_dict('./snapshot/gan_discriminator.pt')
    G.load_state_dict('./snapshot/gan_generator.pt')

EPOCH: [1/500] d_loss: 0.3036 g_loss: 4.8440 D(x):0.93 D(G(z)):0.10
EPOCH: [2/500] d_loss: 0.2155 g_loss: 4.7782 D(x):0.90 D(G(z)):0.01
EPOCH: [3/500] d_loss: 0.1877 g_loss: 5.3818 D(x):0.94 D(G(z)):0.05
EPOCH: [4/500] d_loss: 0.1552 g_loss: 4.2822 D(x):0.96 D(G(z)):0.06
EPOCH: [5/500] d_loss: 0.2795 g_loss: 4.4524 D(x):0.93 D(G(z)):0.03
EPOCH: [6/500] d_loss: 0.2250 g_loss: 4.6137 D(x):0.92 D(G(z)):0.04
EPOCH: [7/500] d_loss: 0.2375 g_loss: 4.0393 D(x):0.91 D(G(z)):0.06
EPOCH: [8/500] d_loss: 0.4384 g_loss: 3.8509 D(x):0.88 D(G(z)):0.06
EPOCH: [9/500] d_loss: 0.5952 g_loss: 3.1912 D(x):0.92 D(G(z)):0.26
EPOCH: [10/500] d_loss: 0.3943 g_loss: 2.9390 D(x):0.92 D(G(z)):0.17
EPOCH: [11/500] d_loss: 0.3558 g_loss: 3.9738 D(x):0.86 D(G(z)):0.07
EPOCH: [12/500] d_loss: 0.3618 g_loss: 3.8864 D(x):0.88 D(G(z)):0.12
EPOCH: [13/500] d_loss: 0.2906 g_loss: 3.8178 D(x):0.92 D(G(z)):0.09
EPOCH: [14/500] d_loss: 0.3039 g_loss: 3.8422 D(x):0.91 D(G(z)):0.12
EPOCH: [15/500] d_loss: 0.4805 g_loss: 4.09

EPOCH: [120/500] d_loss: 1.1545 g_loss: 1.5852 D(x):0.59 D(G(z)):0.30
EPOCH: [121/500] d_loss: 1.1917 g_loss: 1.4632 D(x):0.65 D(G(z)):0.36
EPOCH: [122/500] d_loss: 1.0221 g_loss: 1.3521 D(x):0.75 D(G(z)):0.40
EPOCH: [123/500] d_loss: 0.8087 g_loss: 1.9893 D(x):0.68 D(G(z)):0.22
EPOCH: [124/500] d_loss: 1.2128 g_loss: 1.3383 D(x):0.61 D(G(z)):0.37
EPOCH: [125/500] d_loss: 0.9633 g_loss: 1.7676 D(x):0.67 D(G(z)):0.26
EPOCH: [126/500] d_loss: 1.0679 g_loss: 1.5175 D(x):0.64 D(G(z)):0.31
EPOCH: [127/500] d_loss: 0.8760 g_loss: 1.7512 D(x):0.70 D(G(z)):0.28
EPOCH: [128/500] d_loss: 1.0434 g_loss: 1.4978 D(x):0.68 D(G(z)):0.34
EPOCH: [129/500] d_loss: 0.9214 g_loss: 1.8175 D(x):0.70 D(G(z)):0.31
EPOCH: [130/500] d_loss: 0.9813 g_loss: 1.3471 D(x):0.73 D(G(z)):0.38
EPOCH: [131/500] d_loss: 0.9538 g_loss: 1.3720 D(x):0.71 D(G(z)):0.33
EPOCH: [132/500] d_loss: 1.0621 g_loss: 1.6169 D(x):0.67 D(G(z)):0.33
EPOCH: [133/500] d_loss: 1.0639 g_loss: 1.3358 D(x):0.66 D(G(z)):0.36
EPOCH: [134/500] d_l

EPOCH: [238/500] d_loss: 0.9502 g_loss: 1.4558 D(x):0.68 D(G(z)):0.32
EPOCH: [239/500] d_loss: 1.2567 g_loss: 1.1545 D(x):0.60 D(G(z)):0.39
EPOCH: [240/500] d_loss: 0.8975 g_loss: 1.4928 D(x):0.70 D(G(z)):0.32
EPOCH: [241/500] d_loss: 0.9118 g_loss: 1.6290 D(x):0.68 D(G(z)):0.30
EPOCH: [242/500] d_loss: 1.0496 g_loss: 1.1795 D(x):0.70 D(G(z)):0.37
EPOCH: [243/500] d_loss: 1.0615 g_loss: 1.4122 D(x):0.61 D(G(z)):0.30
EPOCH: [244/500] d_loss: 0.9477 g_loss: 1.3435 D(x):0.69 D(G(z)):0.34
EPOCH: [245/500] d_loss: 1.4353 g_loss: 1.0900 D(x):0.60 D(G(z)):0.45
EPOCH: [246/500] d_loss: 1.1543 g_loss: 1.3723 D(x):0.65 D(G(z)):0.38
EPOCH: [247/500] d_loss: 1.2019 g_loss: 1.1387 D(x):0.66 D(G(z)):0.42
EPOCH: [248/500] d_loss: 1.2393 g_loss: 1.0838 D(x):0.61 D(G(z)):0.42
EPOCH: [249/500] d_loss: 0.8598 g_loss: 1.8982 D(x):0.71 D(G(z)):0.27
EPOCH: [250/500] d_loss: 1.2297 g_loss: 1.1152 D(x):0.74 D(G(z)):0.47
EPOCH: [251/500] d_loss: 1.0805 g_loss: 1.5456 D(x):0.60 D(G(z)):0.29
EPOCH: [252/500] d_l

EPOCH: [356/500] d_loss: 1.0794 g_loss: 1.3709 D(x):0.64 D(G(z)):0.35
EPOCH: [357/500] d_loss: 0.7764 g_loss: 1.6612 D(x):0.69 D(G(z)):0.24
EPOCH: [358/500] d_loss: 0.9996 g_loss: 1.1781 D(x):0.68 D(G(z)):0.36
EPOCH: [359/500] d_loss: 0.9616 g_loss: 1.4682 D(x):0.68 D(G(z)):0.32
EPOCH: [360/500] d_loss: 1.0073 g_loss: 1.3124 D(x):0.68 D(G(z)):0.36
EPOCH: [361/500] d_loss: 1.2438 g_loss: 1.2847 D(x):0.57 D(G(z)):0.37
EPOCH: [362/500] d_loss: 0.9692 g_loss: 1.3894 D(x):0.69 D(G(z)):0.34
EPOCH: [363/500] d_loss: 0.8828 g_loss: 1.6398 D(x):0.68 D(G(z)):0.28
EPOCH: [364/500] d_loss: 0.7408 g_loss: 1.4338 D(x):0.80 D(G(z)):0.34
EPOCH: [365/500] d_loss: 1.2303 g_loss: 1.0909 D(x):0.64 D(G(z)):0.42
EPOCH: [366/500] d_loss: 1.2441 g_loss: 1.0262 D(x):0.62 D(G(z)):0.41
EPOCH: [367/500] d_loss: 1.0531 g_loss: 1.3674 D(x):0.65 D(G(z)):0.35
EPOCH: [368/500] d_loss: 0.9677 g_loss: 1.4994 D(x):0.67 D(G(z)):0.31
EPOCH: [369/500] d_loss: 1.0775 g_loss: 1.5146 D(x):0.60 D(G(z)):0.31
EPOCH: [370/500] d_l

EPOCH: [474/500] d_loss: 0.8337 g_loss: 1.5903 D(x):0.71 D(G(z)):0.27
EPOCH: [475/500] d_loss: 1.0369 g_loss: 1.3302 D(x):0.66 D(G(z)):0.37
EPOCH: [476/500] d_loss: 1.1180 g_loss: 1.3299 D(x):0.63 D(G(z)):0.34
EPOCH: [477/500] d_loss: 1.0783 g_loss: 1.2856 D(x):0.62 D(G(z)):0.35
EPOCH: [478/500] d_loss: 1.2126 g_loss: 1.1830 D(x):0.64 D(G(z)):0.41
EPOCH: [479/500] d_loss: 1.1786 g_loss: 1.3167 D(x):0.59 D(G(z)):0.34
EPOCH: [480/500] d_loss: 1.0834 g_loss: 1.2407 D(x):0.67 D(G(z)):0.38
EPOCH: [481/500] d_loss: 1.0796 g_loss: 1.3423 D(x):0.65 D(G(z)):0.35
EPOCH: [482/500] d_loss: 0.9489 g_loss: 1.4369 D(x):0.68 D(G(z)):0.32
EPOCH: [483/500] d_loss: 0.7829 g_loss: 1.6903 D(x):0.73 D(G(z)):0.27
EPOCH: [484/500] d_loss: 0.9574 g_loss: 1.7371 D(x):0.65 D(G(z)):0.27
EPOCH: [485/500] d_loss: 0.9361 g_loss: 1.3464 D(x):0.70 D(G(z)):0.34
EPOCH: [486/500] d_loss: 0.9017 g_loss: 1.5769 D(x):0.67 D(G(z)):0.27
EPOCH: [487/500] d_loss: 1.0150 g_loss: 1.4581 D(x):0.70 D(G(z)):0.36
EPOCH: [488/500] d_l

In [3]:
# 생성자가 만든 이미지 시각화
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
for i in range(10):
    fake_images_img = np.reshape(fake_images.data.cpu().numpy()[i], (28, 28))
    plt.imshow(fake_images_img, cmap='gray')
    plt.show()

NameError: name 'BATCH_SIZE' is not defined