In [1]:
# !pip install gdown jcopdl
# !gdown https://drive.google.com/uc?id=12DT5Px7FQV7gZEcygWvKb5aZQw2ZprSP
# !unzip /content/mnist.zip

In [2]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
 
torch.cuda.is_available()

True

# Dataset dan Dataloader (hanya train)

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [5]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_set = datasets.ImageFolder("data/train", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True)

In [6]:
%%writefile model_cgan.py
import torch
from torch import nn
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.fc = nn.Sequential(
            linear_block(784 + n_classes, 512, activation='lrelu'),
            linear_block(512, 256, activation='lrelu'),
            linear_block(256, 128, activation='lrelu'),
            linear_block(128, 1, activation='sigmoid'),
        )
    def forward(self, x, y):
        x = self.flatten(x)
        y = self.embed_label(y)
        x = torch.cat([x, y], dim=1)
        return self.fc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, n_classes):
        super().__init__()
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.z_dim = z_dim
        self.fc = nn.Sequential(
            linear_block(self.z_dim + n_classes, 128, activation='lrelu'),
            linear_block(128, 256, activation='lrelu', batch_norm=True),
            linear_block(256, 512, activation='lrelu', batch_norm=True),
            linear_block(512, 1024, activation='lrelu', batch_norm=True),
            linear_block(1024, 784, activation='tanh')
        )
    def forward(self, x, y):
        y = self.embed_label(y)
        x = torch.cat([x, y], dim=1)
        return self.fc(x)
    def generate(self, labels, device):
        z = torch.randn((len(labels), self.z_dim), device=device)
        return self.forward(z, labels)

Overwriting model_cgan.py


In [7]:
config = set_config({
    'z_dim': 100,
    'batch_size':bs,
    'n_classes': len(train_set.classes)
})

# training Preparation -> MCO

In [8]:
from model_cgan import Discriminator, Generator

In [9]:
D = Discriminator(config.n_classes).to(device)
G = Generator(config.z_dim, config.n_classes).to(device)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

# Training

In [10]:
import os
from torchvision.utils import save_image
os.makedirs("output/CGAN/", exist_ok=True)
os.makedirs("model/CGAN/", exist_ok=True)

In [11]:
max_epochs = 1000
fix_labels = torch.randint(10, (64,), device=device)
for epoch in range(max_epochs):
    D.train()
    G.train()
    for real_img, labels in trainloader:
        n_data = real_img.shape[0]

        # Real dan Fake Images
        real_img, labels = real_img.to(device), labels.to(device)
        fake_image = G.generate(labels, device)
        # Real dan Fake Labels
        real = torch.ones((n_data, 1), device=device)
        fake = torch.zeros((n_data, 1), device=device)
        # Training Discriminator
        d_optimizer.zero_grad()
        ## Real Image -> Discriminator -> Label Real
        output = D(real_img, labels)
        d_real_loss = criterion(output, real)
        ## Fake Image -> Discriminator -> Label Fake
        output = D(fake_image.detach(), labels)
        d_fake_loss = criterion(output, fake)

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        # Training Generator
        g_optimizer.zero_grad()
        ## Fake Image -> Discriminator -> tapi Label Real
        output = D(fake_image, labels)
        g_loss = criterion(output, real)
        g_loss.backward()
        g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f"Epoch {epoch:5} : | D_loss : {d_loss/2:5f} | G_loss : {g_loss:5f}")
    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_image = G.generate(fix_labels, device=device)
        save_image(fake_image.view(-1, 1, 28, 28), f"output/CGAN/{epoch}.jpg", nrow=8, normalize=True)

        torch.save(D, "model/CGAN/discriminator.pth")
        torch.save(G, "model/CGAN/generator.pth")
        
        

Epoch     0 : | D_loss : 0.662523 | G_loss : 0.868342
Epoch     5 : | D_loss : 0.009376 | G_loss : 17.082020
Epoch    10 : | D_loss : 0.003385 | G_loss : 8.926982
Epoch    15 : | D_loss : 0.122001 | G_loss : 7.802656
Epoch    20 : | D_loss : 0.054513 | G_loss : 7.528009
Epoch    25 : | D_loss : 0.059280 | G_loss : 5.883502
Epoch    30 : | D_loss : 0.015263 | G_loss : 7.156335
Epoch    35 : | D_loss : 0.057760 | G_loss : 7.544318
Epoch    40 : | D_loss : 0.016073 | G_loss : 9.023596
Epoch    45 : | D_loss : 0.027810 | G_loss : 9.374855
Epoch    50 : | D_loss : 0.052018 | G_loss : 6.438066
Epoch    55 : | D_loss : 0.016761 | G_loss : 9.587646
Epoch    60 : | D_loss : 0.059611 | G_loss : 8.374855
Epoch    65 : | D_loss : 0.170939 | G_loss : 4.949499
Epoch    70 : | D_loss : 0.141104 | G_loss : 6.224888
Epoch    75 : | D_loss : 0.007346 | G_loss : 9.139545
Epoch    80 : | D_loss : 0.175620 | G_loss : 5.410927
Epoch    85 : | D_loss : 0.155877 | G_loss : 4.198637
Epoch    90 : | D_loss : 0.