In [1]:
import torch
import torchvision
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import MNIST # 데이터 셋셋from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # 셋 테스트 목적

<torch._C.Generator at 0x7f9f40a1f7d0>

In [2]:
nn

<module 'torch.nn' from '/usr/local/lib/python3.7/dist-packages/torch/nn/__init__.py'>

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        #Build the neural network
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim * 4),
            self.gen_block(hidden_dim * 4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim * 2, hidden_dim),
            self.gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
            
        )
    def gen_block(self, in_channel, out_channel, kernel_size= 3, stride=2, final_layer = False):
      if not final_layer:
        return nn.Sequential(
            nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride = stride),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )
      else:
        return  nn.Sequential(
            nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride),#, padding=3
            nn.Tanh(),
        )
    
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)
        
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

In [4]:
class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.crit_block(im_chan, hidden_dim),
            self.crit_block(hidden_dim, hidden_dim*2),
            self.crit_block(hidden_dim*2, 1, final_layer=True),
        )

    def crit_block(self, in_channel, out_channel, kernel_size= 4, stride=2, final_layer = False):
      if not final_layer:
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2, inplace=True),
        )
      else: # Final Layer
        return  nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride),
        )
        
    def forward(self, image):
      crit_pred = self.crit(image)
      return crit_pred.view(len(crit_pred),-1)

In [5]:
def get_noise(n_samples, z_dim, device='cuda'):
    return torch.randn(n_samples, z_dim, device=device)

In [6]:
criterion = nn.BCEWithLogitsLoss()
n_epochs= 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002 #만약 디스크리미네이터는 감소하고 제네레이터 로스가 증가한다면 이 값을 줄이는 방법. 스케일 10배씩 줄여가도 됨.

beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'
c_lambda = 100
crit_repeats = 10


In [7]:
transform = transforms.Compose ([
              transforms.ToTensor(),
              transforms.Normalize((0.5,), (0.5,)),
            ])
dataloader = DataLoader(
  MNIST('.', download=True, transform=transform),
  batch_size = batch_size,
  shuffle = True)

In [8]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02) #수정1 0.2-> 0.02
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02) #수정2 0.2-> 0.02
    torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

In [9]:
def get_gradient(crit, real, fake, epsilon):
  mixed_images = real * epsilon + fake * (1 - epsilon)

  mixed_scores = crit(mixed_images)

  gradient = torch.autograd.grad(
      inputs = mixed_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
  ) [0]
  return gradient

In [10]:
def gradient_penalty(gradient):
  gradient = gradient.view(len(gradient), -1)
  gradient_norm = gradient.norm(2, dim=1)

  penalty = torch.mean((gradient_norm -1)**2)
  return penalty

In [11]:
def get_gen_loss(crit_fake_pred):
  gen_loss = -1. * torch.mean(crit_fake_pred)
  return gen_loss

def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
  crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda* gp
  return crit_loss

In [12]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    #plt.imshow( image_grid.permute(1,2,0).squeeze()* 255)
    plt.show()

In [None]:
import matplotlib.pyplot as plt


cur_step = 0
generator_losses = []
critic_losses =[]
for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_batch_size = len(real)
    real = real.to(device)
    mean_iteration_critic_loss = 0
    for _ in range (crit_repeats): #업데이트 크리틱을 더 여러번 학습. 그래야 제네레이터도 학습하기 때문
      crit_opt.zero_grad()
      fake_noise = get_noise(cur_batch_size, z_dim, device= device)
      fake = gen(fake_noise)
      crit_fake_pred = crit(fake.detach()) # D(G(z))
      crit_real_pred = crit(real)

      epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
      gradient = get_gradient(crit, real, fake.detach(), epsilon)#0, epsilon
      gp = gradient_penalty(gradient)
      crit_loss = get_crit_loss( crit_fake_pred, crit_real_pred, gp, c_lambda)

      mean_iteration_critic_loss += crit_loss.item() / crit_repeats
      crit_loss.backward(retain_graph=True)
      crit_opt.step()
    
    critic_losses +=[mean_iteration_critic_loss]

    gen_opt.zero_grad()
    fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
    fake_2 = gen (fake_noise_2)
    crit_fake_pred = crit(fake_2)

    gen_loss = get_gen_loss (crit_fake_pred)
    gen_loss.backward()

    gen_opt.step()

    generator_losses += [gen_loss.item()]


    if cur_step % display_step == 0 and cur_step > 0:
        gen_mean = sum(generator_losses[-display_step:]) / display_step
        crit_mean = sum(critic_losses[-display_step:]) / display_step
        print(f"step {cur_step}: Generator loss: {gen_mean}, Critic loss: {crit_loss}")
        show_tensor_images(fake)#, size=(1,34,34)
        show_tensor_images(real)
        step_bins = 20
        num_examples = (len(generator_losses) // step_bins) * step_bins
        plt.plot (
            range(num_examples // step_bins),
            torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
            label="Generator loss"
        )
        plt.plot (
            range(num_examples // step_bins),
            torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
            label="Critic loss"
        )
        plt.legend()
        plt.show()
    cur_step += 1

  