In [None]:
import gymnasium as gym
import gym_donkeycar

import os
from ruamel.yaml import YAML
import numpy as np
from matplotlib import pyplot as plt
import ipywidgets as widgets
from ipywidgets import HBox, VBox
from IPython.display import display
from tqdm import tqdm

from PIL import Image
import imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import distributions as dist
from torch.distributions import Normal, Categorical

import torchvision
from torchvision import transforms

from tensorboard import notebook
from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["IMAGEIO_IGNORE_WARNINGS"] = "True"

from networks.utils import to_np, load_config, save_image_and_reconstruction

# custom classes and functions
from networks.blocks import ConvBlock, CategoricalStraightThrough
from networks.rssm import RSSM
from networks.mlp import MLP
from networks.categorical_vae import CategoricalVAE
from networks.actor_critic import ContinuousActorCritic
from preprocessing import grayscale_transform as transform

###

import stable_baselines3 as sb3
from stable_baselines3 import SAC, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.results_plotter import load_results, ts2xy

import gym.spaces as gym_spaces
import gymnasium as gym # overwrite OpenAI gym
from gymnasium import spaces
from gymnasium.spaces import Box
from stable_baselines3.common.vec_env import DummyVecEnv
from gym_donkeycar.envs.donkey_env import DonkeyEnv
from stable_baselines3.common import env_checker

###

torch.cuda.empty_cache()
%matplotlib inline

## Load Hyperparameters from YAML config

In [12]:
config = load_config()

for key in config:
    locals()[key] = config[key]

print(config)

{'device': device(type='cuda', index=0), 'A': 3, 'Z': 1024, 'debug': False, 'logdir': 'logs/', 'seed': 0, 'size': [128, 128], 'grayscale': True, 'toy_env': True, 'n_episodes': 5000, 'max_episode_steps': 100, 'env_id': 'donkey-minimonaco-track-v0', 'max_grad_norm': 100, 'batch_size': 8, 'H': 512, 'num_categoricals': 32, 'num_classes': 32, 'mlp_n_layers': 3, 'mlp_hidden_dims': 256, 'action_clip': 1}


## Init Networks

In [13]:
rssm = RSSM().to(device)

rssm_optim = optim.Adam(
    rssm.parameters(), 
    lr=1e-4,

    # l2 regularizer
    weight_decay=1e-6, 
)

# value_net = MLP(input_dims=Z, output_dims=1).to(device) # state (H+Z) -> 1
policy_net = MLP(input_dims=Z, output_dims=A, out_type="gaussian").to(device) # state (H+Z) -> A

Initializing encoder:
- adding ConvBlock((1, 32))                   ==> output shape: (32, 64, 64) ==> prod: 131072
- adding ConvBlock((32, 64))                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding ConvBlock((64, 128))                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding ConvBlock((128, 256))                   ==> output shape: (256, 8, 8) ==> prod: 16384
- adding ConvBlock((256, 64))                   ==> output shape: (64, 4, 4) ==> prod: 1024
- adding Flatten()
- adding Reshape: (*,1024) => (*,32,32)

Initializing decoder:
- adding Reshape: (*,1024) => (*,64,4,4)
- adding transpose ConvBlock(64, 64)                   ==> output shape: (64, 8, 8) ==> prod: 4096
- adding transpose ConvBlock(64, 256)                   ==> output shape: (256, 16, 16) ==> prod: 65536
- adding transpose ConvBlock(256, 128)                   ==> output shape: (128, 32, 32) ==> prod: 131072
- adding transpose ConvBlock(128, 64)                   ==> output shape: (

In [14]:
agent = ContinuousActorCritic(
     n_features=Z, 
     n_actions=A,
     n_envs=1,
     gamma=0.999,
     lam=0.95,
     entropy_coeff=0.01,
     critic_lr=5e-4, # it's very sensitive to higher learning rates (gets nans)
     actor_lr=1e-4,
    action_clip=2
).to(device)

## Training loop

Notes:
- currently taking random actions (not the output of the actor)

In [None]:
""" training loop """

rssm.train()

# Create the environment
if toy_env:
    assert A==3
    env = gym.make("CarRacing-v2", max_episode_steps=100, render_mode="rgb_array") # rgb_array/human # 50 steps
else:
    assert A==2
    sim_config = {
        "exe_path" : "/home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64",
        "port" : 9091
    }
    env = gym.make(
        "GymV21Environment-v0", 
        env_id=env_id,
        max_episode_steps=max_episode_steps,
        make_kwargs={
            "conf": sim_config
        })

# Logging
log_dir = "logs/"
writer = SummaryWriter(log_dir)
notebook.start(f"--logdir={log_dir}")

episode_losses = { # for loss plots
    "episode_loss": [],
    "episode_image_loss": [],
    "episode_reward_loss": [],
    "episode_continue_loss": [],
    "episode_dyn_loss": [],
    "episode_rep_loss": [],
}

try:
    for episode in tqdm(range(n_episodes)):

        # Get the initial state
        obs, info = env.reset()

        # Reset the RNN's hidden state
        h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device) # seq_len, B, H

        # Add a new loss for the current episode and initialize it to 0
        episode_length = 0
        for key in episode_losses:
            episode_losses[key].append(torch.tensor(0, device=device, dtype=torch.float32))

        # Play one episode
        done = False
        while not done:

            x = transform(obs).view(-1, 1, 128, 128)

            """ WORLD MODEL LEARNING """

            # predict z and generate the true stochastic latent variable z with the encoder
            z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
            z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
            z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)

            # apply external actor and critic nets on z
            action_mean, action_var = policy_net(z)
            action = torch.normal(mean=action_mean, std=torch.sqrt(action_var)) # # Ax1 vector
            action = torch.clip(action, action_clip_min, action_clip_max)

            v = value_net(z)

            # predict one step using the RSSM
            h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)

            # choose and execute an action
            next_obs, reward, terminated, truncated, info = env.step(to_np(action.squeeze()))        

            done = terminated or truncated
            obs = next_obs

            # calculate the loss
            continue_target = torch.tensor(1 - done, device=device, dtype=torch.float32)
            reward = torch.tensor(reward, device=device, dtype=torch.float32)
            losses = rssm.get_losses(x, x_pred, reward, reward_pred, 
                                     continue_target, continue_prob, z_prior, z)

            # Add loss for the current step to the episode loss
            episode_length += 1
            for key in losses:
                episode_losses["episode_" + key][-1] += losses[key]

        # Calculate the mean loss of the episode
        for key in episode_losses:
            episode_losses[key][-1] /= episode_length

        # update the world model at the end of an episode using the mean loss of the episode
        rssm_optim.zero_grad()
        episode_losses["episode_loss"][-1].backward()
        nn.utils.clip_grad_norm_(rssm.vae.parameters(), max_norm=100.0, norm_type=2)  
        rssm_optim.step()

        # Detach the losses to save memory and log them in TensorBoard
        for key in episode_losses:
            episode_losses[key][-1] = episode_losses[key][-1].detach().item()
            writer.add_scalar(key, episode_losses[key][-1], global_step=episode)
        
        # save original image and reconstruction
        if episode % 10 == 0:
            save_image_and_reconstruction(x, x_pred, episode)

    env.close()

except KeyboardInterrupt:
    """ Clean handling for interrupts to stop training early """
    print("Stopping training.")
    # Delete the last loss if the training was stopped early
    # so that the list only consists of floats
    for key in episode_losses:
        if isinstance(episode_losses[key][-1], torch.Tensor):
            episode_losses[key] = episode_losses[key][:-1]

    # Close the TensorBoard writer and the gym environment
    writer.close()
    env.close()

## Imagine

In [5]:
""" imagine n steps """

def imagine_n_steps(obs, n, save_images=False):

    if save_images:
        images = []

    # reset h
    h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device)

    # encode the first state
    x = transform(obs).view(-1, 1, 128, 128)
    z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)


    for imagination_step in range(50):

        # predict z from h
        z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
        z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
        z = z_prior

        # sample an action
        # action_mean, action_var = policy_net(z)
        # action = torch.normal(mean=action_mean, std=torch.sqrt(action_var)) # Ax1 vector
        # action = torch.clip(action, action_clip_min, action_clip_max)
        
        # predict the value
        # v = value_net(z)
        # action = torch.randn(2, device=device)
        
        action, log_prob, actor_entropy = agent.get_action(z)

        # predict one step using the RSSM
        h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)
        done = 1 - continue_pred

        if save_images:
            images.append((255 * to_np(x_pred[0][0])).astype("uint8"))

    if save_images:
        imageio.mimsave("reconstructions/imagined_episode.gif", images, duration=0.03)

In [8]:
if toy_env:
    assert A==3
    env = gym.make("CarRacing-v2", max_episode_steps=100, render_mode="rgb_array") # rgb_array/human # 50 steps
else:
    assert A==2
    sim_config = {
        "exe_path" : "/home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64",
        "port" : 9091
    }
    env = gym.make(
        "GymV21Environment-v0", 
        env_id=env_id,
        max_episode_steps=max_episode_steps,
        make_kwargs={
            "conf": sim_config
        })
obs, info = env.reset()

imagine_n_steps(obs, 100, save_images=False)

env.close()

In [6]:
class ImaginationEnv(gym.Env):
    """ Custom gymnasium environment for training inside the world model (RSSM). """
    def __init__(self):
        super(ImaginationEnv, self).__init__()
        
        # define action and observation space
        # they must be gym.spaces objects
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(H+Z,), dtype=np.float32)
        self.action_space = spaces.Box(low=-action_clip, high=action_clip, shape=(A,), dtype=np.float32)
    
    def step(self, action):
        observation = np.random.rand(H+Z)
        reward = np.random.rand(1).item()
        terminated = np.random.rand(1).item() > 0.5
        truncated = np.random.rand(1).item() > 0.5
        info = {}
        return observation, reward, terminated, truncated, info
    
    def reset(self):
        observation = np.random.rand(H+Z)
        info = {}
        return observation, info

    def render(self):
        pass
    
    def close(self):
        pass
    
imagination_env = ImaginationEnv()

In [7]:
imagination_env.reset()

(array([0.40952018, 0.5596945 , 0.55404613, ..., 0.31293196, 0.46543275,
        0.97797887]),
 {})

In [8]:
imagination_env.step(0.1)

(array([0.88832871, 0.31962352, 0.60064864, ..., 0.36925936, 0.64272986,
        0.4995378 ]),
 0.29841965459690645,
 False,
 True,
 {})

In [9]:
agent = PPO("MlpPolicy", imagination_env, verbose=1, tensorboard_log="logs/")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [10]:
sb3.__version__

'2.0.0a5'

In [None]:
class ReplayBuffer():
    def __init__(self):

## Plot results

In [None]:
rolling_length = max(1, int(len(episode_losses["episode_loss"])/20))

fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(3*5, 2*5))

# Iterate over the keys and plot the losses
for i, key in enumerate(episode_losses.keys()):
    row = i // 3
    col = i % 3

    axs[row, col].set_title(key)
    losses = episode_losses[key]
    losses_moving_average = (
        np.convolve(
            np.array(losses).flatten(), np.ones(rolling_length), mode="valid"
        )
        / rolling_length
    )
    axs[row, col].plot(range(len(losses)), losses, label=key)
    axs[row, col].plot(range(len(losses_moving_average)), losses_moving_average, label="moving average")
    axs[row, col].legend(loc="upper right")

plt.tight_layout()
plt.show()

## Test area

In [None]:
# GPU memory consumption:
# 10 steps -> 6504MiB
# 50 steps -> 19990MiB

In [None]:
""" save reconstructions for one episode """

save_one_episode = False


if save_one_episode:
    rssm.eval()

    # Create the environment
    if toy_env:
        assert A==3
        env = gym.make("CarRacing-v2", max_episode_steps=150, render_mode="rgb_array") # rgb_array/human # 50 steps
    else:
        assert A==2
        sim_config = {
            "exe_path" : "/home/till/Desktop/Thesis/donkeycar_sim/DonkeySimLinux/donkey_sim.x86_64",
            "port" : 9091
        }
        env = gym.make(
            "GymV21Environment-v0", 
            env_id=env_id,
            max_episode_steps=max_episode_steps,
            make_kwargs={
                "conf": sim_config
            })


    try:
        for episode in tqdm(range(1)):

            # Get the initial state
            obs, info = env.reset()

            # Reset the RNN's hidden state
            h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device) # seq_len, B, H

            # Add a new loss for the current episode and initialize it to 0
            episode_length = 0

            # Play one episode
            done = False
            while not done:

                x = transform(obs).view(-1, 1, 128, 128)

                """ WORLD MODEL LEARNING """

                # predict z and generate the true stochastic latent variable z with the encoder
                z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
                z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
                z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)

                # apply external actor and critic nets on z
                action = torch.tensor([(np.random.rand() - 0.5)/3, (np.random.rand() + 0.5)/3], device=device).unsqueeze(dim=0) # policy_net(z) # Ax1 vector
                v = value_net(z)

                # predict one step using the RSSM and apply the actor-critic
                h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)

                # choose and execute an action
                next_obs, reward, terminated, truncated, info = env.step(to_np(action.squeeze()))        

                done = terminated or truncated
                obs = next_obs

                # calculate the loss
                continue_target = torch.tensor(1 - done, device=device, dtype=torch.float32)
                reward = torch.tensor(reward, device=device, dtype=torch.float32)

                # TODO: z_prior, z_posterior
                z_prior = torch.tensor(0, device=device, dtype=torch.float32)
                z_posterior = torch.tensor(0, device=device, dtype=torch.float32)

                plt.imsave(f"reconstructions/{episode_length}.png", to_np(x_pred[0][0]), cmap="gray")
                episode_length += 1

        env.close()

    except KeyboardInterrupt:
        """ Clean handling for interrupts to stop training early """
        print("Stopping training.")
        # Delete the last loss if the training was stopped early
        # so that the list only consists of floats
        for key in episode_losses:
            if isinstance(episode_losses[key][-1], torch.Tensor):
                episode_losses[key] = episode_losses[key][:-1]

        # Close the TensorBoard writer and the gym environment
        writer.close()
        env.close()

In [None]:
!tensorboard --logdir="logs/" 


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.11.0 at http://localhost:6006/ (Press CTRL+C to quit)
