Source: https://github.com/adam-maj/deep-learning/tree/main

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
import os
import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset

from typing import List, Tuple

from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import save_image, make_grid

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

img_size = (128, 128)
batch_size = 256
hidden_dim = 1024
latent_dim = 256
learning_rate = 3e-4
epochs = 100

## Datset processing

Datasets available:

- HF: logo-wizard/modern-logo-dataset
- https://github.com/Wangjing1551/Logo-2k-plus-Dataset
- https://www.kaggle.com/datasets/siddharthkumarsah/logo-dataset-2341-classes-and-167140-images?resource=download
- https://www.kaggle.com/datasets/lyly99/logodet3k

In [None]:
def find_files(folder, ext: Tuple[str]):
    out_files = []

    # Walk through the directory and its subdirectories
    for root, dirs, files in os.walk(folder):
        for file in files:
            if file.lower().endswith(ext):
                out_files.append(os.path.join(root, file))

    print(f'Found {len(out_files)} image files in {folder}')
    return out_files

In [None]:
dataset_path = './dataset/datasetcopy/trainandtest/train/'
imgs = find_files(dataset_path, ext=('.jpg', '.jpeg', '.png'))

In [None]:
class LogoDataset(Dataset):
    def __init__(self, img_files: List[str], resize: Tuple[int, int], transform=None):
        self.img_files = img_files
        self.resize = resize
        self.transform = transform
        self.resize_transform = transforms.Resize(resize)

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        image = read_image(img_path).float()/255
        image = self.resize_transform(image)
        
        if self.transform:
            image = self.transform(image)
        
        return image

    def plot_img(self, idx):
        img_path = self.img_files[idx]
        image = read_image(img_path).float()/255
        image = self.resize_transform(image)
        
        if self.transform:
            image = self.transform(image)

        image = image.reshape(*self.resize, 3)

        plt.figure()
        plt.title(f'Image: {idx}')
        plt.imshow(image)

In [None]:
train_ds = LogoDataset(imgs[:-200], img_size)
test_ds = LogoDataset(imgs[-200:], img_size)

print(len(train_ds), len(test_ds))

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
    prefetch_factor=2
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
    prefetch_factor=2
)

## Modeling

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_dim // 8, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim // 8, hidden_dim // 4, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=4, stride=2, padding=1)

        self.fc1 = nn.Linear(hidden_dim * 8 * 8, latent_dim)
        self.fc2 = nn.Linear(hidden_dim * 8 * 8, latent_dim)
        
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = F.dropout2d(x, 0.5)
        
        x = x.view(x.size(0), -1)
        mean = self.fc1(x)
        log_variance = self.fc2(x)

        return mean, log_variance

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, hidden_dim * 4 * 4)

        self.deconv1 = nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(hidden_dim // 4, hidden_dim // 8, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(hidden_dim // 8, output_channels, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(output_channels, output_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), -1, 4, 4)
        
        x = F.leaky_relu(self.deconv1(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.deconv2(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.deconv3(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = F.leaky_relu(self.deconv4(x), 0.2)
        x = F.dropout2d(x, 0.5)
        x = torch.sigmoid(self.deconv5(x))
        
        return x

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, variance):
        # Separate out the randomness into the epsilon term
        epsilon = torch.randn_like(variance).to(device)

        # Now gradients can flow back through mean and variance stil
        z = mean + variance * epsilon

        return z

    def forward(self, x):
        mean, log_variance = self.Encoder(x)

        # Use the reparameterization trick to keep randomness differentiable
        z = self.reparameterization(mean, torch.exp(0.5 * log_variance))

        x_hat = self.Decoder(z)
        return x_hat, mean, log_variance

In [None]:
encoder = Encoder(input_channels=3, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_channels = 3)

model = Model(Encoder=encoder, Decoder=decoder).to(device)

In [None]:
def num_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

num_params(encoder), num_params(decoder), num_params(model)

In [None]:
if device == torch.device("cuda"):
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
    torch.set_float32_matmul_precision("high")
    
    model = torch.compile(model)
    # warmup the model
    input = torch.randn((1, 3, 128, 128), device=device)
    for _ in range(10):
        model(input)

In [None]:
def bce_loss(x, x_hat, mean, log_variance):
    # reconstruction loss encourages latents to model distribution better
    reconstruction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')

    # kl div penalizes latents from deviating too far from gaussian
    kl_divergence = - 0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())

    # both balance each other out to make a good approximation
    return reconstruction_loss + kl_divergence

optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
model.train()
all_losses = []

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, x in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(device)

        optimizer.zero_grad()

        x_hat, mean, log_variance = model(x)
        loss = bce_loss(x, x_hat, mean, log_variance)

        all_losses.append(loss)
        overall_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f"epoch {epoch + 1}: average loss {overall_loss / (batch_idx*batch_size)}")

In [None]:
plt.plot(all_losses)

## Sampling

In [None]:
model.eval()

def show_image(x, idx, figure=True):
    if figure:
      fig = plt.figure()

    x = x.view(-1, *img_size, 3)
    plt.imshow(x[idx].cpu().numpy())

def show_comparison(x, x_hat, idx):
    fig = plt.figure()
    plt.subplot(1, 2, 1)
    show_image(x, idx, False)
    plt.title("Original")
    plt.subplot(1, 2, 2)
    show_image(x_hat, idx, False)
    plt.title("Reconstruction")

In [None]:
x = next(iter(test_loader))
with torch.no_grad():
      x = x.to(device)
      x_hat, _, _ = model(x)

show_comparison(x, x_hat, 1)
show_comparison(x, x_hat, 4)

In [None]:
with torch.no_grad():
    noise = torch.randn((batch_size, latent_dim), device=device)
    generated_images = decoder(noise)

In [None]:
show_image(generated_images, idx = 6)

In [None]:
show_image(generated_images, idx = 7)