<a href="https://colab.research.google.com/github/stevengregori92/GenerativeAdversarialNetwork/blob/main/CGAN_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install 'jcopdl<2.0'



In [2]:
!gdown https://drive.google.com/uc?id=1x4HUS6yQYnrEmyKRsIrGfag0wZlL3LDx

Traceback (most recent call last):
  File "/usr/local/bin/gdown", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/gdown/cli.py", line 151, in main
    filename = download(
  File "/usr/local/lib/python3.10/dist-packages/gdown/download.py", line 155, in download
    res = sess.get(url, stream=True, verify=verify)
  File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 542, in get
    return self.request('GET', url, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 529, in request
    resp = self.send(prep, **send_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/sessions.py", line 645, in send
    r = adapter.send(request, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/requests/adapters.py", line 440, in send
    resp = conn.urlopen(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 714, in urlopen
    httplib_response = self._make_req

In [3]:
!unzip /content/mnist.zip

Archive:  /content/mnist.zip
replace data/.DS_Store? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [4]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#Dataset & Dataloader

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [6]:
bs=64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) #menjadi (-1, 1) agar stabil
])

train_set = datasets.ImageFolder('data/train/', transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle = True, num_workers=2)

#Arsitektur & Config

In [7]:
%%writefile model_cgan.py

import torch
from torch import nn, optim
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
  def __init__(self, n_classes):
    super().__init__()
    self.flatten = nn.Flatten()
    self.embed_label = nn.Embedding(n_classes, n_classes)
    self.fc = nn.Sequential(
        linear_block(784 + n_classes, 512, activation='lrelu'),
        linear_block(512, 256, activation='lrelu'),
        linear_block(256, 128, activation='lrelu'),
        linear_block(128, 1, activation='sigmoid')
    )

  def forward(self, x, y):
    x = self.flatten(x)
    y = self.embed_label(y)

    x = torch.cat([x,y], dim=1)
    return self.fc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, n_classes):
    super().__init__()
    self.z_dim = z_dim
    self.embed_label = nn.Embedding(n_classes, n_classes)
    self.fc = nn.Sequential(
        nn.Flatten(),
        linear_block(z_dim + n_classes, 128, activation='lrelu'),
        linear_block(128, 256, activation='lrelu', batch_norm=True),
        linear_block(256, 512, activation='lrelu', batch_norm=True),
        linear_block(512, 1024, activation='lrelu', batch_norm=True),
        linear_block(1024, 784, activation='tanh')
    )

  def forward(self, x, y):
    y = self.embed_label(y)
    x = torch.cat([x,y], dim=1)
    return self.fc(x)

  def generate(self, labels, device):
    z = torch.randn((len(labels), self.z_dim), device=device)
    return self.forward(z, labels)

Overwriting model_cgan.py


In [8]:
config = set_config({
    'z_dim' : 100,
    'n_classes': len(train_set.classes),
    'batch_size': bs
})

#Training Preparation

In [9]:
from model_cgan import Discriminator, Generator

In [10]:
D = Discriminator(config.n_classes).to(device)
G = Generator(config.z_dim, config.n_classes).to(device)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

#Training

In [17]:
import os
from torchvision.utils import save_image

os.makedirs('output/CGAN/', exist_ok=True)
os.makedirs('model/CGAN/', exist_ok=True)

In [18]:
max_epochs = 500
fix_labels = torch.randint(10, (64,), device=device)
for epoch in range(max_epochs):
  D.train()
  G.train()
  for real_img, labels in trainloader:
    n_data = real_img.shape[0]

    # Real and Fake Images
    real_img, labels = real_img.to(device), labels.to(device)
    fake_img = G.generate(labels, device)

    # Real and Fake Labels
    real = torch.ones((n_data, 1), device=device)
    fake = torch.zeros((n_data, 1), device=device)

    ## Training Discriminator
    d_optimizer.zero_grad()
    # Real Image -> Discriminator -> Real Label
    output = D(real_img, labels)
    d_real_loss = criterion(output, real)

    #Fake Image -> Discriminator -> Fake Label
    output = D(fake_img.detach(), labels)
    d_fake_loss = criterion(output, fake)

    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    d_optimizer.step()

    ## Training Generator
    g_optimizer.zero_grad()
    #Fake image -> Discriminator -> Tapi label Real
    output = D(fake_img, labels)
    g_loss = criterion(output, real)
    g_loss.backward()
    g_optimizer.step()

  if epoch % 5 == 0:
    print(f'Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}')

  if epoch % 15 == 0:
    G.eval()
    epoch = str(epoch).zfill(4)
    fake_img = G.generate(fix_labels, device)
    save_image(fake_img.view(-1,1,28,28), f'output/CGAN/{epoch}.jpg', nrow=8, normalize=True)

    torch.save(D, 'model/CGAN/discriminator.pth')
    torch.save(G, 'model/CGAN/generator.pth')

Epoch:     0 | D_loss: 0.00113 | G_loss: 20.71733
Epoch:     5 | D_loss: 0.27093 | G_loss: 5.02341
Epoch:    10 | D_loss: 0.27929 | G_loss: 8.95749
Epoch:    15 | D_loss: 0.27190 | G_loss: 3.34692
Epoch:    20 | D_loss: 0.08922 | G_loss: 8.21469
Epoch:    25 | D_loss: 0.36645 | G_loss: 3.86576
Epoch:    30 | D_loss: 0.16714 | G_loss: 2.41342
Epoch:    35 | D_loss: 0.20032 | G_loss: 3.17189
Epoch:    40 | D_loss: 0.10784 | G_loss: 4.08948
Epoch:    45 | D_loss: 0.42051 | G_loss: 2.87173
Epoch:    50 | D_loss: 0.48815 | G_loss: 2.71894
Epoch:    55 | D_loss: 0.18478 | G_loss: 2.09532
Epoch:    60 | D_loss: 0.65884 | G_loss: 1.38047
Epoch:    65 | D_loss: 0.39648 | G_loss: 1.92454
Epoch:    70 | D_loss: 0.20311 | G_loss: 3.15567
Epoch:    75 | D_loss: 0.22001 | G_loss: 2.84710
Epoch:    80 | D_loss: 0.29085 | G_loss: 2.12982
Epoch:    85 | D_loss: 0.52415 | G_loss: 5.04550
Epoch:    90 | D_loss: 0.13432 | G_loss: 4.03024
Epoch:    95 | D_loss: 0.21035 | G_loss: 3.18914
Epoch:   100 | D_lo

In [19]:
!zip -r model.zip /content/model

  adding: content/model/ (stored 0%)
  adding: content/model/GAN/ (stored 0%)
  adding: content/model/CGAN/ (stored 0%)
  adding: content/model/CGAN/generator.pth (deflated 8%)
  adding: content/model/CGAN/discriminator.pth (deflated 7%)
