Choose the device.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"


Download the dataset.

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))]),
)

validation_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))]),
)

batch_size = 64
training_loader = DataLoader(training_data, batch_size, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size, shuffle=True)


Build the model.

In [None]:
from go_explore.vae import VQ_VAE

model = VQ_VAE().to(device)


Define optimizer.

In [None]:
from torch import optim

optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=False)


In [None]:
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import resize


data_variance = np.var(training_data.data / 255.0)

model.train()


for epoch in range(100):
    train_res_recon_error = []
    train_res_perplexity = []
    for images, _ in training_loader:
        images = resize(images, (129, 129))
        images = images.to(device)

        vq_loss, data_recon, perplexity = model(images)
        recon_error = F.mse_loss(data_recon, images) / data_variance
        loss = recon_error + vq_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_res_recon_error.append(recon_error.item())
        train_res_perplexity.append(perplexity.item())

    print(
        "Epoch {:d}  recon_error: {:.3f}  perplexity: {:.3f}".format(
            epoch, np.mean(train_res_recon_error), np.mean(train_res_perplexity)
        )
    )


In [None]:
import matplotlib.pyplot as plt

model.eval()

input, _ = next(iter(validation_loader))
input = resize(input, (129, 129))
input = input.to(device)

_, recons, _ = model(input)


def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation="nearest")
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)


In [None]:
from torchvision.utils import make_grid

show(make_grid(recons.cpu().data) + 0.5)


In [None]:
show(make_grid(input.cpu() + 0.5))
