<a href="https://colab.research.google.com/github/stevengregori92/GenerativeAdversarialNetwork/blob/main/GAN_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install 'jcopdl<2.0'

Collecting jcopdl<2.0
  Downloading jcopdl-1.1.10.tar.gz (12 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: jcopdl
  Building wheel for jcopdl (setup.py) ... [?25l[?25hdone
  Created wheel for jcopdl: filename=jcopdl-1.1.10-py2.py3-none-any.whl size=17913 sha256=adeac5e7979cc318d98847e69675fc340b18b0a9324f7713cfd386e446d05610
  Stored in directory: /root/.cache/pip/wheels/41/95/30/86345d2446be19c7d97dee789a2597bee81cfbb7b24a847f7c
Successfully built jcopdl
Installing collected packages: jcopdl
Successfully installed jcopdl-1.1.10


In [None]:
!gdown https://drive.google.com/uc?id=1x4HUS6yQYnrEmyKRsIrGfag0wZlL3LDx

Downloading...
From: https://drive.google.com/uc?id=1x4HUS6yQYnrEmyKRsIrGfag0wZlL3LDx
To: /content/mnist.zip
  0% 0.00/10.5M [00:00<?, ?B/s]100% 10.5M/10.5M [00:00<00:00, 249MB/s]


In [None]:
!unzip /content/mnist.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train/3/3_192.jpg  
  inflating: __MACOSX/data/train/3/._3_192.jpg  
  inflating: data/train/3/3_22.jpg   
  inflating: __MACOSX/data/train/3/._3_22.jpg  
  inflating: data/train/3/3_804.jpg  
  inflating: __MACOSX/data/train/3/._3_804.jpg  
  inflating: data/train/3/3_810.jpg  
  inflating: __MACOSX/data/train/3/._3_810.jpg  
  inflating: data/train/3/3_36.jpg   
  inflating: __MACOSX/data/train/3/._3_36.jpg  
  inflating: data/train/3/3_757.jpg  
  inflating: __MACOSX/data/train/3/._3_757.jpg  
  inflating: data/train/3/3_743.jpg  
  inflating: __MACOSX/data/train/3/._3_743.jpg  
  inflating: data/train/3/3_794.jpg  
  inflating: __MACOSX/data/train/3/._3_794.jpg  
  inflating: data/train/3/3_958.jpg  
  inflating: __MACOSX/data/train/3/._3_958.jpg  
  inflating: data/train/3/3_780.jpg  
  inflating: __MACOSX/data/train/3/._3_780.jpg  
  inflating: data/train/3/3_970.jpg  
  inflating: __MACOSX/data/tr

In [None]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#Dataset & Dataloader

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
bs=64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) #menjadi (-1, 1) agar stabil
])

train_set = datasets.ImageFolder('data/train/', transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle = True, num_workers=2)

#Arsitektur & Config

In [None]:
%%writefile model_gan.py

import torch
from torch import nn, optim
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc = nn.Sequential(
        nn.Flatten(),
        linear_block(784, 512, activation='lrelu'),
        linear_block(512, 256, activation='lrelu'),
        linear_block(256, 128, activation='lrelu'),
        linear_block(128, 1, activation='sigmoid')
    )

  def forward(self, x):
    return self.fc(x)

class Generator(nn.Module):
  def __init__(self, z_dim):
    super().__init__()
    self.z_dim = z_dim
    self.fc = nn.Sequential(
        nn.Flatten(),
        linear_block(z_dim, 128, activation='lrelu'),
        linear_block(128, 256, activation='lrelu', batch_norm=True),
        linear_block(256, 512, activation='lrelu', batch_norm=True),
        linear_block(512, 1024, activation='lrelu', batch_norm=True),
        linear_block(1024, 784, activation='tanh')
    )

  def forward(self, x):
    return self.fc(x)

  def generate(self, n, device):
    z = torch.randn((n, self.z_dim), device=device)
    return self.fc(z)

Writing model_gan.py


In [None]:
config = set_config({
    'z_dim' : 100,
    'batch_size': bs
})

#Training Preparation

In [None]:
from model_gan import Discriminator, Generator

In [None]:
D = Discriminator().to(device)
G = Generator(config.z_dim).to(device)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

#Training

In [None]:
import os
from torchvision.utils import save_image

os.makedirs('output/GAN/', exist_ok=True)
os.makedirs('model/GAN/', exist_ok=True)

In [None]:
max_epochs = 300
for epoch in range(max_epochs):
  D.train()
  G.train()
  for real_img, _ in trainloader:
    n_data = real_img.shape[0]

    # Real and Fake Images
    real_img = real_img.to(device)
    fake_img = G.generate(n_data, device)

    # Real and Fake Labels
    real = torch.ones((n_data, 1), device=device)
    fake = torch.zeros((n_data, 1), device=device)

    ## Training Discriminator
    d_optimizer.zero_grad()
    # Real Image -> Discriminator -> Real Label
    output = D(real_img)
    d_real_loss = criterion(output, real)

    #Fake Image -> Discriminator -> Fake Label
    output = D(fake_img.detach())
    d_fake_loss = criterion(output, fake)

    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    d_optimizer.step()

    ## Training Generator
    g_optimizer.zero_grad()
    #Fake image -> Discriminator -> Tapi label Real
    output = D(fake_img)
    g_loss = criterion(output, real)
    g_loss.backward()
    g_optimizer.step()

  if epoch % 5 == 0:
    print(f'Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}')

  if epoch % 15 == 0:
    G.eval()
    epoch = str(epoch).zfill(4)
    fake_img = G.generate(64, device)
    save_image(fake_img.view(-1,1,28,28), f'output/GAN/{epoch}.jpg', nrow=8, normalize=True)

    torch.save(D, 'model/GAN/discriminator.pth')
    torch.save(G, 'model/GAN/generator.pth')

Epoch:     0 | D_loss: 0.04508 | G_loss: 14.39408
Epoch:     5 | D_loss: 0.09857 | G_loss: 29.98952
Epoch:    10 | D_loss: 0.19364 | G_loss: 9.09700
Epoch:    15 | D_loss: 0.16300 | G_loss: 6.69965
Epoch:    20 | D_loss: 0.26321 | G_loss: 15.49264
Epoch:    25 | D_loss: 0.34067 | G_loss: 4.51578
Epoch:    30 | D_loss: 0.04816 | G_loss: 5.01847
Epoch:    35 | D_loss: 0.22363 | G_loss: 4.25126
Epoch:    40 | D_loss: 0.29632 | G_loss: 2.65158
Epoch:    45 | D_loss: 0.25267 | G_loss: 2.96413
Epoch:    50 | D_loss: 0.16989 | G_loss: 3.37862
Epoch:    55 | D_loss: 0.11424 | G_loss: 3.09119
Epoch:    60 | D_loss: 0.06904 | G_loss: 3.04672
Epoch:    65 | D_loss: 0.47720 | G_loss: 1.61693
Epoch:    70 | D_loss: 0.18788 | G_loss: 2.61418
Epoch:    75 | D_loss: 0.67085 | G_loss: 1.67277
Epoch:    80 | D_loss: 0.18359 | G_loss: 3.22647
Epoch:    85 | D_loss: 0.39887 | G_loss: 1.61781
Epoch:    90 | D_loss: 0.43340 | G_loss: 1.55041
Epoch:    95 | D_loss: 0.34658 | G_loss: 2.34088
Epoch:   100 | D_