## Create the environment, the buffer and the VAE

Create environment.

In [1]:
import gym

env = gym.make("PrivateEye-v4")


  logger.warn(
A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]


Choose the device.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

Create a buffer and feed it.

In [None]:
import numpy as np

from stable_baselines3.common.buffers import ReplayBuffer


buffer = ReplayBuffer(10000, env.observation_space, env.action_space, device=device)

for episode in range(10):
    done = False
    obs = env.reset()
    while not done:
        action = np.array(env.action_space.sample())
        next_obs, reward, done, info = env.step(action)
        buffer.add(obs, next_obs, action, reward, done, [info])
        obs = next_obs

Create a categorical VAE.

In [None]:
from go_explore.vae import CNNCategoricalVAE

vae = CNNCategoricalVAE().to(device)

## Test the VAE

Sample a batch of observations.

In [None]:
input = buffer.sample(10).observations

Transpose, resize and psuh to the device.

In [None]:
from torchvision.transforms.functional import resize

input = input.moveaxis(-1, -3)
input = resize(input, (129, 129)).float() / 255

Build the reconstruction.

In [None]:
recons, logits = vae(input)

Visualize the result.

In [None]:
from go_explore.utils import build_image

build_image([input, recons])

Create the loss function.

In [None]:
import torch.nn.functional as F
from typing import Tuple


def loss_func(
    input: torch.Tensor, recons: torch.Tensor, logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # Reconstruction loss
    recons_loss = F.mse_loss(input, recons)

    # KL loss
    nb_classes = logits.shape[2]
    probs = F.softmax(logits, dim=2)
    latent_entropy = probs * torch.log(probs + 1e-10)
    target_entropy = probs * torch.log((1.0 / torch.tensor(nb_classes)))
    kl_loss = (latent_entropy - target_entropy).mean()

    # Total loss
    loss = recons_loss + 0.01 * kl_loss
    return loss, recons_loss, kl_loss


Test the loss function.

In [None]:
print(loss_func(input, recons, logits))

In [None]:
from torch import optim

buffer_size = 64
optimizer = optim.Adam(vae.parameters(), lr=2e-4)

# Used for visualisation at the end
test_image = buffer.sample(10).observations
test_image = test_image.moveaxis(-1, -3)
test_image = resize(test_image, (129, 129)).float() / 255
images = [test_image]

for epoch in range(5000):
    # Sample
    input = buffer.sample(10).observations
    input = input.moveaxis(-1, -3)
    input = resize(input, (129, 129)).float() / 255

    # Compute the output image
    vae.train()
    recons, logits = vae(input)

    # Compute the loss
    loss, recons_loss, kl_loss = loss_func(input, recons, logits)

    # Step the optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 1000 == 0:
        print(
            "epoch: {:5d}\tloss: {:.5f}\trecons loss: {:.5f}\tkl loss: {:.5f}".format(
                epoch, loss.item(), recons_loss.item(), kl_loss.item()
            ),
        )
        vae.eval()
        images.append(vae(test_image)[0])


Visualize the result. First row is a batch of image from the buffer, and each following row is the reconstruction during the training.

In [None]:
build_image(images)