In [None]:
import torch
import torchvision
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import MNIST # 데이터 셋셋from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # 셋 테스트 목적

<torch._C.Generator at 0x7f99458de8d0>

In [None]:
nn

<module 'torch.nn' from '/usr/local/lib/python3.7/dist-packages/torch/nn/__init__.py'>

In [None]:
def gen_block(input_dim,output_dim):
    return nn.Sequential (
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        #Build the neural network
        self.gen = nn.Sequential(
            gen_block(z_dim, hidden_dim),
            gen_block(hidden_dim, hidden_dim*2),
            gen_block(hidden_dim * 2, hidden_dim*4),
            gen_block(hidden_dim* 4, hidden_dim*8),
            
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
        
    def forward(self, noise):
        return self.gen(noise)
    
    def get_gen(self):
        return self.gen

In [None]:
def disc_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True)
    )

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

In [None]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 10 #
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

In [None]:
dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size = batch_size,
    shuffle = True)
##셀럽a도 해서 해보면 결과가 완전히 달라진다.

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [None]:
def get_disc_loss (gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise (num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    # 컴페어 fake_pred & zeros
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    # 컴페어 real_pred & ones
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss)/2
    
    return disc_loss

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise (num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

    return gen_loss

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True
gen_loss = False
#disc_loss = False 임의
error = False


for epoch in range(n_epochs): #200번
    for real, _ in tqdm(dataloader): #데이터갯수128만이면 128나눠서 만번돔.
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)
        
        #업데이터 디스크리미네이터 흑습
        disc_opt.zero_grad()
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        
        #궁극적으로는 제네레이터만 사용함
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()
        
        #업데이트 제네레이터 학습
        gen_opt.zero_grad()
        gen_loss = get_gen_loss ( gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        if test_generator:
          try:
            assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
            assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
          except:
            error =True
            print("런타임 테스트 실패")

        mean_discriminator_loss += disc_loss.item() /display_step

        mean_generator_loss += gen_loss.item() / display_step
        
        #500번돌때마다 출력
        if cur_step % display_step == 0 and cur_step > 0:
          print(f"step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
          fake_noise = get_noise (cur_batch_size, z_dim, device =device)
          fake = gen(fake_noise)
          show_tensor_images(fake)
          show_tensor_images(real)
          mean_generator_loss = 0
          mean_discriminator_loss = 0
        cur_step += 1


Output hidden; open in https://colab.research.google.com to view.