## Imports

In [8]:
import logging
import os
import random
import sys
from collections import deque
from operator import itemgetter

import gym_donkeycar
import gymnasium as gym
import imageio
import ipywidgets as widgets
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from IPython.display import display
from ipywidgets import HBox, VBox
from matplotlib import pyplot as plt
from PIL import Image
from ruamel.yaml import YAML
from scipy.ndimage import gaussian_filter1d
from scipy.stats import norm
from tensorboard import notebook
from tensorboard.backend.event_processing.event_accumulator import \
    EventAccumulator
from torch import distributions as dist
from torch.distributions import Categorical, Normal
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

# suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["IMAGEIO_IGNORE_WARNINGS"] = "True"

import gym.spaces as gym_spaces
import gymnasium as gym  # overwrite OpenAI gym
import stable_baselines3 as sb3
from gym_donkeycar.envs.donkey_env import DonkeyEnv
from gymnasium import spaces
from gymnasium.spaces import Box
from gymnasium.experimental.wrappers import RescaleActionV0
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from src.actor_critic import ContinuousActorCritic
from src.blocks import CategoricalStraightThrough, ConvBlock
from src.categorical_vae import CategoricalVAE
from src.imagination_env import ImaginationEnv
from src.mlp import MLP
from src.preprocessing import grayscale_transform as transform
from src.replay_buffer import ReplayBuffer
from src.rssm import RSSM
from src.utils import (load_config, make_env, save_image_and_reconstruction,
                       to_np)
from src.vae import VAE

torch.cuda.empty_cache()
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Load the config
config = load_config()
for key in config:
    locals()[key] = config[key]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Create the environment

In [58]:
env = make_env()

Making a toy env.
Making 1 vectorized envs.
Adding a Gymnasium RecordEpisodeStatistics wrapper.
Adding a TimeLimit wrapper with 1000 max episode steps.
Adding an AutoReset wrapper.
Adding a RescaleActionV0 wrapper to have an effective action space [-1,1].
Note: Clip actions at 1.0 => The agent can take agents from:
Low: [-1. -1. -1.] to High: [1. 1. 1.]


In [60]:
vae = VAE().to(device)
vae.load_weights("weights/VAE_1")
vae.optim = optim.Adam(vae.parameters(), lr=1e-4, weight_decay=1e-6)

Initializing encoder:
- adding ConvBlock((3, 64))                   ==> output shape: (64, 32, 32) ==> prod: 65536
- adding ConvBlock((64, 128))                   ==> output shape: (128, 16, 16) ==> prod: 32768
- adding ConvBlock((128, 256))                   ==> output shape: (256, 8, 8) ==> prod: 16384
- adding ConvBlock((256, 512))                   ==> output shape: (512, 4, 4) ==> prod: 8192
- adding ConvBlock((512, 256))                   ==> output shape: (256, 2, 2) ==> prod: 1024
- adding Flatten()
- adding Linear() for Mu: 1024 and Logvar: 1024

Initializing decoder:
- adding Reshape: (*,1024) => (*,256,2,2)
- adding transpose ConvBlock(256, 256)                   ==> output shape: (256, 4, 4) ==> prod: 4096
- adding transpose ConvBlock(256, 512)                   ==> output shape: (512, 8, 8) ==> prod: 32768
- adding transpose ConvBlock(512, 256)                   ==> output shape: (256, 16, 16) ==> prod: 65536
- adding transpose ConvBlock(256, 128)                   ==> out

In [61]:
agent = ContinuousActorCritic().to(device)

In [73]:
# New training loop with batches

losses = {
    "vae_loss": [],
    "reconstruction_loss": [],
    "KLD_loss": [],
    "critic_loss": [],
    "actor_loss": [],
}

# Logging
writer = SummaryWriter(log_dir)
if config["show_inline_tensorboard"]:
    notebook.start(f"--logdir={log_dir}")

for sample_phase in tqdm(range(n_updates)):
    
    batch_observations = []
    batch_reconstructions = []
    batch_mu = []
    batch_logvar = []
    
    # NEW
    batch_rewards = []
    batch_log_probs = []
    batch_value_preds = []
    batch_entropies = []
    batch_masks = []
    
    if sample_phase == 0:
        obs, info = env.reset(seed=42)

    for step in range(n_steps_per_update):
        if len(obs.shape) == 4:
            obs = obs[0]

        # Get the observation and encode it
        obs = transform(obs)
        z, reconstruction, mu, logvar = vae(obs)
        
        if sample_phase % 100 == 0:
            save_image_and_reconstruction(obs, reconstruction, sample_phase)
        
        # Add the observation, reconstruction, mu, and logvar to the respective batches
        batch_observations.append(obs)
        batch_reconstructions.append(reconstruction)
        batch_mu.append(mu)
        batch_logvar.append(logvar)

        # Get an action and take an environment step
        # action = np.random.rand(A)
        action, log_prob, actor_entropy = agent.get_action(z)
        # if sample_phase % config["log_interval"] == 0:
        #     if step % 10 == 0:
        #         print(action)
        obs, reward, terminated, truncated, info = env.step(to_np(action))
        
        # Collect the necessary data for an agent update
        batch_rewards.append(reward)
        batch_log_probs.append(log_prob)
        batch_entropies.append(actor_entropy)
        mask = torch.tensor(0.0 if terminated else 1.0)
        batch_masks.append(mask)
        value_pred = agent.critic(torch.Tensor(z))
        batch_value_preds.append(value_pred)

    # Convert the batch tensors to tensors
    batch_observations = torch.stack(batch_observations).to(device)  # [n_steps_per_update, *obs_shape]
    batch_reconstructions = torch.stack(batch_reconstructions).to(device)  # [n_steps_per_update, *obs_shape]
    batch_mu = torch.stack(batch_mu).to(device)  # [n_steps_per_update, latent_dim]
    batch_logvar = torch.stack(batch_logvar).to(device)  # [n_steps_per_update, latent_dim]
    batch_rewards = torch.tensor(batch_rewards).to(device)  # [n_steps_per_update]
    batch_log_probs = torch.stack(batch_log_probs).to(device)  # [n_steps_per_update]
    batch_value_preds = torch.stack(batch_value_preds).to(device)  # [n_steps_per_update]
    last_value_pred = agent.critic(torch.Tensor(z)).to(device)  # last value prediction for GAE
    batch_entropies = torch.stack(batch_entropies).to(device)  # [n_steps_per_update]
    batch_masks = torch.stack(batch_masks).to(device)  # [n_steps_per_update]

    # Update the agent's parameters
    critic_loss, actor_loss = agent.get_loss(
        batch_rewards, batch_log_probs, batch_value_preds, last_value_pred, batch_entropies, batch_masks
    )
    agent.update_parameters(critic_loss, actor_loss)

    # Update the VAE's parameters
    vae_loss, reconstruction_loss, KLD_loss = vae.get_loss(batch_observations, batch_reconstructions, batch_mu, batch_logvar)
    vae.optim.zero_grad()
    vae_loss.backward()
    vae.optim.step()

    if sample_phase % config["log_interval"] == 0:
        
        # Log the losses
        losses["vae_loss"].append(to_np(vae_loss))
        losses["reconstruction_loss"].append(to_np(reconstruction_loss))
        losses["KLD_loss"].append(to_np(KLD_loss))
        losses["critic_loss"].append(to_np(critic_loss))
        losses["actor_loss"].append(to_np(actor_loss))

        # Log the episode return
        writer.add_scalar("Mean reward", np.mean(to_np(batch_rewards)), global_step=sample_phase)

        # Detach the losses to save memory and log them in TensorBoard
        writer.add_scalar("vae_loss", to_np(vae_loss), global_step=sample_phase)
        writer.add_scalar("reconstruction_loss", to_np(reconstruction_loss), global_step=sample_phase)
        writer.add_scalar("KLD_loss", to_np(KLD_loss), global_step=sample_phase)
        writer.add_scalar("critic_loss", to_np(critic_loss), global_step=sample_phase)
        writer.add_scalar("actor_loss", to_np(actor_loss), global_step=sample_phase)



  5%|████▎                                                                               | 1028/20000 [10:34<3:15:15,  1.62it/s]


KeyboardInterrupt: 

In [75]:
# vae.save_weights()

In [11]:
agent.actor(z)

(tensor([-1., -1., -1.], device='cuda:0', grad_fn=<TanhBackward0>),
 tensor([2.2562e-09, 1.0139e-09, 8.2383e-10], device='cuda:0',
        grad_fn=<SigmoidBackward0>))

In [52]:
mu, var = agent.actor(z)

In [54]:
action_pd = dist.MultivariateNormal(mu, var * torch.eye(mu.shape[0], device=agent.device))
action = action_pd.sample()
action

tensor([-1.0000, -1.0001, -1.0000], device='cuda:0')

In [56]:
action_pd.log_prob(action)

tensor(25.1843, device='cuda:0', grad_fn=<SubBackward0>)

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')

In [21]:
future_return = 0.0
returns = []

for r in reversed(batch_rewards):
    future_return = r + config["gamma"] * future_return
    returns.insert(0, future_return)

In [22]:
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)

In [23]:
returns

tensor([-1.6396, -1.5699, -1.5000, -1.4299, -1.3595, -1.2890, -1.2183, -1.1473,
        -1.0761, -1.0047, -0.9331, -0.8613, -0.7893, -0.7170, -0.6446, -0.5719,
        -0.4990, -0.4258, -0.3525, -0.2789, -0.2052, -0.1311, -0.0569,  0.0175,
         0.0922,  0.1671,  0.2422,  0.3176,  0.3932,  0.4690,  0.5450,  0.6213,
         0.6978,  0.7745,  0.8515,  0.9287,  1.0061,  1.0837,  1.1616,  1.2397,
         1.3181,  1.3967,  1.4755,  1.5546,  1.6339,  1.7134])

In [26]:
for log_prob, value_pred, value in zip(batch_log_probs, batch_value_preds, returns):
    print(log_prob, value_pred, value)

tensor(7.9608, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6531], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.6396)
tensor(7.8397, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6498], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.5699)
tensor(6.8291, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6525], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.5000)
tensor(7.7620, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6520], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.4299)
tensor(8.2424, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6579], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.3595)
tensor(8.2347, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6581], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.2890)
tensor(7.9923, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.6581], device='cuda:0', grad_fn=<AddBackward0>) tensor(-1.2183)
tensor(7.9663, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-32.

In [10]:
log_prob

tensor(8.2914, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
    losses["vae_loss"].append(to_np(vae_loss))
    losses["reconstruction_loss"].append(to_np(reconstruction_loss))
    losses["KLD_loss"].append(to_np(KLD_loss))
    losses["critic_loss"].append(to_np(critic_loss))
    losses["actor_loss"].append(to_np(actor_loss))


In [None]:
import matplotlib.pyplot as plt

# Assuming `losses` is a dictionary containing the loss values

# Create a figure and subplots
smoothing_factor = 40
fig, axs = plt.subplots(2, 3, figsize=(12, 8))

# Plot VAE loss
smoothed = gaussian_filter1d(losses["vae_loss"], sigma=smoothing_factor)
axs[0, 0].plot(losses["vae_loss"], alpha=0.8)
axs[0, 0].plot(smoothed)
axs[0, 0].set_title("VAE Loss")

# Plot Reconstruction loss
smoothed = gaussian_filter1d(losses["reconstruction_loss"], sigma=smoothing_factor)
axs[0, 1].plot(losses["reconstruction_loss"], alpha=0.8)
axs[0, 1].plot(smoothed)
axs[0, 1].set_title("Reconstruction Loss")

# Plot KLD loss
smoothed = gaussian_filter1d(losses["KLD_loss"], sigma=smoothing_factor)
axs[0, 2].plot(losses["KLD_loss"], alpha=0.8)
axs[0, 2].plot(smoothed)
axs[0, 2].set_title("KLD Loss")

# Plot Critic loss
smoothed = gaussian_filter1d(losses["critic_loss"], sigma=smoothing_factor)
axs[1, 0].plot(losses["critic_loss"], alpha=0.8)
axs[1, 0].plot(smoothed)
axs[1, 0].set_title("Critic Loss")

# Plot Actor loss
smoothed = gaussian_filter1d(losses["actor_loss"], sigma=smoothing_factor)
axs[1, 1].plot(losses["actor_loss"], alpha=0.8)
axs[1, 1].plot(smoothed)
axs[1, 1].set_title("Actor Loss")

# Remove empty subplot
fig.delaxes(axs[1, 2])

# Adjust spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()


In [None]:
batch_observations.shape

In [None]:
batch_reconstructions.shape

In [None]:
batch_mu.shape

In [None]:
batch_logvar.shape

In [None]:
batch_rewards.shape

In [None]:
batch_log_probs.shape

In [None]:
batch_value_preds.shape

In [None]:
last_value_pred.shape

In [None]:
batch_entropies.shape

In [None]:
batch_masks.shape

In [None]:
# max_episode_steps, 

In [None]:
plt.plot(losses["vae_loss"])

In [None]:
plt.plot(losses["reconstruction_loss"])

In [None]:
plt.plot(losses["KLD_loss"])

In [None]:
obs, info = env.reset(seed=42)

if len(obs.shape) == 4:
    obs = obs[0]

obs = transform(obs)

plt.imshow(to_np(torch.permute(obs, (1,2,0))))

In [None]:
reconstruction, mu, logvar = vae(torch.randn(obs.shape).to(device))
plt.imshow(to_np(torch.permute(reconstruction, (1,2,0))))