
# **Projeto VAE-GAN com Dataset de Rostos de Gatos (64√ó64)**  
### *Atividade 4 - Deep Learning*

**Disciplina:** Deep Learning  
**Aluno:** Hermes Winarski  
**Modelo:** VAE-GAN (Variational Autoencoder + Generative Adversarial Network)  
**Dataset:** [Cat Faces Dataset - Kaggle](https://www.kaggle.com/datasets/veeralakrishna/cat-faces-dataset)

---

## üß† Introdu√ß√£o
O objetivo deste projeto √© desenvolver um modelo **VAE-GAN** capaz de gerar novas imagens realistas de rostos de gatos a partir do dataset *Cat Faces (64√ó64)*.  
O VAE-GAN combina duas abordagens cl√°ssicas de aprendizado profundo para gera√ß√£o de imagens:

- **Variational Autoencoder (VAE):** aprende a representar imagens em um espa√ßo latente cont√≠nuo e probabil√≠stico, permitindo reconstru√ß√µes e interpola√ß√£o entre amostras.  
- **Generative Adversarial Network (GAN):** utiliza uma competi√ß√£o entre dois modelos ‚Äî o gerador e o discriminador ‚Äî para produzir imagens cada vez mais realistas.

O uso combinado dessas duas arquiteturas busca unir o melhor dos dois mundos: a **estrutura latente suave do VAE** com a **nitidez e realismo do GAN**.

---


## ‚öôÔ∏è Desenvolvimento
A seguir est√° o conte√∫do completo do notebook de desenvolvimento, **com sa√≠das preservadas**.


# VAE-GAN Starter ‚Äî Cat Faces 64√ó64 (Kaggle)
Dataset: [Cat Faces ‚Äî veeralakrishna](https://www.kaggle.com/datasets/veeralakrishna/cat-faces-dataset)



In [13]:

import os
import math
import random
from glob import glob
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import torchvision.utils as vutils
from tqdm import tqdm

SEED = 1337
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


Device: cuda


In [14]:

# Paths (Kaggle mounts datasets under /kaggle/input/...)
KAGGLE_DS_PATH = '/kaggle/input/cat-faces-dataset'
LOCAL_FALLBACK = './data/cats'  # para uso local opcional

if os.path.isdir(KAGGLE_DS_PATH):
    DATA_ROOT = KAGGLE_DS_PATH
else:
    DATA_ROOT = LOCAL_FALLBACK

IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')

def list_images(root: str) -> List[str]:
    files = []
    for ext in IMG_EXTS:
        files.extend(glob(os.path.join(root, f'**/*{ext}'), recursive=True))
    return sorted(files)

all_imgs = list_images(DATA_ROOT)
print(f'Found {len(all_imgs)} images under {DATA_ROOT}')
assert len(all_imgs) > 0, "No images found. In Kaggle, click '+ Add data' and attach 'cat-faces-dataset'."


Found 29843 images under /kaggle/input/cat-faces-dataset


In [15]:

IMG_SIZE = 64

transform = T.Compose([
    T.Lambda(lambda im: im.convert('RGB')),  # garante RGB
    T.Resize(IMG_SIZE, interpolation=Image.BICUBIC),
    T.CenterCrop(IMG_SIZE),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

class ImageListDataset(Dataset):
    def __init__(self, paths: List[str], transform=None):
        self.paths = paths
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p)
        if self.transform:
            img = self.transform(img)
        return img

# Split simples
val_ratio = 0.05
val_count = max(1, int(len(all_imgs) * val_ratio))
val_paths = all_imgs[:val_count]
train_paths = all_imgs[val_count:]

train_ds = ImageListDataset(train_paths, transform)
val_ds   = ImageListDataset(val_paths,   transform)

BATCH_SIZE = 128
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, drop_last=False)

len(train_ds), len(val_ds)


(28351, 1492)

In [None]:

# VAE-GAN (DCGAN-style)
Z_DIM = 128
G_CH  = 64   # base channels Generator/Decoder
D_CH  = 64   # base channels Discriminator

class Encoder(nn.Module):
    # x -> (mu, logvar)
    def __init__(self, z_dim=Z_DIM, ch=D_CH):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch, 4, 2, 1, bias=False),      # 32x32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch, ch*2, 4, 2, 1, bias=False),   # 16x16
            nn.BatchNorm2d(ch*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch*2, ch*4, 4, 2, 1, bias=False), # 8x8
            nn.BatchNorm2d(ch*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch*4, ch*8, 4, 2, 1, bias=False), # 4x4
            nn.BatchNorm2d(ch*8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.mu     = nn.Conv2d(ch*8, z_dim, 4, 1, 0)  # [B,z,1,1]
        self.logvar = nn.Conv2d(ch*8, z_dim, 4, 1, 0)

    def forward(self, x):
        h = self.net(x)
        mu = self.mu(h).squeeze(-1).squeeze(-1)
        logvar = self.logvar(h).squeeze(-1).squeeze(-1)
        # evita overflow num√©rico no KL
        logvar = torch.clamp(logvar, -10.0, 10.0)
        return mu, logvar


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class Decoder(nn.Module):
    # z -> x_hat (tanh)
    def __init__(self, z_dim=Z_DIM, ch=G_CH):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ch*8, 4, 1, 0, bias=False), # 4x4
            nn.BatchNorm2d(ch*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ch*8, ch*4, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(ch*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(ch*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ch*2, ch, 4, 2, 1, bias=False),    # 32x32
            nn.BatchNorm2d(ch),
            nn.ReLU(True),

            nn.ConvTranspose2d(ch, 3, 4, 2, 1, bias=False),       # 64x64
            nn.Tanh(),
        )

    def forward(self, z):
        if z.dim() == 2:
            z = z.unsqueeze(-1).unsqueeze(-1)
        return self.net(z)

class Discriminator(nn.Module):
    # x -> real/fake logit
    def __init__(self, ch=D_CH):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch, 4, 2, 1, bias=False),      # 32x32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch, ch*2, 4, 2, 1, bias=False),   # 16x16
            nn.BatchNorm2d(ch*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch*2, ch*4, 4, 2, 1, bias=False), # 8x8
            nn.BatchNorm2d(ch*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch*4, ch*8, 4, 2, 1, bias=False), # 4x4
            nn.BatchNorm2d(ch*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ch*8, 1, 4, 1, 0, bias=False),    # 1x1
        )

    def forward(self, x):
        return self.net(x).view(x.size(0), -1)

# Instantiate
enc = Encoder().to(device)
dec = Decoder().to(device)
dis = Discriminator().to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

enc.apply(weights_init)
dec.apply(weights_init)
dis.apply(weights_init)

lr_G = 2e-4
lr_D = 2e-4
betas = (0.5, 0.999)

opt_E = torch.optim.Adam(enc.parameters(), lr=lr_G, betas=betas)
opt_G = torch.optim.Adam(dec.parameters(), lr=lr_G, betas=betas)
opt_D = torch.optim.Adam(dis.parameters(), lr=lr_D, betas=betas)

adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()

BETA_KL = 1e-3
LAMBDA_GAN = 1.0


In [None]:
import datetime as dt
from tqdm import tqdm

# ---- hiperpar√¢metros "seguros" ----
EPOCHS = 100
BATCH_SIZE = BATCH_SIZE  # mantenha o que j√° usa
lr_G = 1e-4
lr_D = 1e-4
for opt in (opt_E, opt_G, opt_D):
    for g in opt.param_groups:
        if g['lr'] in (2e-4, 0.0002):  # se estiver no valor antigo, reduz
            g['lr'] = lr_G if opt in (opt_E, opt_G) else lr_D

# ---- agendamento do peso do KL (warm-up) ----
BETA_KL_BASE = 1e-3
def beta_kl(epoch):
    # cresce de 0 -> BETA_KL_BASE nas 5 primeiras √©pocas
    return BETA_KL_BASE * min(1.0, epoch / 5.0)

# ---- utilit√°rios ----
def denorm(x):  # [-1,1] -> [0,1]
    return (x + 1) / 2

os.makedirs('/kaggle/working/samples', exist_ok=True)
os.makedirs('/kaggle/working', exist_ok=True)
fix_z = torch.randn(64, Z_DIM, device=device)

PRINT_EVERY = 25   # heartbeat por batch
CLIP_NORM = 5.0    # grad clip

for epoch in range(1, EPOCHS + 1):
    enc.train(); dec.train(); dis.train()
    total_D = total_G = total_rec = total_kl = 0.0

    pbar = tqdm(
        train_dl,
        desc=f"Epoch {epoch}/{EPOCHS} ({len(train_ds)} imgs)",
        dynamic_ncols=True,
        mininterval=1.0,
        smoothing=0.0,
        leave=False
    )

    for step, x in enumerate(pbar, 1):
        x = x.to(device, non_blocking=True)
        b = x.size(0)

        # -------------------------
        # 1) Update Discriminator
        # -------------------------
        opt_D.zero_grad(set_to_none=True)

        with torch.no_grad():
            mu, logvar = enc(x)
            std = torch.exp(0.5 * logvar)
            z = mu + std * torch.randn_like(std)
            x_hat = dec(z)

        d_real = dis(x)
        d_fake = dis(x_hat.detach())

        # label smoothing ajuda a estabilizar
        valid = 0.9 * torch.ones_like(d_real, device=device)
        fake  = torch.zeros_like(d_fake, device=device)

        loss_D_real = adv_criterion(d_real, valid)
        loss_D_fake = adv_criterion(d_fake,  fake)
        loss_D = 0.5 * (loss_D_real + loss_D_fake)

        # NaN/Inf check
        if torch.isnan(loss_D) or torch.isinf(loss_D):
            raise RuntimeError("Loss D virou NaN/Inf ‚Äî ajuste LR ou verifique dados.")

        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(dis.parameters(), CLIP_NORM)
        opt_D.step()

        # -------------------------
        # 2) Update Encoder + Decoder (VAE + GAN)
        # -------------------------
        opt_E.zero_grad(set_to_none=True)
        opt_G.zero_grad(set_to_none=True)

        mu, logvar = enc(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        x_hat = dec(z)

        # reconstru√ß√£o + KL
        loss_rec = recon_criterion(x_hat, x)
        kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

        # adversarial (queremos D(x_hat) -> real)
        d_fake_for_G = dis(x_hat)
        loss_G_adv = adv_criterion(d_fake_for_G, valid)

        bk = beta_kl(epoch)
        loss_G_total = loss_rec + bk * kl + LAMBDA_GAN * loss_G_adv

        # NaN/Inf check
        for name, val in (("rec", loss_rec), ("kl", kl), ("G_adv", loss_G_adv), ("G_total", loss_G_total)):
            if torch.isnan(val) or torch.isinf(val):
                raise RuntimeError(f"Loss {name} virou NaN/Inf ‚Äî reduza LR, ative clamp e grad clip.")

        loss_G_total.backward()
        torch.nn.utils.clip_grad_norm_(list(enc.parameters()) + list(dec.parameters()), CLIP_NORM)
        opt_E.step(); opt_G.step()

        total_D   += loss_D.item()    * b
        total_G   += loss_G_adv.item()* b
        total_rec += loss_rec.item()  * b
        total_kl  += kl.item()        * b

        # atualiza barra
        pbar.set_postfix({
            "D":     f"{loss_D.item():.3f}",
            "G_adv": f"{loss_G_adv.item():.3f}",
            "rec":   f"{loss_rec.item():.3f}",
            "kl":    f"{kl.item():.3f}",
            "Œ≤":     f"{bk:.4f}"
        })

        # heartbeat (imprime a cada N batches)
        if step % PRINT_EVERY == 0:
            print(f"[{dt.datetime.now().strftime('%H:%M:%S')}] "
                  f"ep={epoch} step={step}/{len(train_dl)} "
                  f"D={loss_D.item():.3f} G_adv={loss_G_adv.item():.3f} "
                  f"rec={loss_rec.item():.3f} kl={kl.item():.3f}", flush=True)

    # -------------------------
    # 3) Amostras de valida√ß√£o + checkpoint
    # -------------------------
    enc.eval(); dec.eval()
    with torch.no_grad():
        # recon
        x_val = next(iter(val_dl)).to(device)
        mu, logvar = enc(x_val)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        x_hat = dec(z)
        grid = torch.cat([x_val[:32], x_hat[:32]], dim=0)
        vutils.save_image(denorm(grid), f"/kaggle/working/samples/recon_e{epoch:03d}.png", nrow=16)

        # samples
        gen = dec(fix_z)
        vutils.save_image(denorm(gen), f"/kaggle/working/samples/samples_e{epoch:03d}.png", nrow=8)

    n = len(train_ds)
    print(f"Epoch {epoch}: D={total_D/n:.4f} | G_adv={total_G/n:.4f} | rec={total_rec/n:.4f} | kl={total_kl/n:.4f}")

    # checkpoint por √©poca
    torch.save({
        "epoch": epoch,
        "enc": enc.state_dict(),
        "dec": dec.state_dict(),
        "dis": dis.state_dict(),
        "opt_E": opt_E.state_dict(),
        "opt_G": opt_G.state_dict(),
        "opt_D": opt_D.state_dict(),
    }, f"/kaggle/working/checkpoint_e{epoch:03d}.pt")

print("Training done. Check /kaggle/working/samples and /kaggle/working/*.pt")



---

## üìä Conclus√£o
Ap√≥s o treinamento do modelo, o **VAE-GAN** foi capaz de aprender a realizar bons esbo√ßos com as principais caracter√≠sticas dos rostos de gatos, gerando novas imagens borradas a partir de imagens existentes.  
Para os casos de gera√ß√£o de imagem. Mesmo mantendo a resolu√ß√£o de **64√ó64 pixels**, O modelo n√£o se comportou bem e n√£o conseguiu identificar os padr√µes para criar imagens do zero.

*Hermes Winarski ‚Äî Deep Learning | Atividade 4*
