In [1]:
import os
from operator import itemgetter

import gym_donkeycar
import gymnasium as gym
import imageio
import ipywidgets as widgets
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from IPython.display import display
from ipywidgets import HBox, VBox
from matplotlib import pyplot as plt
from PIL import Image
from ruamel.yaml import YAML
from tensorboard import notebook
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from torch import distributions as dist
from torch.distributions import Categorical, Normal
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

# suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["IMAGEIO_IGNORE_WARNINGS"] = "True"

import gym.spaces as gym_spaces
import gymnasium as gym  # overwrite OpenAI gym
import stable_baselines3 as sb3
from gym_donkeycar.envs.donkey_env import DonkeyEnv
from gymnasium import spaces
from gymnasium.spaces import Box
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import DummyVecEnv

from src.actor_critic import ContinuousActorCritic
from src.blocks import CategoricalStraightThrough, ConvBlock
from src.categorical_vae import CategoricalVAE
from src.imagination_env import ImaginationEnv
from src.mlp import MLP
from src.preprocessing import grayscale_transform as transform
from src.rssm import RSSM
from src.utils import load_config, save_image_and_reconstruction, to_np

torch.cuda.empty_cache()
%matplotlib inline

## Load Hyperparameters from YAML config

In [2]:
config = load_config()

for key in config:
    locals()[key] = config[key]

print(config)

{'device': device(type='cuda', index=0), 'A': 3, 'Z': 1024, 'debug': False, 'logdir': 'logs/', 'seed': 0, 'size': [128, 128], 'grayscale': True, 'toy_env': True, 'n_episodes': 5000, 'max_episode_steps': 500, 'imagination_timesteps_per_model_update': 500, 'env_id': 'donkey-minimonaco-track-v0', 'max_grad_norm': 100, 'batch_size': 8, 'H': 512, 'num_categoricals': 32, 'num_classes': 32, 'mlp_n_layers': 3, 'mlp_hidden_dims': 256, 'action_clip': 1}


## Init the RSSM (including all networks)

In [3]:
rssm = RSSM().to(device)

rssm_optim = optim.Adam(
    rssm.parameters(), 
    lr=1e-4,
    weight_decay=1e-6, # l2 regularizer
)

Initializing encoder:
- adding ConvBlock((1, 32))                   ==> output shape: (32, 64, 64) ==> prod: 131072
- adding ConvBlock((32, 64))                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding ConvBlock((64, 128))                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding ConvBlock((128, 256))                   ==> output shape: (256, 8, 8) ==> prod: 16384
- adding ConvBlock((256, 64))                   ==> output shape: (64, 4, 4) ==> prod: 1024
- adding Flatten()
- adding Reshape: (*,1024) => (*,32,32)

Initializing decoder:
- adding Reshape: (*,1024) => (*,64,4,4)
- adding transpose ConvBlock(64, 64)                   ==> output shape: (64, 8, 8) ==> prod: 4096
- adding transpose ConvBlock(64, 256)                   ==> output shape: (256, 16, 16) ==> prod: 65536
- adding transpose ConvBlock(256, 128)                   ==> output shape: (128, 32, 32) ==> prod: 131072
- adding transpose ConvBlock(128, 64)                   ==> output shape: (

## Create the imagination environment for training the agent

In [4]:
imagination_env = ImaginationEnv(
    rssm, 
    device, 
    max_episode_steps=50, 
    render_mode=None
)

## Init the agent

In [14]:
agent = A2C(
    "MlpPolicy",
    imagination_env,
    verbose=1,
    tensorboard_log="logs/"
)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


## Training loop

Notes:
- currently not using a replay buffer for the initial observation in RL agent training

In [15]:
""" training loop """

rssm.train()

# Create the environment
if toy_env:
    assert A==3
    env = gym.make("CarRacing-v2", max_episode_steps=max_episode_steps, render_mode="rgb_array") # rgb_array/human # 50 steps
else:
    assert A==2
    sim_config = {
        "exe_path" : "/home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64",
        "port" : 9091
    }
    env = gym.make(
        "GymV21Environment-v0", 
        env_id=env_id,
        max_episode_steps=max_episode_steps,
        make_kwargs={
            "conf": sim_config
        })

# Logging
log_dir = "logs/"
writer = SummaryWriter(log_dir)
notebook.start(f"--logdir={log_dir}")

episode_losses = { # for loss plots
    "episode_loss": [],
    "episode_image_loss": [],
    "episode_reward_loss": [],
    "episode_continue_loss": [],
    "episode_dyn_loss": [],
    "episode_rep_loss": [],
}

try:
    for episode in tqdm(range(n_episodes)):

        # Get the initial state
        obs, info = env.reset()

        # Reset the RNN's hidden state
        h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device) # seq_len, B, H

        # Add a new loss for the current episode and initialize it to 0
        episode_length = 0
        for key in episode_losses:
            episode_losses[key].append(torch.tensor(0, device=device, dtype=torch.float32))

        # Play one episode
        done = False
        while not done:

            x = transform(obs).view(-1, 1, 128, 128)

            """ WORLD MODEL LEARNING """

            # predict z and generate the true stochastic latent variable z with the encoder
            z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
            z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
            z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)
            
            # apply the RL agent in eval mode to get an action
            state = to_np(torch.cat((h.flatten().detach(), z.flatten().detach()), dim=0))
            action, _ = agent.predict(state, deterministic=True)

            # predict one step using the RSSM
            h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)

            # take an environment step with the action
            obs, reward, terminated, truncated, info = env.step(action.squeeze())        
            done = terminated or truncated

            # calculate the loss
            continue_target = torch.tensor(1 - done, device=device, dtype=torch.float32)
            reward = torch.tensor(reward, device=device, dtype=torch.float32)
            losses = rssm.get_losses(x, x_pred, reward, reward_pred, 
                                     continue_target, continue_prob, z_prior, z)

            # Add the loss for the current step to the episode loss
            episode_length += 1
            for key in losses:
                episode_losses["episode_" + key][-1] += losses[key]

        # Calculate the mean loss of the episode
        for key in episode_losses:
            episode_losses[key][-1] /= episode_length

        # update the world model at the end of an episode using the mean loss of the episode
        rssm_optim.zero_grad()
        episode_losses["episode_loss"][-1].backward()
        nn.utils.clip_grad_norm_(rssm.vae.parameters(), max_norm=10.0, norm_type=2)  
        rssm_optim.step()

        # Detach the losses to save memory and log them in TensorBoard
        for key in episode_losses:
            episode_losses[key][-1] = episode_losses[key][-1].detach().item()
            writer.add_scalar(key, episode_losses[key][-1], global_step=episode)
        
        # save original image and reconstruction
        if episode % 10 == 0:
            save_image_and_reconstruction(x, x_pred, episode)
        
        """ RL AGENT LEARNING (IN THE WORLD MODEL) """
        agent.learn(total_timesteps=imagination_timesteps_per_model_update)

    env.close()

except KeyboardInterrupt:
    """ Clean handling for interrupts to stop training early """
    print("Stopping training.")
    # Delete the last loss if the training was stopped early
    # so that the list only consists of floats
    for key in episode_losses:
        if isinstance(episode_losses[key][-1], torch.Tensor):
            episode_losses[key] = episode_losses[key][:-1]

    # Close the TensorBoard writer and the gym environment
    writer.close()
    env.close()

  0%|                                                                                                  | 0/5000 [00:00<?, ?it/s]

Logging to logs/A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 43       |
|    ep_rew_mean        | 0.529    |
| time/                 |          |
|    fps                | 526      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.27    |
|    explained_variance | 0.836    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.0927   |
|    std                | 1        |
|    value_loss         | 0.000646 |
------------------------------------


  0%|                                                                                        | 1/5000 [00:05<8:05:12,  5.82s/it]

Logging to logs/A2C_2
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 44.5     |
|    ep_rew_mean        | -0.0141  |
| time/                 |          |
|    fps                | 537      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.26    |
|    explained_variance | 0.979    |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss        | -0.168   |
|    std                | 1        |
|    value_loss         | 0.00881  |
------------------------------------


  0%|                                                                                        | 2/5000 [00:11<7:58:24,  5.74s/it]

Logging to logs/A2C_3
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 43.5     |
|    ep_rew_mean        | -0.704   |
| time/                 |          |
|    fps                | 517      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.27    |
|    explained_variance | 0.902    |
|    learning_rate      | 0.0007   |
|    n_updates          | 299      |
|    policy_loss        | 0.354    |
|    std                | 1        |
|    value_loss         | 0.00758  |
------------------------------------


  0%|                                                                                        | 3/5000 [00:17<7:58:22,  5.74s/it]

Logging to logs/A2C_4
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 38.9     |
|    ep_rew_mean        | 0.332    |
| time/                 |          |
|    fps                | 513      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.25    |
|    explained_variance | 0.996    |
|    learning_rate      | 0.0007   |
|    n_updates          | 399      |
|    policy_loss        | 0.00995  |
|    std                | 0.997    |
|    value_loss         | 1.04e-05 |
------------------------------------


  0%|                                                                                        | 4/5000 [00:23<8:03:51,  5.81s/it]

Logging to logs/A2C_5
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 39.2     |
|    ep_rew_mean        | 1.39     |
| time/                 |          |
|    fps                | 507      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.23    |
|    explained_variance | 0.947    |
|    learning_rate      | 0.0007   |
|    n_updates          | 499      |
|    policy_loss        | 0.211    |
|    std                | 0.992    |
|    value_loss         | 0.00321  |
------------------------------------


  0%|                                                                                        | 5/5000 [00:29<8:08:08,  5.86s/it]

Logging to logs/A2C_6
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 37.6     |
|    ep_rew_mean        | 2.87     |
| time/                 |          |
|    fps                | 503      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.22    |
|    explained_variance | 0.0571   |
|    learning_rate      | 0.0007   |
|    n_updates          | 599      |
|    policy_loss        | -0.0152  |
|    std                | 0.987    |
|    value_loss         | 2.75e-05 |
------------------------------------


  0%|                                                                                        | 6/5000 [00:35<8:13:53,  5.93s/it]

Logging to logs/A2C_7
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 44.8     |
|    ep_rew_mean        | 3.07     |
| time/                 |          |
|    fps                | 498      |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.19    |
|    explained_variance | 0.989    |
|    learning_rate      | 0.0007   |
|    n_updates          | 699      |
|    policy_loss        | 0.648    |
|    std                | 0.977    |
|    value_loss         | 0.0967   |
------------------------------------


  0%|                                                                                        | 7/5000 [00:41<8:16:13,  5.96s/it]

Logging to logs/A2C_8
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 46.4     |
|    ep_rew_mean        | 2.33     |
| time/                 |          |
|    fps                | 500      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.18    |
|    explained_variance | 0.984    |
|    learning_rate      | 0.0007   |
|    n_updates          | 799      |
|    policy_loss        | 0.04     |
|    std                | 0.976    |
|    value_loss         | 9.52e-05 |
------------------------------------


  0%|▏                                                                                       | 8/5000 [00:47<8:16:34,  5.97s/it]

Logging to logs/A2C_9
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 45       |
|    ep_rew_mean        | 1.54     |
| time/                 |          |
|    fps                | 473      |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.2     |
|    explained_variance | -0.101   |
|    learning_rate      | 0.0007   |
|    n_updates          | 899      |
|    policy_loss        | -0.16    |
|    std                | 0.983    |
|    value_loss         | 0.00219  |
------------------------------------


  0%|▏                                                                                       | 9/5000 [00:53<8:22:48,  6.04s/it]

Logging to logs/A2C_10
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 44.1     |
|    ep_rew_mean        | 0.836    |
| time/                 |          |
|    fps                | 484      |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -4.2     |
|    explained_variance | 0.998    |
|    learning_rate      | 0.0007   |
|    n_updates          | 999      |
|    policy_loss        | 0.0698   |
|    std                | 0.98     |
|    value_loss         | 0.000396 |
------------------------------------


  0%|▏                                                                                      | 10/5000 [01:01<8:29:11,  6.12s/it]

Stopping training.





## Imagine

In [None]:
class ReplayBuffer():
    def __init__(self):

## Plot the results

In [None]:
rolling_length = max(1, int(len(episode_losses["episode_loss"])/20))

fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(3*5, 2*5))

# Iterate over the keys and plot the losses
for i, key in enumerate(episode_losses.keys()):
    row = i // 3
    col = i % 3

    axs[row, col].set_title(key)
    losses = episode_losses[key]
    losses_moving_average = (
        np.convolve(
            np.array(losses).flatten(), np.ones(rolling_length), mode="valid"
        )
        / rolling_length
    )
    axs[row, col].plot(range(len(losses)), losses, label=key)
    axs[row, col].plot(range(len(losses_moving_average)), losses_moving_average, label="moving average")
    axs[row, col].legend(loc="upper right")

plt.tight_layout()
plt.show()

## Showcase the trained agent

In [12]:
showcase_agent = False

if showcase_agent:
    
    imagination_env.render_mode = "gif"
    obs, info = imagination_env.reset()
    
    for i in range(max_episode_steps):
        
        # apply the RL agent in eval mode to get an action
        state = to_np(torch.cat((h.flatten().detach(), z.flatten().detach()), dim=0))
        action, _ = agent.predict(state, deterministic=True)
        
        obs, reward, terminated, truncated, info = imagination_env.step(action)
        imagination_env.render()
        
    imagination_env.close()
    imagination_env.render_mode = None

## Test area

In [None]:
# GPU memory consumption:
# 10 steps -> 6504MiB
# 50 steps -> 19990MiB

In [None]:
# !tensorboard --logdir="logs/" 