In [1]:
# !pip install gdown jcopdl
# !gdown https://drive.google.com/uc?id=12DT5Px7FQV7gZEcygWvKb5aZQw2ZprSP
# !unzip /content/mnist.zip

In [2]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
 
torch.cuda.is_available()

True

# Dataset dan Dataloader (hanya train)

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [5]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_set = datasets.ImageFolder("data/train", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True)

# Arsitekstur dan Config

### sumber [paper](https://arxiv.org/pdf/1701.07875.pdf)
```python
- diakhiri aktivasi Linear dan bukan Sigmoid<br>
  `linear_block(128, 1, activation=None)`
- diadopsi dari teori optimal transport, menggunakan wasserstein Loss<br>

def wasserstein_loss(output, target):
    return output.mean() * target.mean()
        
- Note : Fake = +1 | Real = -1
- Momentum based Optimizer kadang membuat GAN tidak stabil
- Menggunakan RMSProp, lr yang kecil (misalnya 5e-5) <br>

  d_optimizer = optim.RMSprop(D.parameters(), lr=5e-5)
  g_optimizer = optim.RMSprop(G.parameters(), lr=5e-5)
  
- Weight pada critic dibatasi misalnya [-0,01, 0,01] <br>

  def clip_weights(self, vmin=-0.01, vmax=0.01):
    for p in self.parameters():
        p.data.clamp_(vmin, vmax)
- train critic lebih banyak pada generator<br>
  if n_batch % 5 ==0:
      #train generator

In [6]:
from jcopdl.layers import linear_block

In [7]:
%%writefile model_wgan.py
import torch
from torch import nn
from jcopdl.layers import linear_block

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            linear_block(784, 512, activation='lrelu'),
            linear_block(512, 256, activation='lrelu'),
            linear_block(256, 128, activation='lrelu'),
            # linear_block(128, 1, activation=None)
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.fc(x)
    def clip_weights(self, vmin=-0.01, vmax=0.01):
        for p in self.parameters():
            p.data.clamp_(vmin, vmax)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.fc = nn.Sequential(
            linear_block(self.z_dim, 128, activation='lrelu'),
            linear_block(128, 256, activation='lrelu', batch_norm=True),
            linear_block(256, 512, activation='lrelu', batch_norm=True),
            linear_block(512, 1024, activation='lrelu', batch_norm=True),
            linear_block(1024, 784, activation='tanh')
        )
    def forward(self, x):
        return self.fc(x)
    def generate(self, n, device):
        z = torch.randn((n, self.z_dim), device=device)
        return self.fc(z)

Overwriting model_wgan.py


In [8]:
config = set_config({
    'z_dim': 100,
    'batch_size':bs
})

# training Preparation -> MCO

In [9]:
from model_wgan import Critic, Generator

In [10]:
def wasserstein_loss(output, target):
    return output.mean() * target.mean()

In [11]:
D = Critic().to(device)
G = Generator(config.z_dim).to(device)

criterion = wasserstein_loss
# criterion = nn.BCELoss()
# d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
# g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
d_optimizer = optim.RMSprop(D.parameters(), lr=5e-5)
g_optimizer = optim.RMSprop(G.parameters(), lr=5e-5)

# Training

In [12]:
import os
from torchvision.utils import save_image
os.makedirs("output/WGAN/", exist_ok=True)
os.makedirs("model/WGAN/", exist_ok=True)

In [13]:
max_epochs = 1000
for epoch in range(max_epochs):
    D.train()
    G.train()
    for i, (real_img, _) in enumerate(trainloader):
        n_data = real_img.shape[0]

        # Real dan Fake Images
        real_img = real_img.to(device)
        fake_image = G.generate(n_data, device)
        # Real dan Fake Labels
        real = - torch.ones((n_data, 1), device=device)
        fake = torch.ones((n_data, 1), device=device)
        # Training Discriminator
        d_optimizer.zero_grad()
        ## Real Image -> Discriminator -> Label Real
        output = D(real_img)
        d_real_loss = criterion(output, real)
        ## Fake Image -> Discriminator -> Label Fake
        output = D(fake_image.detach())
        d_fake_loss = criterion(output, fake)

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        D.clip_weights()
        if i % 5 == 0:
            # Training Generator
            g_optimizer.zero_grad()
            ## Fake Image -> Discriminator -> tapi Label Real
            output = D(fake_image)
            g_loss = criterion(output, real)
            g_loss.backward()
            g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f"Epoch {epoch:5} : | D_loss : {d_loss/2:5f} | G_loss : {g_loss:5f}")
    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_image = G.generate(64, device=device)
        save_image(fake_image.view(-1, 1, 28, 28), f"output/WGAN/{epoch}.jpg", nrow=8, normalize=True)

        torch.save(D, "model/WGAN/discriminator.pth")
        torch.save(G, "model/WGAN/generator.pth")
        
        

Epoch     0 : | D_loss : -0.017161 | G_loss : -0.020237
Epoch     5 : | D_loss : -0.031785 | G_loss : -1.105294
Epoch    10 : | D_loss : -0.057694 | G_loss : -1.850849
Epoch    15 : | D_loss : -0.062357 | G_loss : -1.666635
Epoch    20 : | D_loss : -0.054528 | G_loss : -1.344638
Epoch    25 : | D_loss : -0.067599 | G_loss : -1.021007
Epoch    30 : | D_loss : -0.086399 | G_loss : -0.571073
Epoch    35 : | D_loss : -0.035489 | G_loss : -0.541371
Epoch    40 : | D_loss : -0.057620 | G_loss : -0.384826
Epoch    45 : | D_loss : -0.039812 | G_loss : -0.433597
Epoch    50 : | D_loss : -0.063830 | G_loss : -0.179912
Epoch    55 : | D_loss : -0.078385 | G_loss : -0.031712
Epoch    60 : | D_loss : -0.085959 | G_loss : 0.373750
Epoch    65 : | D_loss : -0.049251 | G_loss : -0.251536
Epoch    70 : | D_loss : -0.062182 | G_loss : -0.497509
Epoch    75 : | D_loss : -0.081720 | G_loss : -0.768480
Epoch    80 : | D_loss : -0.075976 | G_loss : -0.708248
Epoch    85 : | D_loss : -0.107046 | G_loss : -0.