In [1]:
import logging
import os
import random
import sys
from collections import deque
from operator import itemgetter

import gym_donkeycar
import gymnasium as gym
import imageio
import ipywidgets as widgets
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from IPython.display import display
from ipywidgets import HBox, VBox
from matplotlib import pyplot as plt
from PIL import Image
from ruamel.yaml import YAML
from scipy.ndimage import gaussian_filter1d
from scipy.stats import norm
from tensorboard import notebook
from tensorboard.backend.event_processing.event_accumulator import \
    EventAccumulator
from torch import distributions as dist
from torch.distributions import Categorical, Normal
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

import gym.spaces as gym_spaces
import gymnasium as gym  # overwrite OpenAI gym

# suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium.spaces.box") # module="gymnasium"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["IMAGEIO_IGNORE_WARNINGS"] = "True"

import stable_baselines3 as sb3
from gym_donkeycar.envs.donkey_env import DonkeyEnv
from gymnasium import spaces
from gymnasium.spaces import Box
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from src.actor_critic_discrete import DiscreteActorCritic
from src.actor_critic_dreamer import ActorCriticDreamer
from src.actor_critic import ContinuousActorCritic
from src.blocks import CategoricalStraightThrough, ConvBlock
from src.categorical_vae import CategoricalVAE
from src.imagination_env import make_imagination_env
from src.mlp import MLP
from src.preprocessing import transform
from src.replay_buffer import ReplayBuffer
from src.rssm import RSSM
from src.utils import (load_config, make_env, save_image_and_reconstruction,
                       to_np, symlog, symexp, twohot_encode, ExponentialMovingAvg,
                       ActionExponentialMovingAvg, MetricsTracker)
from src.vae import VAE

from typing import Dict, List, Union
from __future__ import annotations

torch.cuda.empty_cache()
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Load the config
config = load_config()
for key in config:
    locals()[key] = config[key]

## Init the RSSM (including all networks)

In [2]:
rssm = RSSM().to(device)

Initializing encoder:
- adding ConvBlock((3, 64))                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding ConvBlock((64, 128))                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding ConvBlock((128, 256))                   ==> output shape: (256, 8, 8) ==> prod: 16384
- adding ConvBlock((256, 512))                   ==> output shape: (512, 4, 4) ==> prod: 8192
- adding ConvBlock((512, 256))                   ==> output shape: (256, 2, 2) ==> prod: 1024
- adding Flatten()
- adding Reshape: (*,1024) => (*,32,32)

Initializing decoder:
- adding Reshape: (*,1024) => (*,256,2,2)
- adding transpose ConvBlock(256, 256)                   ==> output shape: (256, 4, 4) ==> prod: 4096
- adding transpose ConvBlock(256, 512)                   ==> output shape: (512, 8, 8) ==> prod: 32768
- adding transpose ConvBlock(512, 256)                   ==> output shape: (256, 16, 16) ==> prod: 65536
- adding transpose ConvBlock(256, 128)                   ==> output shap

## Create the imagination environment for training the agent

In [3]:
replay_buffer = ReplayBuffer()

In [4]:
imagination_env = make_imagination_env(rssm, replay_buffer, render_mode=None)

Adding a TimeLimit wrapper with 16 max imagination episode steps.
Adding an AutoReset wrapper.
Adding a RescaleActionV0 wrapper. Low: [-1. -1. -1.], High: [1. 1. 1.]


## Init the agent

In [5]:
agent = PPO(
    policy="MlpPolicy",
    env=imagination_env,
    verbose=verbose,
    tensorboard_log=log_dir,
    gamma=gamma,
    gae_lambda=lam,
    ent_coef=ent_coef,
)

# agent = DiscreteActorCritic()

## Training loop


In [6]:
#""" training loop """
#
#rssm.load_weights("weights/RSSM_1.70111713")
#rssm.train()
#
## Create the environment
#env = make_env()
#
## Logging
#writer = SummaryWriter(log_dir)
#if config["show_inline_tensorboard"]:
#    notebook.start(f"--logdir={log_dir}")
#
#episode_return = []
#episode_losses = { # for loss plots
#    "episode_loss": [],
#    "episode_image_loss": [],
#    "episode_reward_loss": [],
#    "episode_continue_loss": [],
#    "episode_dyn_loss": [],
#    "episode_rep_loss": [],
#}
#
#best_running_loss = np.inf # used for saving the best rssm model
#
#try:
#    for episode in tqdm(range(start_episode, n_seed_episodes + n_training_episodes)):
#
#        if verbose:
#            print("EPISODE:", episode)
#        
#        # Get the initial state
#        obs, info = env.reset()
#
#        # Reset the RNN's hidden state
#        h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device) # seq_len, B, H
#
#        # Add a new loss for the current episode and initialize it to 0
#        episode_length = 0
#        episode_return.append(0)
#        for key in episode_losses:
#            episode_losses[key].append(torch.tensor(0, device=device, dtype=torch.float32))
#
#        # Play one episode
#        done = False
#        while not done:
#
#            # preprocess the observation and add it to the replay buffer
#            x = transform(obs).view(-1, 1 if grayscale else 3, *size) # (B, 3, 128, 128)
#            replay_buffer.push(x)
#
#            """ WORLD MODEL LEARNING """
#
#            # predict z and generate the true stochastic latent variable z with the encoder
#            z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
#            z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
#            z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)
#            state = to_np(torch.cat((h.flatten().detach(), z.flatten().detach()), dim=0))
#            
#            # get random action for the first seed episodes
#            # and in training mode get the action from the actor
#            if episode < n_seed_episodes:
#                action = env.action_space.sample()
#            else:
#                action, _ = agent.predict(state, deterministic=False)
#
#            # predict one step using the RSSM
#            h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)
#
#            # take an environment step with the action
#            obs, reward, terminated, truncated, info = env.step(action.squeeze())
#            done = terminated or truncated
#            
#            # Add the reward to the episode return
#            episode_return[-1] += reward
#
#            # calculate the loss
#            continue_target = torch.tensor(1 - done, device=device, dtype=torch.float32)
#            reward = torch.tensor(reward, device=device, dtype=torch.float32)
#            losses = rssm.get_losses(x, x_pred, reward, reward_pred, 
#                                     continue_target, continue_prob, z_prior, z)
#
#            # Add the loss for the current step to the episode loss
#            episode_length += 1
#            for key in losses:
#                episode_losses["episode_" + key][-1] += losses[key]
#
#        # Calculate the mean loss of the episode
#        for key in episode_losses:
#            episode_losses[key][-1] /= episode_length
#
#        # update the world model at the end of an episode using the mean loss of the episode
#        rssm.optim.zero_grad()
#        episode_losses["episode_loss"][-1].backward()
#        nn.utils.clip_grad_norm_(rssm.vae.parameters(), max_norm=max_grad_norm, norm_type=2)  
#        rssm.optim.step()
#
#        # Log the episode return
#        writer.add_scalar("episode_return", episode_return[-1], global_step=episode)
#        
#        # Detach the losses to save memory and log them in TensorBoard
#        for key in episode_losses:
#            episode_losses[key][-1] = episode_losses[key][-1].detach().item()
#            writer.add_scalar(key, episode_losses[key][-1], global_step=episode)
#        
#        # save original image and reconstruction
#        if episode % 10 == 0:
#            save_image_and_reconstruction(x, x_pred, episode)
#        
#        # save the rssm model with the best running loss
#        running_loss = np.mean(episode_losses["episode_loss"][-10:])
#        if episode > 0 and episode % 10 == 0 and running_loss < best_running_loss:
#            best_running_loss = running_loss
#            
#            # save the rssm and agent
#            rssm.save_weights(filename=f"RSSM_{best_running_loss:.8f}")
#            agent.save(f"weights/{agent.__class__.__name__}_agent")
#        
#        """ RL AGENT LEARNING (IN THE WORLD MODEL) """
#        if episode >= n_seed_episodes:
#            if verbose:
#                print("AGENT IS LEARNING")
#            agent.learn(
#                total_timesteps=imagination_timesteps_per_model_update,
#                progress_bar=imagination_progress_bar,
#                reset_num_timesteps=False
#            )
#
#    env.close()
#
#except KeyboardInterrupt:
#    """ Clean handling for interrupts to stop training early """
#    print("Stopping training.")
#    # Delete the last loss if the training was stopped early
#    # so that the list only consists of floats
#    for key in episode_losses:
#        if isinstance(episode_losses[key][-1], torch.Tensor):
#            episode_losses[key] = episode_losses[key][:-1]
#
#    # Close the TensorBoard writer and the gym environment
#    writer.close()
#    env.close()

In [7]:
# ----------------------------------------------------

In [8]:
# New dreamer training loop with batches for the distributional critic
# (and later: add manual actor-critic training in imagination env. and delete sb3).

### rssm.load_weights("weights/RSSM_1.70111713")
### rssm.train()

# Create the environment
env = make_env()

# Logging
### writer = SummaryWriter(log_dir)
### if config["show_inline_tensorboard"]:
###     notebook.start(f"--logdir={log_dir}")
### episode_return = []
### episode_losses = { # for loss plots
###     "episode_loss": [],
###     "episode_image_loss": [],
###     "episode_reward_loss": [],
###     "episode_continue_loss": [],
###     "episode_dyn_loss": [],
###     "episode_rep_loss": [],
### }
### best_running_loss = np.inf # used for saving the best rssm model

tracker = MetricsTracker(
    # log the mean loss for the training metrics
    training_metrics=["loss", "image_loss", "reward_loss", "continue_loss", "dyn_loss", "rep_loss", "rewards"],
    
    # loss per step
    episode_metrics=["loss", "image_loss", "reward_loss", "continue_loss", "dyn_loss", "rep_loss", "rewards"],
)

for sample_phase in tqdm(range(start_phase, n_seed_phases + n_model_updates)):
    
    # Reset the RNN's hidden state
    h = torch.zeros(rssm.num_rnn_layers, 1, H, device=device, dtype=torch.float32) # seq_len, B, H
    
    if sample_phase == start_phase:
        
        # Get the first obs
        obs, info = env.reset(seed=42)
        x = transform(obs).view(-1, 1 if grayscale else 3, *size) # (B, 3, 128, 128)
        replay_buffer.push(x)
    
    for step in range(n_steps_per_model_update):
    
        """ WORLD MODEL LEARNING """

        # predict z and generate the true stochastic latent variable z with the encoder
        z_prior = rssm.dynamics_mlp(h).view(-1, num_categoricals, num_classes) # (1,32,32) for the softmax
        z_prior = F.softmax(z_prior, -1).flatten(start_dim=1, end_dim=2) # (1, 1024)
        z = rssm.vae.encode(x).flatten(start_dim=1, end_dim=2)
        state = to_np(torch.cat((h.flatten().detach(), z.flatten().detach()), dim=0))

        # get random action for the first seed sample phases
        # and in training mode get the action from the actor
        if sample_phase < n_seed_phases:
            action = env.action_space.sample()
        else:
            action, _ = agent.predict(state, deterministic=False)

        # predict one step using the RSSM
        h, reward_pred, continue_prob, continue_pred, x_pred = rssm.step(action, h, z)

        # take an environment step with the action
        obs, reward, terminated, truncated, info = env.step(action.squeeze())

        # Add the reward to the episode return
        ### episode_return[-1] += reward
        tracker.add(
            episode_metrics={
                "rewards": reward,
            }
        )

        # calculate the loss
        continue_target = torch.tensor(1 - (terminated or truncated), device=device, dtype=torch.float32)
        reward = torch.tensor(reward, device=device, dtype=torch.float32)
        losses = rssm.get_losses(x, x_pred, reward, reward_pred, 
                                 continue_target, continue_prob, z_prior, z)

        # Add the loss for the current step to the episode loss
        ### episode_length += 1
        ### for key in losses:
        ###     episode_losses["episode_" + key][-1] += losses[key]
        tracker.add(
            episode_metrics=losses # losses is a dict with batches for all episode metrics
        )

    # Calculate the mean loss of the episode
    ### for key in episode_losses:
    ###     episode_losses[key][-1] /= episode_length
    
    # NEW: get the mean loss and log it
    episode_losses = tracker.get_episode_batches(reduction="mean") # episode_losses is a dict

    # update the world model at the end of an episode using the mean loss of the episode
    ### rssm.optim.zero_grad()
    ### episode_losses["episode_loss"][-1].backward()
    ### nn.utils.clip_grad_norm_(rssm.vae.parameters(), max_norm=max_grad_norm, norm_type=2)  
    ### rssm.optim.step()
    
    # NEW:
    # per episode:
    rssm.update_parameters(episode_losses["loss"])

    # Log the episode return
    ### writer.add_scalar("episode_return", episode_return[-1], global_step=episode)
    # NEW: do this later. at the end in env return queue

    # Detach the losses to save memory and log them in TensorBoard
    ### for key in episode_losses:
    ###     episode_losses[key][-1] = episode_losses[key][-1].detach().item()
    ###     writer.add_scalar(key, episode_losses[key][-1], global_step=episode)
    
    # NEW:
    # per episode:
    # save the rssm model with the best running loss
    ### running_loss = np.mean(episode_losses["episode_loss"][-10:])
    ### if episode > 0 and episode % 10 == 0 and running_loss < best_running_loss:
    ###     best_running_loss = running_loss
    ### 
    ###     # save the rssm and agent
    ###     rssm.save_weights(filename=f"RSSM_{best_running_loss:.8f}")
    ###     agent.save(f"weights/{agent.__class__.__name__}_agent")

    """ RL AGENT LEARNING (IN THE WORLD MODEL) """
    if verbose and sample_phase == n_seed_phases:
        print("The agent starts learning.")
            
    if sample_phase >= n_seed_phases:
        agent.learn(
            total_timesteps=imagination_timesteps_per_model_update,
            progress_bar=imagination_progress_bar,
            reset_num_timesteps=False
        )
    
    # every couple episodes:
    if sample_phase % config["log_interval"] == 0:
        
        # log mean episode losses
        tracker.add(
            training_metrics=episode_losses
        )
        
        # Episode return
        if len(env.return_queue):
            tracker.writer.add_scalar("episode_return", np.array(env.return_queue)[-1], global_step=len(env.return_queue))
            
        # TODO Later: actor and critic losses

        # save original image and reconstruction
        save_image_and_reconstruction(x, x_pred, sample_phase)

env.close()
    
    

Making a toy env.
Making 1 vectorized envs.
Adding a Gymnasium RecordEpisodeStatistics wrapper.
Adding a TimeLimit wrapper with 1000 max episode steps.
Adding an AutoReset wrapper.
Adding a RescaleActionV0 wrapper. Low: [-1. -1. -1.], High: [1. 1. 1.]


  0%|                                                                                                | 0/501000 [00:00<?, ?it/s]

0


  0%|                                                                                    | 1/501000 [00:01<242:20:39,  1.74s/it]

1


  0%|                                                                                    | 2/501000 [00:02<143:58:59,  1.03s/it]

2


  0%|                                                                                    | 3/501000 [00:02<113:17:48,  1.23it/s]

3


  0%|                                                                                     | 4/501000 [00:03<99:37:38,  1.40it/s]

4


  0%|                                                                                     | 5/501000 [00:03<92:18:38,  1.51it/s]

5


  0%|                                                                                     | 6/501000 [00:04<89:26:35,  1.56it/s]

6


  0%|                                                                                     | 7/501000 [00:05<84:58:42,  1.64it/s]

7


  0%|                                                                                     | 8/501000 [00:05<82:50:32,  1.68it/s]

8


  0%|                                                                                     | 9/501000 [00:06<81:33:40,  1.71it/s]

9


  0%|                                                                                    | 10/501000 [00:06<80:48:13,  1.72it/s]

10


  0%|                                                                                    | 11/501000 [00:07<86:26:50,  1.61it/s]

11


  0%|                                                                                    | 12/501000 [00:08<83:24:05,  1.67it/s]

12


  0%|                                                                                    | 13/501000 [00:08<81:34:03,  1.71it/s]

13


  0%|                                                                                    | 14/501000 [00:09<79:58:35,  1.74it/s]

14


  0%|                                                                                    | 15/501000 [00:09<80:09:29,  1.74it/s]

15


  0%|                                                                                    | 16/501000 [00:10<81:28:47,  1.71it/s]

16


  0%|                                                                                    | 17/501000 [00:10<79:46:02,  1.74it/s]

17


  0%|                                                                                    | 18/501000 [00:11<78:50:43,  1.76it/s]

18


  0%|                                                                                    | 19/501000 [00:11<77:40:18,  1.79it/s]

19


  0%|                                                                                    | 20/501000 [00:12<77:01:22,  1.81it/s]

20


  0%|                                                                                    | 21/501000 [00:13<79:24:03,  1.75it/s]

21


  0%|                                                                                    | 22/501000 [00:13<78:08:35,  1.78it/s]

22


  0%|                                                                                    | 23/501000 [00:14<78:33:08,  1.77it/s]

23


  0%|                                                                                    | 24/501000 [00:14<78:37:52,  1.77it/s]

24


  0%|                                                                                    | 25/501000 [00:15<78:44:17,  1.77it/s]

25


  0%|                                                                                    | 26/501000 [00:16<80:24:58,  1.73it/s]

26


  0%|                                                                                    | 27/501000 [00:16<79:22:12,  1.75it/s]

27


  0%|                                                                                    | 28/501000 [00:17<79:06:09,  1.76it/s]

28


  0%|                                                                                    | 29/501000 [00:17<78:44:08,  1.77it/s]

29


  0%|                                                                                    | 30/501000 [00:18<78:35:10,  1.77it/s]

30


  0%|                                                                                    | 31/501000 [00:18<79:58:27,  1.74it/s]

31


  0%|                                                                                    | 32/501000 [00:19<79:47:46,  1.74it/s]

32


  0%|                                                                                    | 33/501000 [00:19<79:30:17,  1.75it/s]

33


  0%|                                                                                    | 34/501000 [00:20<79:34:52,  1.75it/s]

34


  0%|                                                                                    | 35/501000 [00:21<79:30:36,  1.75it/s]

35


  0%|                                                                                    | 36/501000 [00:21<83:56:52,  1.66it/s]

36


  0%|                                                                                    | 37/501000 [00:22<82:55:54,  1.68it/s]

37


  0%|                                                                                    | 38/501000 [00:22<81:40:09,  1.70it/s]

38


  0%|                                                                                    | 39/501000 [00:23<80:10:21,  1.74it/s]

39


  0%|                                                                                    | 40/501000 [00:24<79:55:20,  1.74it/s]

40


  0%|                                                                                    | 40/501000 [00:24<85:42:14,  1.62it/s]


KeyboardInterrupt: 

In [None]:
# try:
# - training loop
# except KeyboardInterrupt:
#     """ Clean handling for interrupts to stop training early """
#     print("Stopping training.")
#     # Delete the last loss if the training was stopped early
#     # so that the list only consists of floats
#     for key in episode_losses:
#         if isinstance(episode_losses[key][-1], torch.Tensor):
#             episode_losses[key] = episode_losses[key][:-1]
# 
#     # Close the TensorBoard writer and the gym environment
#     writer.close()
#     env.close()

## Plot the results

In [None]:
plot_results = False

if plot_results:
    rolling_length = max(1, int(len(episode_losses["episode_loss"])/20))

    fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(3*5, 2*5))

    # Iterate over the keys and plot the losses
    for i, key in enumerate(episode_losses.keys()):
        row = i // 3
        col = i % 3

        axs[row, col].set_title(key)
        losses = episode_losses[key]
        losses_moving_average = (
            np.convolve(
                np.array(losses).flatten(), np.ones(rolling_length), mode="valid"
            )
            / rolling_length
        )
        axs[row, col].plot(range(len(losses)), losses, label=key)
        axs[row, col].plot(range(len(losses_moving_average)), losses_moving_average, label="moving average")
        axs[row, col].legend(loc="upper right")

    plt.tight_layout()
    plt.show()

## Showcase the trained agent playing in latent imagination

In [None]:
showcase_agent = False

if showcase_agent:
    
    showcase_rewards = []
    imagination_env.render_mode = "gif"
    obs, info = imagination_env.reset()
    
    for i in range(500):
        
        # apply the RL agent in eval mode to get an action
        state = to_np(torch.cat((h.flatten().detach(), z.flatten().detach()), dim=0))
        action, _ = agent.predict(state, deterministic=False)
        # action = imagination_env.action_space.sample()
        
        obs, reward, terminated, truncated, info = imagination_env.step(action)
        showcase_rewards.append(reward)
        imagination_env.render()
        
        if terminated or truncated:
            break
        
    imagination_env.close()
    imagination_env.render_mode = None

    plt.plot(showcase_rewards)
    plt.show()

## Test area

In [None]:
# !tensorboard 