In [None]:
#!pip install jcopdl
#!pip install gdown 

In [None]:
# Download Datasets
!gdown https://drive.google.com/uc?id=12DT5Px7FQV7gEcyGWvKb5aZQW2ZptSP
!unzip /content/mnist.zip

In [1]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Datasets dan Dataloader (Hanya Trainset)

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from jcopdl.layers import linear_block
from torch import nn


In [None]:
bs = 64

data_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # normalize supaya menjadi -1 sampai 1, supaya lebih stabil
])

train_set = datasets.ImageFolder("data/train/", transform=data_transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=8)

## Arsitektur dan Config

In [None]:
import torch
from torch import nn
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            linear_block(784, 512, activation="lrelu"),
            linear_block(512, 256, activation="lrelu"),
            linear_block(256, 128, activation="lrelu"),
            linear_block(128, 1, activation='sigmoid')
        )
    
    def forward(self, x):
        return self.fc(x)
    

class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.fc = nn.Sequential(
            linear_block(z_dim, 128, activation="lrelu"),
            linear_block(126, 256, activation="lrelu", batch_norm=True),
            linear_block(256, 512, activation="lrelu", batch_norm=True),
            linear_block(512, 1024, activation="lrelu", batch_norm=True),
            linear_block(1024, 784, activation="tanh")
        )
    
    def forward(self, x):
        return self.fc(x)
    
    def generate(self, n, device):
        z = torch.randn((n, self.z_dim), device=device)
        return self.fc(z)