라이브러리

In [3]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib
import matplotlib.pylab as plt

from torchvision.utils import make_grid, save_image
import torchvision.datasets as datasets
import torchvision.transforms as transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

변수 값

In [4]:
batch_size = 512
epochs = 200
sample_size = 64 
nz = 128 
k = 1 

MNIST를 내려받은 후 정규화

In [5]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])

train_dataset = datasets.MNIST(
    root="./data", train=True, transform=transform, download=True)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  cpuset_checked))


생성자 네트워크 생성

In [6]:
class Generator(nn.Module):
  def __init__(self, nz):
    super(Generator, self).__init__()
    self.nz = nz
    self.main = nn.Sequential(
        nn.Linear(self.nz, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 512),
        nn.LeakyReLU(0.2),
        nn.Linear(512, 1024),
        nn.LeakyReLU(0.2),
        nn.Linear(1024, 784),
        nn.Tanh(),
  )

  def forward(self, x):
    return self.main(x).view(-1, 1, 28, 28)

판별자 네트워크 생성

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

초기화

In [8]:
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)
print(generator)
print(discriminator)

Generator(
  (main): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1, bias=True)
    (10): Sigmoid()
  )
)


옵티마이저, 손실함수 정의

In [9]:
optim_g = optim.Adam(generator.parameters(), lr=0.0002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()

losses_g = [] 
losses_d = [] 
images = [] 

생성된 이미지 저장 함수 정의

In [10]:
def save_generator_image(image, path):
    save_image(image, path)

판별자 학습을 위한 함수

In [11]:
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = torch.ones(b_size, 1).to(device)
    fake_label = torch.zeros(b_size, 1).to(device)
    optimizer.zero_grad()
    output_real = discriminator(data_real)
    loss_real = criterion(output_real, real_label)
    output_fake = discriminator(data_fake)
    loss_fake = criterion(output_fake, fake_label)
    loss_real.backward()
    loss_fake.backward()
    optimizer.step()
    return loss_real + loss_fake

생성자 학습을 위한 함수

In [12]:
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)    
    real_label = torch.ones(b_size, 1).to(device)
    optimizer.zero_grad()
    output = discriminator(data_fake)
    loss = criterion(output, real_label)
    loss.backward()
    optimizer.step()
    return loss

모델학습

In [13]:
!mkdir img #img 폴더 만들기

mkdir: cannot create directory ‘img’: File exists


In [14]:
generator.train()
discriminator.train()

for epoch in range(epochs):
    loss_g = 0.0
    loss_d = 0.0
    for idx, data in tqdm(enumerate(train_loader), total=int(len(train_dataset)/train_loader.batch_size)):
        image, _ = data
        image = image.to(device)
        b_size = len(image)
        for step in range(k):                                
            data_fake = generator(torch.randn(b_size, nz).to(device)).detach()
            data_real = image
            loss_d += train_discriminator(optim_d, data_real, data_fake)
        data_fake = generator(torch.randn(b_size, nz).to(device))
        loss_g += train_generator(optim_g, data_fake)
    generated_img = generator(torch.randn(b_size, nz).to(device)).cpu().detach()
    generated_img = make_grid(generated_img)
    save_generator_image(generated_img, "./img/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / idx 
    epoch_loss_d = loss_d / idx 
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    
    print(f"Epoch {epoch} of {epochs}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

  cpuset_checked))
118it [00:11, 10.31it/s]                         


Epoch 0 of 200
Generator loss: 1.26433420, Discriminator loss: 0.89390093


118it [00:09, 12.45it/s]                         

Epoch 1 of 200
Generator loss: 1.68707442, Discriminator loss: 1.27960753



118it [00:10, 11.74it/s]                         

Epoch 2 of 200
Generator loss: 2.67859650, Discriminator loss: 1.02706456



118it [00:11,  9.85it/s]                         

Epoch 3 of 200
Generator loss: 6.85774612, Discriminator loss: 0.38197088



118it [00:09, 12.22it/s]                         

Epoch 4 of 200
Generator loss: 3.68248940, Discriminator loss: 0.66707629



118it [00:13,  8.70it/s]

Epoch 5 of 200
Generator loss: 2.49973941, Discriminator loss: 1.03940201



118it [00:14,  8.34it/s]                         

Epoch 6 of 200
Generator loss: 2.40022516, Discriminator loss: 0.95008326



118it [00:12,  9.50it/s]

Epoch 7 of 200
Generator loss: 1.76613986, Discriminator loss: 1.12662375



118it [00:13,  8.51it/s]                         


Epoch 8 of 200
Generator loss: 1.85582805, Discriminator loss: 0.93674350


118it [00:13,  8.82it/s]                         

Epoch 9 of 200
Generator loss: 2.26175237, Discriminator loss: 0.99317038



118it [00:12,  9.73it/s]                         

Epoch 10 of 200
Generator loss: 1.74709892, Discriminator loss: 1.10250056



118it [00:13,  8.82it/s]                         

Epoch 11 of 200
Generator loss: 1.81082988, Discriminator loss: 0.92109442



118it [00:12,  9.14it/s]                         

Epoch 12 of 200
Generator loss: 1.52849615, Discriminator loss: 1.19835448



118it [00:13,  8.96it/s]                         

Epoch 13 of 200
Generator loss: 1.94255388, Discriminator loss: 1.03905535



118it [00:10, 11.51it/s]                         

Epoch 14 of 200
Generator loss: 0.78360665, Discriminator loss: 1.28558040



118it [00:10, 10.95it/s]                         

Epoch 15 of 200
Generator loss: 1.23022997, Discriminator loss: 1.13170505



118it [00:09, 11.83it/s]                         

Epoch 16 of 200
Generator loss: 2.08819246, Discriminator loss: 0.96472770



118it [00:09, 11.94it/s]                         

Epoch 17 of 200
Generator loss: 2.54060292, Discriminator loss: 1.18727899



118it [00:10, 10.79it/s]                         

Epoch 18 of 200
Generator loss: 2.14646602, Discriminator loss: 0.80537504



118it [00:13,  8.67it/s]                         

Epoch 19 of 200
Generator loss: 2.21369195, Discriminator loss: 0.92397463



118it [00:12,  9.23it/s]                         

Epoch 20 of 200
Generator loss: 2.70703459, Discriminator loss: 0.61489713



118it [00:10, 10.94it/s]

Epoch 21 of 200
Generator loss: 2.80564332, Discriminator loss: 0.57373893



118it [00:12,  9.26it/s]                         

Epoch 22 of 200
Generator loss: 3.17875552, Discriminator loss: 0.65897471



118it [00:11,  9.90it/s]                         

Epoch 23 of 200
Generator loss: 2.42810512, Discriminator loss: 0.72600842



118it [00:11, 10.55it/s]                         

Epoch 24 of 200
Generator loss: 2.89098144, Discriminator loss: 0.69551122



118it [00:10, 10.86it/s]                         

Epoch 25 of 200
Generator loss: 3.29217052, Discriminator loss: 0.49910107



118it [00:10, 11.65it/s]                         


Epoch 26 of 200
Generator loss: 2.30159092, Discriminator loss: 0.60605043


118it [00:10, 11.72it/s]                         

Epoch 27 of 200
Generator loss: 2.28961444, Discriminator loss: 0.76601648



118it [00:10, 11.27it/s]


Epoch 28 of 200
Generator loss: 2.44315195, Discriminator loss: 0.57768494


118it [00:10, 11.38it/s]                         

Epoch 29 of 200
Generator loss: 2.86708045, Discriminator loss: 0.61051607



118it [00:09, 11.87it/s]                         

Epoch 30 of 200
Generator loss: 2.85195279, Discriminator loss: 0.57553601



118it [00:09, 11.85it/s]                         


Epoch 31 of 200
Generator loss: 2.78782392, Discriminator loss: 0.55651116


118it [00:10, 11.76it/s]


Epoch 32 of 200
Generator loss: 2.67412996, Discriminator loss: 0.60003936


118it [00:10, 11.80it/s]                         

Epoch 33 of 200
Generator loss: 3.08150601, Discriminator loss: 0.41962385



118it [00:09, 11.88it/s]                         

Epoch 34 of 200
Generator loss: 2.80679250, Discriminator loss: 0.60446781



118it [00:10, 11.69it/s]                         

Epoch 35 of 200
Generator loss: 2.50830102, Discriminator loss: 0.58268207



118it [00:10, 11.63it/s]                         

Epoch 36 of 200
Generator loss: 2.60208845, Discriminator loss: 0.52921396



 84%|████████▍ | 98/117 [00:09<00:01, 10.31it/s]


KeyboardInterrupt: ignored

생성자와 판별자의 오차 확인

In [None]:
plt.figure()
losses_g = [f1.item() for f1 in losses_g]

plt.plot(losses_g, label='Generator loss')
losses_g = [f2.item() for f2 in losses_d]

plt.plot(losses_d, label='Discriminator Loss')
plt.legeng()

생성된 이미지 출력

In [None]:
fake_images = generator(torch.randn(b_size, nz)).to(device)
for i in range(10):
  fake_images_img = np.reshape(fake_images.data)