In [2]:
# !pip install gdown jcopdl
# !gdown https://drive.google.com/uc?id=12DT5Px7FQV7gZEcygWvKb5aZQw2ZprSP
# !unzip /content/mnist.zip

In [1]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
 
torch.cuda.is_available()

True

# Dataset dan Dataloader (hanya train)

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_set = datasets.ImageFolder("data/train", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True)

In [4]:
%%writefile model_gan.py
import torch
from torch import nn
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            linear_block(784, 512, activation='lrelu'),
            linear_block(512, 256, activation='lrelu'),
            linear_block(256, 128, activation='lrelu'),
            linear_block(128, 1, activation='sigmoid'),
        )
    def forward(self, x):
        return self.fc(x)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.fc = nn.Sequential(
            linear_block(self.z_dim, 128, activation='lrelu'),
            linear_block(128, 256, activation='lrelu', batch_norm=True),
            linear_block(256, 512, activation='lrelu', batch_norm=True),
            linear_block(512, 1024, activation='lrelu', batch_norm=True),
            linear_block(1024, 784, activation='tanh')
        )
    def forward(self, x):
        return self.fc(x)
    def generate(self, n, device):
        z = torch.randn((n, self.z_dim), device=device)
        return self.fc(z)

Overwriting model_gan.py


In [5]:
config = set_config({
    'z_dim': 100,
    'batch_size':bs
})

# training Preparation -> MCO

In [6]:
from model_gan import Discriminator, Generator

In [7]:
D = Discriminator().to(device)
G = Generator(config.z_dim).to(device)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

# Training

In [8]:
import os
from torchvision.utils import save_image
os.makedirs("output/GAN/", exist_ok=True)
os.makedirs("model/GAN/", exist_ok=True)

In [9]:
max_epochs = 1000
for epoch in range(max_epochs):
    D.train()
    G.train()
    for real_img, _ in trainloader:
        n_data = real_img.shape[0]

        # Real dan Fake Images
        real_img = real_img.to(device)
        fake_image = G.generate(n_data, device)
        # Real dan Fake Labels
        real = torch.ones((n_data, 1), device=device)
        fake = torch.zeros((n_data, 1), device=device)
        # Training Discriminator
        d_optimizer.zero_grad()
        ## Real Image -> Discriminator -> Label Real
        output = D(real_img)
        d_real_loss = criterion(output, real)
        ## Fake Image -> Discriminator -> Label Fake
        output = D(fake_image.detach())
        d_fake_loss = criterion(output, fake)

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        # Training Generator
        g_optimizer.zero_grad()
        ## Fake Image -> Discriminator -> tapi Label Real
        output = D(fake_image)
        g_loss = criterion(output, real)
        g_loss.backward()
        g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f"Epoch {epoch:5} : | D_loss : {d_loss/2:5f} | G_loss : {g_loss:5f}")
    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_image = G.generate(64, device=device)
        save_image(fake_image.view(-1, 1, 28, 28), f"output/GAN/{epoch}.jpg", nrow=8, normalize=True)

        torch.save(D, "model/GAN/discriminator.pth")
        torch.save(G, "model/GAN/generator.pth")
        
        

Epoch     0 : | D_loss : 0.620034 | G_loss : 0.893318
Epoch     5 : | D_loss : 0.029480 | G_loss : 15.326870
Epoch    10 : | D_loss : 0.000914 | G_loss : 10.692772
Epoch    15 : | D_loss : 0.230609 | G_loss : 11.841581
Epoch    20 : | D_loss : 0.020662 | G_loss : 11.104354
Epoch    25 : | D_loss : 0.115729 | G_loss : 9.415239
Epoch    30 : | D_loss : 0.335380 | G_loss : 9.909477
Epoch    35 : | D_loss : 0.063014 | G_loss : 5.366125
Epoch    40 : | D_loss : 0.136325 | G_loss : 4.241593
Epoch    45 : | D_loss : 0.245341 | G_loss : 4.268900
Epoch    50 : | D_loss : 0.272129 | G_loss : 4.680780
Epoch    55 : | D_loss : 0.258448 | G_loss : 3.385092
Epoch    60 : | D_loss : 0.226616 | G_loss : 3.432909
Epoch    65 : | D_loss : 0.160086 | G_loss : 4.082263
Epoch    70 : | D_loss : 0.360658 | G_loss : 4.361314
Epoch    75 : | D_loss : 0.234619 | G_loss : 3.316228
Epoch    80 : | D_loss : 0.426729 | G_loss : 2.225915
Epoch    85 : | D_loss : 0.071605 | G_loss : 5.672840
Epoch    90 : | D_loss :