[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/xiptos/generative/blob/main/notebooks/vae_generator.ipynb)

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

In [None]:
%cd /content/gdrive/MyDrive/synthetic/

In [None]:
!git clone https://github.com/xiptos/generative.git

In [None]:
%cd generative/notebooks

# Variational Autoencoder for Image Generation

In [None]:
# based on https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
import torch
from torch.utils.data import DataLoader

from modules.dataset.pixelart_dataset import FilteredDatasetOneHot, PixelartDataset
from modules.vae.vae import PixelVAE
from vae_train import IMAGE_SIZE, LATENT_DIM, transform

from modules.dataset.pixelart_dataset import onehot2label
import numpy as np

import matplotlib.pyplot as plt

In [None]:
MODEL_FILE = '../models/vae_model_50.pth'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

dataset = FilteredDatasetOneHot(PixelartDataset(transform=transform), target_classes=[4])
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
model = PixelVAE(image_size=IMAGE_SIZE, latend_dim=LATENT_DIM).to(device)
model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
model.eval()

for pic, _ in loader:  # batch size is 1, loader is shuffled, so this gets one random pic
    pics = pic.to(device)
    break
orig = torch.clone(pics)

for _ in range(1):
    recon, mu, log = model(pics)
    pic = recon[0].view(1, 3, IMAGE_SIZE, IMAGE_SIZE)
    pics = torch.cat((pics, pic), dim=0)


In [None]:

# Mostrar imagens individualmente com título
def imshow(img, ax, title):
    img = img / 2 + 0.5  # desfaz normalização [-1,1] -> [0,1]
    npimg = img.numpy()
    ax.imshow(np.transpose(npimg, (1, 2, 0)))
    ax.set_title(title, fontsize=8)
    ax.axis('off')

# Definir layout
batch_size = pics.size(0)
cols = min(8, batch_size)
rows = (batch_size + cols - 1) // cols

fig, axs = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
axs = axs.flatten()

for i in range(batch_size):
    imshow(pics[i].detach(), axs[i], "sample")

# Remover eixos vazios se o batch for menor que o número de subplots
for j in range(i + 1, len(axs)):
    axs[j].axis('off')

plt.tight_layout()
plt.show()