In [1]:
# !pip install jcopdl gdown
# !gdown https://drive.google.com/uc?id=1KaiwyyYRGW8FbvSd4Feg1i1YW2k2s30u
# !unzip /content/celebA_redux.zip

In [2]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Dataset & Dataloader (Hanya Train set)

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
bs = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # menjadi (-1, 1)
])

train_set = datasets.ImageFolder("celebA_redux/celebA_redux/", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

# Arsitektur & Config

In [5]:
import jcopdl
jcopdl.__version__

'1.1.10'

In [6]:
!pip uninstall jcopdl -y

Found existing installation: jcopdl 1.1.10
Uninstalling jcopdl-1.1.10:
  Successfully uninstalled jcopdl-1.1.10


In [7]:
!pip install --upgrade jcopdl==1.1.10

Collecting jcopdl==1.1.10
  Using cached jcopdl-1.1.10-py2.py3-none-any.whl
Installing collected packages: jcopdl
Successfully installed jcopdl-1.1.10


In [8]:
%%writefile model_wdcgan.py
import torch
from torch import nn
from jcopdl.layers import conv_block, tconv_block, linear_block

def conv(c_in, c_out, batch_norm=True, activation="lrelu"):
    return conv_block(c_in, c_out, kernel=4, stride=2, pad=1, bias=False, batch_norm=batch_norm, activation=activation, pool_type=None)

def tconv(c_in, c_out, batch_norm=True, activation="lrelu"):
    return tconv_block(c_in, c_out, kernel=4, stride=2, pad=1, bias=False, batch_norm=batch_norm, activation=activation, pool_type=None)  


class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            conv(3, 32, batch_norm=False),          
            conv(32, 64),
            conv(64, 128),
            conv(128, 256),
            conv_block(256, 1, kernel=4, stride=1, pad=0, bias=False, activation=None, pool_type=None),
            nn.Flatten()
        )

    def forward(self, x):
        x = self.conv(x)
        return x
    
    def clip_weights(self, vmin=-0.01, vmax=0.01):
        for p in self.parameters():
            p.data.clamp_(vmin, vmax)    


class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.tconv = nn.Sequential(
            tconv_block(z_dim, 512, kernel=4, stride=2, pad=1, bias=False, activation="lrelu", pool_type=None),
            tconv(512, 256),
            tconv(256, 128),
            tconv(128, 64),
            tconv(64, 32),
            tconv(32, 3, activation="tanh", batch_norm=False)
        )
        
    def forward(self, x):
        return self.tconv(x)

    def generate(self, n, device):
        z = torch.randn((n, self.z_dim, 1, 1), device=device)
        return self.tconv(z)

Overwriting model_wdcgan.py


In [9]:
config = set_config({
    "z_dim": 100,
    "batch_size": bs
})

# Training Preparation -> MCOC

In [10]:
from model_wdcgan import Critic, Generator

In [11]:
def wasserstein_loss(output, target):
    return output.mean() * target.mean()

In [12]:
D = Critic().to(device)
G = Generator(config.z_dim).to(device)

criterion = wasserstein_loss

d_optimizer = optim.RMSprop(D.parameters(), lr=1e-4)
g_optimizer = optim.RMSprop(G.parameters(), lr=1e-4)

# Training

In [13]:
# !rm -rf /content/output

In [14]:
import os
from torchvision.utils import save_image
from tqdm.auto import tqdm

os.makedirs("output/WDCGAN/", exist_ok=True)
os.makedirs("model/WDCGAN/", exist_ok=True)

In [15]:
max_epochs = 1000
for epoch in range(max_epochs):
    D.train()
    G.train()
    for i, (real_img, _) in enumerate(trainloader):
        n_data = real_img.shape[0]
        ## Real and Fake Images
        real_img = real_img.to(device)
        fake_img = G.generate(n_data, device)

        ## Real and Fake Labels
        real = -torch.ones((n_data, 1), device=device)
        fake = torch.ones((n_data, 1), device=device)

        ## Training Discriminator ##
        d_optimizer.zero_grad()
        # Real image -> Discriminator -> label Real
        output = D(real_img)
        d_real_loss = criterion(output, real)
        
        # Fake image -> Discriminator -> label Fake
        output = D(fake_img.detach())
        d_fake_loss = criterion(output, fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Weight clipping
        D.clip_weights()

        if i % 5 == 0:
            ## Training Generator ##
            g_optimizer.zero_grad()
            # Fake image -> Discriminator -> label Real
            output = D(fake_img)
            g_loss = criterion(output, real)        
            g_loss.backward()
            g_optimizer.step()

    
    if epoch % 5 == 0:
        print(f"Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}")

    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_img = G.generate(64, device)
        save_image(fake_img, f"output/WDCGAN/{epoch}.jpg", nrow=8, normalize=True)
        
        torch.save(D, "model/WDCGAN/critic.pth")
        torch.save(G, "model/WDCGAN/generator.pth")

Epoch:     0 | D_loss: -0.19008 | G_loss: 0.20854
Epoch:     5 | D_loss: -0.25356 | G_loss: 0.29609
Epoch:    10 | D_loss: -0.27076 | G_loss: 0.31270
Epoch:    15 | D_loss: -0.22399 | G_loss: 0.22184
Epoch:    20 | D_loss: -0.16483 | G_loss: 0.10128
Epoch:    25 | D_loss: -0.20062 | G_loss: 0.26080
Epoch:    30 | D_loss: -0.18804 | G_loss: 0.26900
Epoch:    35 | D_loss: -0.14951 | G_loss: 0.09250
Epoch:    40 | D_loss: -0.11740 | G_loss: 0.08126
Epoch:    45 | D_loss: -0.15283 | G_loss: 0.09500
Epoch:    50 | D_loss: -0.12296 | G_loss: 0.22264
Epoch:    55 | D_loss: -0.10538 | G_loss: 0.03324
Epoch:    60 | D_loss: -0.11213 | G_loss: 0.20821
Epoch:    65 | D_loss: -0.11334 | G_loss: 0.04439
Epoch:    70 | D_loss: -0.11902 | G_loss: 0.06656
Epoch:    75 | D_loss: -0.09465 | G_loss: 0.04598
Epoch:    80 | D_loss: -0.12831 | G_loss: 0.02047
Epoch:    85 | D_loss: -0.10817 | G_loss: 0.04513
Epoch:    90 | D_loss: -0.10672 | G_loss: 0.12306
Epoch:    95 | D_loss: -0.07999 | G_loss: 0.18913


KeyboardInterrupt: 