## Imports

In [1]:
from __future__ import annotations
from typing import Dict, List, Union

import logging
import os
import random
import sys
from collections import deque
from operator import itemgetter

import gym_donkeycar
import gymnasium as gym
import imageio
import ipywidgets as widgets
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from IPython.display import display
from ipywidgets import HBox, VBox
from matplotlib import pyplot as plt
from PIL import Image
from ruamel.yaml import YAML
from scipy.ndimage import gaussian_filter1d
from scipy.stats import norm
from tensorboard import notebook
from tensorboard.backend.event_processing.event_accumulator import \
    EventAccumulator
from torch import distributions as dist
from torch.distributions import Categorical, Normal
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

import gymnasium as gym

# suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gymnasium.spaces.box") # module="gymnasium"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["IMAGEIO_IGNORE_WARNINGS"] = "True"

import stable_baselines3 as sb3
from gym_donkeycar.envs.donkey_env import DonkeyEnv
from gymnasium import spaces
from gymnasium.spaces import Box
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common import env_checker
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

import src
from src.actor_critic_discrete import DiscreteActorCritic
from src.actor_critic import ContinuousActorCritic
from src.blocks import CategoricalStraightThrough, ConvBlock
from src.categorical_vae import CategoricalVAE
from src.imagination_env import make_imagination_env
from src.mlp import MLP
from src.preprocessing import transform
from src.replay_buffer import ReplayBuffer
from src.rssm import RSSM
from src.utils import (load_config, make_env, save_image_and_reconstruction,
                       to_np, symlog, symexp, twohot_encode, ExponentialMovingAvg,
                       ActionExponentialMovingAvg, MetricsTracker)
from src.vae import VAE

torch.cuda.empty_cache()
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Load the config
config = load_config()
for key in config:
    locals()[key] = config[key]

## Create the environment

In [2]:
make_one_env = lambda: gym.make("Pendulum-v1")
make_time_limit_env = lambda: gym.wrappers.TimeLimit(make_one_env(), max_episode_steps=config["max_episode_steps"])
make_auto_reset_env = lambda: gym.wrappers.AutoResetWrapper(make_time_limit_env())

env = gym.vector.AsyncVectorEnv([lambda: make_auto_reset_env() for i in range(n_envs)])

env = gym.experimental.wrappers.RescaleActionV0(env, min_action=config["action_space_low"], max_action=config["action_space_high"])
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=config["n_updates"])

In [3]:
# env = make_env()
# 
# env.reset()
# for i in range(1000):
#     obs, reward, terminated, truncated, info = env.step(np.random.rand(5,3)*100)
#     done = [te or tr for te, tr in zip(terminated, truncated)]
#     print(done)

In [7]:
# agent = ContinuousActorCritic()
agent = DiscreteActorCritic(n_features=3, n_actions=1)

# agent.load_weights("weights/ContinuousActorCritic_0")

# vae = VAE()
# vae.optim = optim.Adam(vae.parameters(), lr=1e-4, weight_decay=1e-6)

Initializing agent with 3 features and 1 actions.
Initializing critic.
Adding zero weight init to the output layer.
Initializing actor.


In [9]:
# New training loop with batches for the distributional critic

tracker = MetricsTracker(
    training_metrics=["critic_loss", "actor_loss"],
    episode_metrics=["rewards", "log_probs", "value_preds", "critic_dists", "entropies", "masks"],
)

for sample_phase in tqdm(range(n_updates)):
    
    if sample_phase == 0:
        obs, info = env.reset(seed=42)
        # obs = transform(torch.tensor(obs)) ### only for images

    for step in range(n_steps_per_update):
            
        value_pred, critic_dist = agent.apply_critic(obs)

        # Get an action and take an environment step
        action, log_prob, actor_entropy = agent.get_action(obs)
        obs, reward, terminated, truncated, info = env.step(to_np(action))
        # obs = transform(torch.tensor(obs)) ### only for images
        
        # every step:
        tracker.add(
            episode_metrics={
                "rewards": reward,
                "log_probs": log_prob,
                "value_preds": value_pred,
                "critic_dists": critic_dist,
                "entropies": actor_entropy,
                "masks": np.where(terminated, 0, 1),
            }
        )
    
    # every sample phase:
    episode_batches = tracker.get_episode_batches() # episode_batches is a dict
    last_value_pred, _ = agent.apply_critic(obs) # last value prediction for GAE

    # Update the agent's parameters
    # DEBUG:
    for key, val in episode_batches.items():
        print(key, val.shape)
    critic_loss, actor_loss = agent.get_loss(episode_batches, last_value_pred)
    agent.update_parameters(critic_loss, actor_loss)
    
    # Log the episode metrics
    if sample_phase % config["log_interval"] == 0:
        tracker.add(
            training_metrics={
                "critic_loss": critic_loss,
                "actor_loss": actor_loss,
            }
        )

        # Episode return
        if len(env.return_queue):
            tracker.writer.add_scalar("episode_return", np.array(env.return_queue)[-1], global_step=len(env.return_queue))

        # Actor and critic losses
        tracker.log_to_tensorboard(step=sample_phase)
        

  0%|                                                                                                | 0/500000 [00:01<?, ?it/s]

rewards torch.Size([16, 5])
log_probs torch.Size([16, 5])
value_preds torch.Size([16, 5])
critic_dists torch.Size([16, 5, 255])
entropies torch.Size([16, 5])
masks torch.Size([16, 5])





RuntimeError: grad can be implicitly created only for scalar outputs

In [19]:
last_value_pred, _ = agent.apply_critic(obs)

In [20]:
ep_rewards = episode_batches["rewards"]
ep_log_probs = episode_batches["log_probs"]
ep_value_preds = episode_batches["value_preds"]
batch_critic_dists = episode_batches["critic_dists"]
ep_entropies = episode_batches["entropies"]
ep_masks = episode_batches["masks"]

# FROM ORIGINAL:
# append the last value pred to the value preds tensor
last_value_pred = last_value_pred.view(1,-1).detach() # (1, B)
ep_value_preds = torch.cat((ep_value_preds, last_value_pred), dim=0) # (SEQ_LEN+1, B)

# set up tensors for the advantage calculation
returns = torch.zeros_like(ep_rewards).to(device) # (SEQ_LEN, B)
advantages = torch.zeros_like(ep_rewards).to(device) # (SEQ_LEN, B)
next_advantage = torch.zeros_like(last_value_pred) # (1, B)

# calculate advantages using GAE
for t in reversed(range(len(ep_rewards))):
    returns[t] = ep_rewards[t] + gamma * ep_masks[t] * ep_value_preds[t+1]
    td_error = returns[t] - ep_value_preds[t]
    advantages[t] = next_advantage = td_error + gamma * lam * ep_masks[t] * next_advantage

# categorical crossentropy (should be fine, I checked.)
twohot_returns = torch.stack([twohot_encode(r) for r in returns])

In [21]:
ep_value_preds.shape

torch.Size([17, 5])

In [22]:
last_value_pred.shape

torch.Size([1, 5])

In [23]:
advantages.shape

torch.Size([16, 5])

In [None]:
# two-hot encode returns
# twohot_returns = torch.stack([twohot_encode(r) for r in returns])  # (SEQ_LEN, B, NUM_BUCKETS)
# twohot_returns = twohot_returns.permute(1, 0, 2)  # (B, SEQ_LEN, NUM_BUCKETS)
# 
# # compute critic loss
# batch_critic_dists = batch_critic_dists.permute(1, 0, 2)  # (B, SEQ_LEN, NUM_BUCKETS)
# critic_loss = -torch.sum(twohot_returns * torch.log(batch_critic_dists), dim=(1, 2))
# critic_loss = torch.mean(critic_loss)

In [61]:
twohot_returns.shape

torch.Size([16, 5, 255])

In [163]:
# one sample
twohot_returns = torch.zeros(16, 1, 255).to(config["device"])

for i in range(16):
    twohot_returns[i,0,30] = 0.83
    twohot_returns[i,0,31] = 0.17

torch.manual_seed(0)
batch_critic_dists = torch.rand(16, 1, 255).to(config["device"])
batch_critic_dists = F.softmax(batch_critic_dists, dim=2)

In [174]:
# batch of samples
twohot_returns = torch.zeros(16, 7, 255).to(config["device"])
for i in range(16):
    twohot_returns[i,:,30] = 0.83 * torch.ones(7).to(config["device"])
    twohot_returns[i,:,31] = 0.17 * torch.ones(7).to(config["device"])

batch_critic_dists = torch.ones(16, 7, 255).to(config["device"])
batch_critic_dists = F.softmax(batch_critic_dists, dim=2)

In [181]:
result = agent._calculate_critic_loss(twohot_returns, batch_critic_dists)
result.shape

torch.Size([])

ValueError: only one element tensors can be converted to Python scalars

In [180]:
# ORIGINAL:

critic_loss = - twohot_returns[:,0,:].detach() @ torch.log(batch_critic_dists[:,0,:]).T
critic_loss = torch.sum(torch.diag(critic_loss))
critic_loss

tensor(88.6602, device='cuda:0')

In [None]:
# solut = 88.0143

In [63]:
a = twohot_returns.permute(1, 0, 2)
a.shape

torch.Size([5, 16, 255])

In [64]:
b = batch_critic_dists.permute(1, 0, 2)
b.shape

torch.Size([5, 16, 255])

In [66]:
critic_loss = -torch.sum(a * torch.log(b), dim=(1, 2))
critic_loss

tensor([88.6126, 88.9002, 88.6646, 88.8494, 88.8574], device='cuda:0',
       grad_fn=<NegBackward0>)

In [None]:
torch.scatter_add

In [66]:
torch.tensor(1).scatter_add_

<function Tensor.scatter_add_>

In [72]:
ep_rewards = episode_batches["rewards"]
ep_log_probs = episode_batches["log_probs"]
ep_value_preds = episode_batches["value_preds"]
batch_critic_dists = episode_batches["critic_dists"]
ep_entropies = episode_batches["entropies"]
ep_masks = episode_batches["masks"]
last_value_pred = torch.randn(5).to(device)

In [73]:
print(ep_rewards.shape)
print(last_value_pred.shape)

torch.Size([16, 5])
torch.Size([5])


In [74]:
last_value_pred = last_value_pred.unsqueeze(0).detach() # (1, B)
ep_value_preds = torch.cat((ep_value_preds, last_value_pred), dim=0) # (SEQ_LEN+1, B)

# set up tensors for the advantage calculation
returns = torch.zeros_like(ep_rewards).to(device) # (SEQ_LEN, B)
advantages = torch.zeros_like(ep_rewards).to(device) # (SEQ_LEN, B)
next_advantage = torch.zeros_like(last_value_pred) # (1, B)

In [75]:
# calculate advantages using GAE
for t in reversed(range(len(ep_rewards))):
    returns[t] = ep_rewards[t] + gamma * ep_masks[t] * ep_value_preds[t+1]
    td_error = returns[t] - ep_value_preds[t]
    advantages[t] = next_advantage = td_error + gamma * lam * ep_masks[t] * next_advantage

In [76]:
advantages.shape

torch.Size([16, 5])

In [77]:
returns.shape

torch.Size([16, 5])

In [94]:
twohot_returns = torch.stack([twohot_encode(r) for r in returns]) # (SEQ_LEN, B, NUM_BUCKETS)
critic_loss = -torch.sum(twohot_returns.detach() * torch.log(batch_critic_dists), dim=(0, 2))

In [95]:
critic_loss.shape

torch.Size([5])

In [91]:
torch.dot(twohot_returns.detach(), torch.log(batch_critic_dists).T)

RuntimeError: 1D tensors expected, but got 3D and 3D tensors

In [None]:
critic_loss = torch.sum(torch.diag(critic_loss))

In [82]:
twohot_returns.shape

torch.Size([16, 5, 255])

In [83]:
batch_critic_dists.shape

torch.Size([16, 5, 255])

In [99]:
actor_loss = -(ep_log_probs * advantages.detach()).mean() - ent_coef * ep_entropies
actor_loss.shape

torch.Size([16, 5])

In [100]:
(ep_log_probs * advantages).shape

torch.Size([16, 5])

In [81]:
twohot_returns.shape

torch.Size([16, 5, 255])

In [None]:
for key, val in episode_batches.items():
    print(key, val.shape)

In [None]:
episode_batches["rewards"]