In [1]:
import gymnasium as gym
import gym_donkeycar

import os
from ruamel.yaml import YAML
import numpy as np
from matplotlib import pyplot as plt
import ipywidgets as widgets
from ipywidgets import HBox, VBox
from IPython.display import display
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import distributions as dist
from torch.distributions import Normal, Categorical

import torchvision
from torchvision import transforms

from tensorboard import notebook
from torch.utils.tensorboard import SummaryWriter

# Set the environment variable to suppress TensorFlow warning
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# custom classes and functions
import models
from models.blocks import ConvBlock, TransposeConvBlock, ResConvBlock, CategoricalStraightThrough
from models.mlp import MLP
from models.categorical_vae import CategoricalVAE
from models.sequential_categorical_vae import SeqCatVAE
from preprocessing import grayscale_transform as transform

torch.cuda.empty_cache()
%matplotlib inline

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


## Load Hyperparameters from YAML config

In [2]:
yaml = YAML(typ='safe')
with open("./config.yaml", "r") as file:
    config = yaml.load(file)

for key in config:
    locals()[key] = config[key]

# set dependent parameters
device = torch.device("cuda:0" if torch.cuda.is_available() and config["device"] == "cuda:0" else "cpu")
A = 3 if config["toy_env"] else 2
Z = config["num_categoricals"] * config["num_classes"]

config

{'debug': False,
 'logdir': 'logs/',
 'seed': 0,
 'precision': 32,
 'device': 'cuda:0',
 'size': [128, 128],
 'grayscale': True,
 'toy_env': False,
 'n_episodes': 5000,
 'max_episode_steps': 100,
 'env_id': 'donkey-minimonaco-track-v0',
 'max_grad_norm': 100,
 'batch_size': 8,
 'H': 512,
 'num_categoricals': 32,
 'num_classes': 32}

## Utility

In [3]:
to_np = lambda x: x.detach().cpu().numpy() 

## VAE

In [4]:
vae = CategoricalVAE().to(device)
vae.info()

Initializing encoder:
- adding ConvBlock((1, 16))                   ==> output shape: (16, 64, 64) ==> prod: 65536
- adding ConvBlock((16, 32))                   ==> output shape: (32, 32, 32) ==> prod: 32768
- adding ConvBlock((32, 64))                   ==> output shape: (64, 16, 16) ==> prod: 16384
- adding ConvBlock((64, 128))                   ==> output shape: (128, 8, 8) ==> prod: 8192
- adding ConvBlock((128, 64))                   ==> output shape: (64, 4, 4) ==> prod: 1024
- adding Flatten()
- adding Reshape: (*,1024) => (*,32,32)

Initializing decoder:
- adding Reshape: (*,1024) => (*,64,4,4)
- adding transpose ConvBlock(64, 64)                   ==> output shape: (64, 8, 8) ==> prod: 4096
- adding transpose ConvBlock(64, 128)                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding transpose ConvBlock(128, 64)                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding transpose ConvBlock(64, 32)                   ==> output shape: (32, 64, 6

## RSSM

In [5]:
class RSSM(nn.Module):
    def __init__(self):
        super(RSSM, self).__init__()
        
        # init the VAE
        self.vae = CategoricalVAE(features=H+Z)
        
        # init the RNN
        self.num_rnn_layers = 1
        self.rnn = nn.GRU(input_size=A+H+Z, hidden_size=H, num_layers=self.num_rnn_layers)
        
        # init MLPs
        self.dynamics_mlp = MLP(input_dims=H, output_dims=Z) # H -> Z
        self.reward_mlp = MLP(input_dims=H+Z, output_dims=1) # state (H+Z) -> 1
        self.continue_mlp = MLP(input_dims=H+Z, output_dims=1) # state (H+Z)->1 # add sigmoid and BinaryCE  
    
    def step(self, action, h, z):

        # concatenate the rnn_input and apply RNN to obtain the next hidden state
        rnn_input = torch.cat((action, h.view(-1, H), z), 1)
        _, h = self.rnn(rnn_input, h.view(-1, H))
        
        state = torch.cat((h.view(-1, H), z), 1)
        
        # predict the reward and continue flag
        reward_pred = rssm.reward_mlp(state)
        continue_prob = torch.sigmoid(rssm.continue_mlp(state)) # binary classification
        continue_pred = bool(continue_prob > 0.5)
        
        x_reconstruction = rssm.vae.decode(h, z)
        
        return h, reward_pred, continue_prob, continue_pred, x_reconstruction
    
    def get_losses(self,
                   x_target, x_pred, 
                   reward_target, reward_pred, 
                   continue_target, continue_prob, 
                   z_pred, z):
        
        image_loss = F.mse_loss(x_target, x_pred, reduction="mean")
        reward_loss = F.mse_loss(reward_target.squeeze(), reward_pred.squeeze(), reduction="mean")
        continue_loss = F.binary_cross_entropy(continue_prob.squeeze(), continue_target.squeeze())
        
        # DreamerV3 KL losses: regularize the posterior (z) towards the prior (z_pred)
        kld = dist.kl.kl_divergence
        
        # define the distributions with grad
        dist_z = dist.OneHotCategorical(probs=z.view(-1, num_categoricals, num_classes))
        dist_z_pred = dist.OneHotCategorical(probs=z_pred.view(-1, num_categoricals, num_classes))
        
        # define the distributions without grad
        dist_z_sg = dist.OneHotCategorical(probs=z.detach().view(-1, num_categoricals, num_classes))
        dist_z_pred_sg = dist.OneHotCategorical(probs=z_pred.detach().view(-1, num_categoricals, num_classes))

        # calculate the mean KL-divergence across the categoricals
        
        dyn_loss = torch.max(torch.tensor(1), torch.mean(kld(dist_z_sg, dist_z_pred)))
        rep_loss = torch.max(torch.tensor(1), torch.mean(kld(dist_z, dist_z_pred_sg)))
     
        # calculate the combined loss
        loss = 1.0 * (image_loss + reward_loss + continue_loss) + 0.5 * dyn_loss + 0.1 * rep_loss
        
        return {"loss": loss, "image_loss": image_loss, "reward_loss": reward_loss, 
                "continue_loss": continue_loss, "dyn_loss": dyn_loss, "rep_loss": rep_loss}

In [6]:
rssm = RSSM().to(device)

rssm_optim = optim.Adam(
    rssm.parameters(), 
    lr=1e-4,

    # l2 regularizer
    weight_decay=1e-6, 
)

value_net = MLP(input_dims=Z, output_dims=1).to(device) # state (H+Z) -> 1
policy_net = MLP(input_dims=Z, output_dims=A).to(device) # state (H+Z) -> A

Initializing encoder:
- adding ConvBlock((1, 16))                   ==> output shape: (16, 64, 64) ==> prod: 65536
- adding ConvBlock((16, 32))                   ==> output shape: (32, 32, 32) ==> prod: 32768
- adding ConvBlock((32, 64))                   ==> output shape: (64, 16, 16) ==> prod: 16384
- adding ConvBlock((64, 128))                   ==> output shape: (128, 8, 8) ==> prod: 8192
- adding ConvBlock((128, 64))                   ==> output shape: (64, 4, 4) ==> prod: 1024
- adding Flatten()
- adding Reshape: (*,1024) => (*,32,32)

Initializing decoder:
- adding Reshape: (*,1024) => (*,64,4,4)
- adding transpose ConvBlock(64, 64)                   ==> output shape: (64, 8, 8) ==> prod: 4096
- adding transpose ConvBlock(64, 128)                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding transpose ConvBlock(128, 64)                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding transpose ConvBlock(64, 32)                   ==> output shape: (32, 64, 6

## Training loop

In [None]:
""" training loop """

rssm.train()

# Create the environment
if toy_env:
    assert A==3
    env = gym.make("CarRacing-v2", max_episode_steps=100, render_mode="rgb_array") # rgb_array/human # 50 steps
else:
    assert A==2
    sim_config = {
        "exe_path" : "/home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64",
        "port" : 9091
    }
    env = gym.make(
        "GymV21Environment-v0", 
        env_id=env_id,
        max_episode_steps=max_episode_steps,
        make_kwargs={
            "conf": sim_config
        })

# Logging
log_dir = "logs/"
writer = SummaryWriter(log_dir)
notebook.start(f"--logdir={log_dir}")

episode_losses = { # for loss plots
    "episode_loss": [],
    "episode_image_loss": [],
    "episode_reward_loss": [],
    "episode_continue_loss": [],
    "episode_dyn_loss": [],
    "episode_rep_loss": [],
}

try:
    for episode in tqdm(range(n_episodes)):

        # Get the initial state
        obs, info = env.reset()

        # Reset the RNN's hidden state
        h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device) # seq_len, B, H

        # Add a new loss for the current episode and initialize it to 0
        episode_length = 0
        for key in episode_losses:
            episode_losses[key].append(torch.tensor(0, device=device, dtype=torch.float32))

        # Play one episode
        done = False
        while not done:

            x = transform(obs).view(-1, 1, 128, 128)

            """ WORLD MODEL LEARNING """

            # predict z and generate the true stochastic latent variable z with the encoder
            z_pred = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
            z_pred = F.softmax(z_pred, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
            z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)

            # apply external actor and critic nets on z
            action = policy_net(z) # Ax1 vector
            v = value_net(z)

            # predict one step using the RSSM and apply the actor-critic
            h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)

            # choose and execute an action
            next_obs, reward, terminated, truncated, info = env.step(to_np(action.squeeze()))        

            done = terminated or truncated
            obs = next_obs

            # calculate the loss
            continue_target = torch.tensor(1 - done, device=device, dtype=torch.float32)
            reward = torch.tensor(reward, device=device, dtype=torch.float32)

            # TODO: z_prior, z_posterior
            z_prior = torch.tensor(0, device=device, dtype=torch.float32)
            z_posterior = torch.tensor(0, device=device, dtype=torch.float32)

            losses = rssm.get_losses(x, x_pred, reward, reward_pred, 
                                     continue_target, continue_prob, z_pred, z)

            # Add loss for the current step to the episode loss
            episode_length += 1
            for key in losses:
                episode_losses["episode_" + key][-1] += losses[key]

        # Calculate the mean loss of the episode
        for key in episode_losses:
            episode_losses[key][-1] /= episode_length

        # update the world model at the end of an episode using the mean loss of the episode
        rssm_optim.zero_grad()
        episode_losses["episode_loss"][-1].backward()
        nn.utils.clip_grad_norm_(rssm.vae.parameters(), max_norm=100.0, norm_type=2)  
        rssm_optim.step()

        # Detach the losses to save memory and log them in TensorBoard
        for key in episode_losses:
            episode_losses[key][-1] = episode_losses[key][-1].detach().item()
            writer.add_scalar(key, episode_losses[key][-1], global_step=episode)
        
        # save original image and reconstruction
        if episode % 10 == 0:
            plt.imsave(f"reconstructions/episode_{episode}_original.png", to_np(x[0][0]), cmap="gray")
            plt.imsave(f"reconstructions/episode_{episode}_reconstruction.png", to_np(x_pred[0][0]), cmap="gray")

    env.close()

except KeyboardInterrupt:
    """ Clean handling for interrupts to stop training early """
    print("Stopping training.")
    # Delete the last loss if the training was stopped early
    # so that the list only consists of floats
    for key in episode_losses:
        if isinstance(episode_losses[key][-1], torch.Tensor):
            episode_losses[key] = episode_losses[key][:-1]

    # Close the TensorBoard writer and the gym environment
    writer.close()
    env.close()

starting DonkeyGym env
Setting default: start_delay 5.0
Setting default: max_cte 8.0
Setting default: frame_skip 1
Setting default: cam_resolution (120, 160, 3)
Setting default: log_level 20
Setting default: host localhost
Setting default: steer_limit 1.0
Setting default: throttle_min 0.0
Setting default: throttle_max 1.0
donkey subprocess started
Found path: /home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64


INFO:gym_donkeycar.core.client:connecting to localhost:9091 
  logger.warn(


loading scene mini_monaco


INFO:gym_donkeycar.envs.donkey_sim:on need car config
INFO:gym_donkeycar.envs.donkey_sim:sending car config.
INFO:gym_donkeycar.envs.donkey_sim:sim started!


  if not isinstance(terminated, (bool, np.bool8)):
  0%|▎                                                                                      | 20/5000 [02:08<8:53:47,  6.43s/it]

In [None]:
# GPU memory consumption:
# 10 steps -> 6504MiB
# 50 steps -> 19990MiB

## Plot results

In [None]:
rolling_length = max(1, int(len(episode_losses["episode_loss"])/20))

fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(4*5, 2*5))

# Iterate over the keys and plot the losses
for i, key in enumerate(episode_losses.keys()):
    row = i // 4
    col = i % 4

    axs[row, col].set_title(key)
    losses = episode_losses[key]
    losses_moving_average = (
        np.convolve(
            np.array(losses).flatten(), np.ones(rolling_length), mode="valid"
        )
        / rolling_length
    )
    axs[row, col].plot(range(len(losses)), losses, label=key)
    axs[row, col].plot(range(len(losses_moving_average)), losses_moving_average, label="moving average")
    axs[row, col].legend(loc="upper right")

plt.tight_layout()
plt.show()

## Test area