# 🧱 DCGAN - Bricks Data

In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset


In [None]:
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from PIL import Image
from rich.progress import Progress
from torchvision.transforms import v2

from notebooks.utils import display

## 0. Parameters


In [None]:
IMAGE_SIZE = 64
CHANNELS = 1
BATCH_SIZE = 128
Z_DIM = 100
EPOCHS = 300
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
LEARNING_RATE = 0.0002
NOISE_PARAM = 0.1

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0)
    print(f"Let's use CUDA ({gpu_name})")
else:
    device = torch.device('cpu')

## 1. Prepare the data


In [None]:
class LegoBrickDataset(torch.utils.data.Dataset):
    """
    LEGO brick image dataset (from Kaggle)
    URL: https://www.kaggle.com/datasets/joosthazelzet/lego-brick-images
    """

    def __init__(self, data_root='', transform=None):
        self.data_root = Path(data_root)
        self.image_files = sorted(list(self.data_root.rglob('*.png')))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image_file = self.image_files[index]
        image = Image.open(image_file)

        if self.transform is not None:
            image = self.transform(image)

        return image

In [None]:
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        v2.Grayscale(num_output_channels=CHANNELS),
        v2.Normalize(mean=[0.5], std=[0.5]),
    ]
)

train_data = LegoBrickDataset(
    data_root='./data/lego-brick-images/dataset',
    transform=transform,
)

train_dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    persistent_workers=True,
)

In [None]:
train_sample = np.array([train_data[i] for i in range(10)])
display(train_sample)

## 2. Build the GAN


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        layer1 = nn.Sequential(
            nn.Conv2d(CHANNELS, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
        )
        layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
        )
        layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
        )
        layer4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
        )
        layer5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False)

        self.net = nn.ModuleList([layer1, layer2, layer3, layer4, layer5])

    def forward(self, x):
        for layer in self.net:
            x = layer(x)

        return torch.sigmoid(x.view(-1, 1))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        layer1 = nn.Sequential(
            nn.ConvTranspose2d(Z_DIM, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
        )
        layer2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
        )
        layer3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
        )
        layer4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(0.2),
        )
        layer5 = nn.ConvTranspose2d(64, CHANNELS, kernel_size=4, stride=2, padding=1, bias=False)

        self.net = nn.ModuleList([layer1, layer2, layer3, layer4, layer5])

    def forward(self, x):
        x = x.view(-1, Z_DIM, 1, 1)
        for layer in self.net:
            x = layer(x)

        return torch.tanh(x)

In [None]:
netD = Discriminator().to(device)
netG = Generator().to(device)

optimD = torch.optim.Adam(netD.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))
optimG = torch.optim.Adam(netG.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))

## 3. Train the GAN


In [None]:
netD.train()
netG.train()

ema_d_loss = 0.0
ema_d_acc = 0.0
ema_g_loss = 0.0
ema_g_acc = 0.0

with Progress() as progress:
    train_task = progress.add_task('Training', total=len(train_dataloader) * EPOCHS)
    for epoch in range(EPOCHS):
        for real_images in train_dataloader:
            real_images = real_images.to(device)

            # Sample random points in the latent space
            batch_size = real_images.size(0)
            random_latent_vectors = torch.randn((batch_size, Z_DIM)).to(device)

            # Train the discriminator on real/fake images
            generated_images = netG(random_latent_vectors)
            real_predictions = netD(real_images)
            fake_predictions = netD(generated_images.detach())  # <- IMPORTANT: detach required

            # Label smoothing
            real_labels = torch.ones_like(real_predictions)
            real_noisy_labels = real_labels - NOISE_PARAM * torch.rand_like(real_labels)
            fake_labels = torch.zeros_like(fake_predictions)
            fake_noisy_labels = fake_labels + NOISE_PARAM * torch.rand_like(fake_labels)

            d_real_loss = F.binary_cross_entropy(real_predictions, real_noisy_labels)
            d_fake_loss = F.binary_cross_entropy(fake_predictions, fake_noisy_labels)
            d_loss = (d_real_loss + d_fake_loss) * 0.5

            optimD.zero_grad()
            d_loss.backward()
            optimD.step()

            # Train the generator on fake images
            random_latent_vectors = torch.randn((batch_size, Z_DIM)).to(device)
            generated_images = netG(random_latent_vectors)
            fake_predictions = netD(generated_images)

            g_loss = F.binary_cross_entropy(fake_predictions, real_labels)
            optimG.zero_grad()
            g_loss.backward()
            optimG.step()

            # Update metrics
            d_real_acc = (real_predictions > 0.5).float().mean()
            d_fake_acc = (fake_predictions < 0.5).float().mean()
            d_acc = (d_real_acc + d_fake_acc) * 0.5
            g_acc = (fake_predictions > 0.5).float().mean()

            ema_d_loss = 0.9 * ema_d_loss + 0.1 * d_loss.item()
            ema_d_acc = 0.9 * ema_d_acc + 0.1 * d_acc.item()
            ema_g_loss = 0.9 * ema_g_loss + 0.1 * g_loss.item()
            ema_g_acc = 0.9 * ema_g_acc + 0.1 * g_acc.item()

            # Display metrics
            metrics = (
                f'd-loss: {ema_d_loss:.4f} | '
                f'd-acc: {ema_d_acc:.4f} | '
                f'g-loss: {ema_g_loss:.4f} | '
                f'g-acc: {ema_g_acc:.4f}'
            )
            progress.update(train_task, advance=1, description=f'Epoch {epoch + 1}/{EPOCHS} | {metrics}')

In [None]:
# Save model checkpoints
checkpoint_dir = Path('./checkpoint')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

ckpt_dict = {
    'netD': netD.state_dict(),
    'netG': netG.state_dict(),
}
torch.save(ckpt_dict, checkpoint_dir / 'checkpoint.ckpt')

## 3. Generate new images


In [None]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (10, 3)
z_sample = torch.randn((grid_width * grid_height, Z_DIM)).to(device)

In [None]:
netG.eval()
with torch.no_grad():
    reconstructions = netG(z_sample).cpu().numpy()
    reconstructions = reconstructions * 0.5 + 0.5
    reconstructions = reconstructions.reshape((-1, IMAGE_SIZE, IMAGE_SIZE))

In [None]:
# Draw a plot of decoded images
fig = plt.figure(figsize=(18, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

# Output the grid of faces
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :], cmap="Greys")

In [None]:
def compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))

In [None]:
all_data = []
for i in range(len(train_data)):
    img = train_data[i].reshape((IMAGE_SIZE, IMAGE_SIZE))
    img = img * 0.5 + 0.5
    all_data.append(img)
all_data = np.stack(all_data, axis=0)

In [None]:
r, c = 3, 5
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Generated images", fontsize=20)

noise = torch.randn(r * c, Z_DIM).to(device)
with torch.no_grad():
    gen_imgs = netG(noise).cpu().numpy()

gen_imgs = (gen_imgs * 0.5 + 0.5).reshape((-1, IMAGE_SIZE, IMAGE_SIZE))

cnt = 0
for i in range(r):
    for j in range(c):
        axs[i, j].imshow(gen_imgs[cnt], cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()

In [None]:
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Closest images in the training set", fontsize=20)

cnt = 0
for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, k in enumerate(all_data):
            diff = compare_images(gen_imgs[cnt], k)
            if diff < c_diff:
                c_img = np.copy(k)
                c_diff = diff
        axs[i, j].imshow(c_img, cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()