In [None]:
# !pip install jcopdl gdown
# !gdown https://drive.google.com/uc?id=1sGZwJdbk7ZWqKEORJp6blndHX8azDXkt
# !unzip /content/mnist.zip

In [1]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Dataset & Dataloader (hanya train set)

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [3]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # menjadi (-1, 1) supaya lebih stabil
])


train_set = datasets.ImageFolder("data/train/", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

# Arsitektur & Config

In [4]:
%%writefile model_cgan.py
import torch
from torch import nn
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.fc = nn.Sequential(
            nn.Flatten(),
            linear_block(784 + n_classes, 512, activation="lrelu"),
            linear_block(512, 256, activation="lrelu"),
            linear_block(256, 128, activation="lrelu"),
            linear_block(128, 1, activation='sigmoid')
        )
    
    def forward(self, x, y):
        x = self.flatten(x)
        y = self.embed_label(y)
        x = torch.cat([x,y], dim=1)
        return self.fc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, n_classes):
        super().__init__()
        self.z_dim = z_dim
        self.embed_label = nn.Embedding(n_classes, n_classes)
        self.fc = nn.Sequential(
            linear_block(z_dim + n_classes, 128, activation="lrelu"),
            linear_block(128, 256, activation="lrelu", batch_norm=True),
            linear_block(256, 512, activation="lrelu", batch_norm=True),
            linear_block(512, 1024, activation='lrelu', batch_norm=True),
            linear_block(1024, 784, activation='tanh')
        )
    
    def forward(self, x, y):
        y = self.embed_label(y)
        x = torch.cat([x,y], dim=1)
        return self.fc(x)
    
    def generate(self, labels, device):
        z = torch.randn((len(labels), self.z_dim), device=device)
        return self.forward(z, labels)

Writing model_cgan.py


In [5]:
config = set_config({
    "z_dim": 100,
    "n_classes" : len(train_set.classes),
    "batch_size": bs
})

# Training Preparation -> MCO

In [7]:
from model_cgan import Discriminator, Generator

In [8]:
D = Discriminator(config.n_classes).to(device)
G = Generator(config.z_dim, config.n_classes).to(device)

criterion = nn.BCELoss()

d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

# Training

In [6]:
import os

from torchvision.utils import save_image

os.makedirs("output/CGAN/", exist_ok=True)
os.makedirs("model/CGAN/", exist_ok=True)

In [None]:
max_epoch = 300
fix_labels = torch.randint(10, (64,), device=device)
for epoch in range(max_epoch):
    D.train()
    G.train()
    for real_img, labels in trainloader:
        n_data = real_img.shape[0]
        
        ## Real and Fake Images
        real_img, labels = real_img.to(device), labels.to(device)
        fake_img = G.generate(labels, device)
        
        ## Real and Fake Labels
        real = torch.ones((n_data, 1), device=device) 
        fake = torch.zeros((n_data, 1), device=device)
        
        ## Training Discriminator
        d_optimizer.zero_grad()
        # Real image -> Discriminator -> label Real
        output = D(real_img, labels)
        d_real_loss = criterion(output, real)
        
        # Fake image -> Discriminator -> label Fake
        output = D(fake_img.detach(), labels)
        d_fake_loss = criterion(output, fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        ## Training Generator
        g_optimizer.zero_grad()
        
        # Fake image -> Dicriminator -> tapi label real
        output = D(fake_img, labels)
        g_loss = criterion(output, real)
        g_loss.backward()
        g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f"Epoch {epoch:5} | D_loss {d_loss/2:.5f} | G_loss: {g_loss:.5f}")
        
    if epoch % 15 ==0:
        G.eval()
        epoch = str(epoch).zfill(4)
        fake_img = G.generate(fix_labels, device)
        save_image(fake_img.view(-1, 1, 28, 28), f"output/CGAN/{epoch}.jpg", nrow=8, normalize=True)
        
        torch.save(D, "model/CGAN/discriminator.pth")
        torch.save(G, "model/CGAN/generator.pth")

In [None]:
# !zip -r model.zip /content/model/
# !zip -r output.zip /content/output/