In [45]:
import gymnasium
import ale_py
import argparse
from tensorboardX import SummaryWriter
import cv2
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from tqdm import tqdm
import copy
import colorama
import random
import json
import shutil
import pickle
import os
# import wandb
import importlib

In [46]:

import sys
sys.path.append(os.path.abspath("/Users/rijulizer/work_space/Thesis/STORM"))# Dynamically reload the modules to reflect any changes

import utils
import sub_models.replay_buffer
import env_wrapper
# import agents
import sub_models.director_agents
import sub_models.functions_losses
import sub_models.world_models
import sub_models.constants
import train

importlib.reload(utils)
importlib.reload(sub_models.replay_buffer)
importlib.reload(env_wrapper)
importlib.reload(sub_models.director_agents)
importlib.reload(sub_models.functions_losses)
importlib.reload(sub_models.world_models)
importlib.reload(sub_models.constants)
importlib.reload(train)

from utils import seed_np_torch, Logger, load_config
from sub_models.replay_buffer import ReplayBuffer
from train import (
    build_single_env,
    build_vec_env,
    build_world_model,
    build_agent,
    train_world_model,
    world_model_imagine_data,
    joint_train_world_model_agent,
)
from sub_models.constants import DEVICE
print(DEVICE, DEVICE.type)

cpu cpu


In [47]:

# ignore warnings
import warnings

warnings.filterwarnings("ignore")
if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
class RunParams:
    def __init__(self, env_name="MsPacman", exp_name = "TEM-Transformer"):
        self._env_name = env_name
        self.exp_name = exp_name
        self.seed = 1
        self.config_path = "../config_files/STORM.yaml"
        self.trajectory_path = f"D_TRAJ/{self._env_name}.pkl"
        self.env_name = f"ALE/{self._env_name}-v5"

        self.conf = load_config(self.config_path)
        self.print_args()
    def print_args(self):
        print(colorama.Fore.GREEN + "Arguments:" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "exp_name: " + colorama.Style.RESET_ALL + self.exp_name)
        print(colorama.Fore.GREEN + "seed: " + colorama.Style.RESET_ALL + str(self.seed))
        print(colorama.Fore.GREEN + "config_path: " + colorama.Style.RESET_ALL + self.config_path)
        print(colorama.Fore.GREEN + "trajectory_path: " + colorama.Style.RESET_ALL + self.trajectory_path)
        print(colorama.Fore.GREEN + "env_name: " + colorama.Style.RESET_ALL + self.env_name)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)
    
    # def get_configs(self):
        
    #     config_dict = {
    #         "env_ImageSize": self.conf["BasicSettings"]["ImageSize"],
    #         "env_ReplayBufferOnGPU": self.conf["BasicSettings"]["ReplayBufferOnGPU"],
    #         "WM_InChannels": self.conf["Models"]["WorldModel"]["InChannels"],
    #         "WM_TransformerMaxLength": self.conf["Models"]["WorldModel"]["TransformerMaxLength"],
    #         "WM_TransformerHiddenDim": self.conf["Models"]["WorldModel"]["TransformerHiddenDim"],
    #         "WM_TransformerNumLayers": self.conf["Models"]["WorldModel"]["TransformerNumLayers"],
    #         "WM_TransformerNumHeads": self.conf["Models"]["WorldModel"]["TransformerNumHeads"],
    #         "Agent_NumLayers": self.conf["Models"]["Agent"]["NumLayers"],
    #         "Agent_HiddenDim": self.conf["Models"]["Agent"]["HiddenDim"],
    #         "Agent_Gamma": self.conf["Models"]["Agent"]["Gamma"],
    #         "Agent_Lambda": self.conf["Models"]["Agent"]["Lambda"],
    #         "Agent_EntropyCoef": self.conf["Models"]["Agent"]["EntropyCoef"],
    #         "Train_MaxSteps": self.conf["JointTrainAgent"]["SampleMaxSteps"],
    #         "Train_BufferMaxLength": self.conf["JointTrainAgent"]["BufferMaxLength"],
    #         "Train_BufferWarmUp": self.conf["JointTrainAgent"]["BufferWarmUp"],
    #         "Train_NumEnvs": self.conf["JointTrainAgent"]["NumEnvs"],
    #         "Train_BatchSize": self.conf["JointTrainAgent"]["BatchSize"],
    #         "Train_DemonstrationBatchSize": self.conf["JointTrainAgent"]["DemonstrationBatchSize"],
    #         "Train_BatchLength": self.conf["JointTrainAgent"]["BatchLength"],
    #         "Train_ImagineBatchSize": self.conf["JointTrainAgent"]["ImagineBatchSize"],
    #         "Train_ImagineDemonstrationBatchSize": self.conf["JointTrainAgent"]["ImagineDemonstrationBatchSize"],
    #         "Train_ImagineContextLength": self.conf["JointTrainAgent"]["ImagineContextLength"],
    #         "Train_ImagineBatchLength": self.conf["JointTrainAgent"]["ImagineBatchLength"],
    #         "Train_TrainDynamicsEverySteps": self.conf["JointTrainAgent"]["TrainDynamicsEverySteps"],
    #         "Train_TrainAgentEverySteps": self.conf["JointTrainAgent"]["TrainAgentEverySteps"],
    #         "Train_SaveEverySteps": self.conf["JointTrainAgent"]["SaveEverySteps"],
    #         "Train_UseDemonstration": self.conf["JointTrainAgent"]["UseDemonstration"],
    #     }
    #     return config_dict

run_params = RunParams(env_name="MsPacman", exp_name = "TEM-Transformer_1")
# set seed
seed_np_torch(seed=run_params.seed)
# tensorboard writer
logger = Logger(path=f"runs/{run_params.exp_name}")
# copy config file
# shutil.copy(run_params.config_path, f"runs/{run_params.exp_name}/config.yaml")

[32mArguments:[0m
[32m-----------------[0m
[32mexp_name: [0mTEM-Transformer_1
[32mseed: [0m1
[32mconfig_path: [0m../config_files/STORM.yaml
[32mtrajectory_path: [0mD_TRAJ/MsPacman.pkl
[32menv_name: [0mALE/MsPacman-v5
[32m-----------------[0m


In [48]:
# Setuop env, models, replay buffer
# getting action_dim with dummy env
dummy_env = build_single_env(
    run_params.env_name, run_params.conf.BasicSettings.ImageSize, seed=0
)
action_dim = dummy_env.action_space.n

# build world model and agent
world_model = build_world_model(run_params.conf, action_dim)
agent = build_agent(run_params.conf, action_dim)
print(f"World model transformer: {world_model.storm_transformer.__class__.__name__}")
# Log the number of parameters for both models
world_model_params = sum(p.numel() for p in world_model.parameters() if p.requires_grad)
agent_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
print(f"World model parameters: {world_model_params}")
print(f"Agent parameters: {agent_params}")
# Build replay buffer
replay_buffer = ReplayBuffer(
    num_envs=run_params.conf.JointTrainAgent.NumEnvs,
    obs_shape=(run_params.conf.BasicSettings.ImageSize, run_params.conf.BasicSettings.ImageSize, 3),
    max_length=run_params.conf.JointTrainAgent.BufferMaxLength,
    warmup_length=20,  #FIXME: run_params.conf.JointTrainAgent.BufferWarmUp,
    store_on_gpu=run_params.conf.BasicSettings.ReplayBufferOnGPU,
)



World model transformer: StochasticTransformerKVCache
World model parameters: 16508547
Agent parameters: 5367561


In [38]:
# run_params.conf.JointTrainAgent.ImagineContextLength, run_params.conf.JointTrainAgent.ImagineBatchLength, run_params.conf.JointTrainAgent.SaveEverySteps,run_params.seed,

In [49]:
metrics = joint_train_world_model_agent(
    env_name=run_params.env_name,
    num_envs=run_params.conf.JointTrainAgent.NumEnvs,
    max_steps=80,
    image_size=run_params.conf.BasicSettings.ImageSize,
    replay_buffer=replay_buffer,
    world_model=world_model,
    agent=agent,
    train_dynamics_every_steps=1,
    train_agent_every_steps=1,
    batch_size=64,
    demonstration_batch_size=0,
    batch_length=16,
    imagine_batch_size=64,
    imagine_demonstration_batch_size=0,
    imagine_context_length=8,
    imagine_batch_length=16,
    save_every_steps=2500,
    seed=run_params.seed,
    logger=logger,
    args=run_params,
)



A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


Current env: [33mALE/MsPacman-v5[0m


  0%|          | 0/80 [00:00<?, ?it/s]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 26%|██▋       | 21/80 [00:08<00:23,  2.55it/s]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 28%|██▊       | 22/80 [00:16<00:50,  1.15it/s]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 29%|██▉       | 23/80 [00:24<01:25,  1.49s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 30%|███       | 24/80 [00:32<02:03,  2.20s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 31%|███▏      | 25/80 [00:40<02:42,  2.95s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 32%|███▎      | 26/80 [00:48<03:21,  3.73s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 34%|███▍      | 27/80 [00:56<03:58,  4.50s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 35%|███▌      | 28/80 [01:04<04:33,  5.26s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 36%|███▋      | 29/80 [01:12<05:02,  5.93s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 38%|███▊      | 30/80 [01:20<05:24,  6.49s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 39%|███▉      | 31/80 [01:29<05:41,  6.96s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 40%|████      | 32/80 [01:37<05:54,  7.39s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 41%|████▏     | 33/80 [01:45<05:55,  7.56s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 42%|████▎     | 34/80 [01:53<05:54,  7.70s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 44%|████▍     | 35/80 [02:01<05:49,  7.77s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 45%|████▌     | 36/80 [02:09<05:45,  7.86s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 46%|████▋     | 37/80 [02:17<05:42,  7.97s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 48%|████▊     | 38/80 [02:26<05:36,  8.01s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 49%|████▉     | 39/80 [02:34<05:34,  8.15s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 50%|█████     | 40/80 [02:43<05:31,  8.29s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 51%|█████▏    | 41/80 [02:51<05:21,  8.26s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 52%|█████▎    | 42/80 [02:59<05:13,  8.26s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 54%|█████▍    | 43/80 [03:07<05:03,  8.22s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 55%|█████▌    | 44/80 [03:15<04:55,  8.20s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 56%|█████▋    | 45/80 [03:23<04:46,  8.17s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 57%|█████▊    | 46/80 [03:32<04:41,  8.27s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 59%|█████▉    | 47/80 [03:40<04:31,  8.22s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 60%|██████    | 48/80 [03:49<04:25,  8.31s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 61%|██████▏   | 49/80 [03:57<04:14,  8.22s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 62%|██████▎   | 50/80 [04:05<04:06,  8.20s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 64%|██████▍   | 51/80 [04:13<03:56,  8.16s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 65%|██████▌   | 52/80 [04:24<04:09,  8.93s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 66%|██████▋   | 53/80 [04:32<03:58,  8.82s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 68%|██████▊   | 54/80 [04:40<03:45,  8.68s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 69%|██████▉   | 55/80 [04:49<03:35,  8.61s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 70%|███████   | 56/80 [04:57<03:23,  8.49s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 71%|███████▏  | 57/80 [05:05<03:13,  8.41s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 72%|███████▎  | 58/80 [05:13<03:02,  8.32s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 74%|███████▍  | 59/80 [05:22<02:54,  8.29s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 75%|███████▌  | 60/80 [05:30<02:45,  8.27s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 76%|███████▋  | 61/80 [05:38<02:38,  8.34s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 78%|███████▊  | 62/80 [05:46<02:28,  8.26s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 79%|███████▉  | 63/80 [05:55<02:19,  8.22s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 80%|████████  | 64/80 [06:03<02:11,  8.21s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 81%|████████▏ | 65/80 [06:11<02:03,  8.21s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 82%|████████▎ | 66/80 [06:20<01:56,  8.34s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 84%|████████▍ | 67/80 [06:28<01:47,  8.28s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 85%|████████▌ | 68/80 [06:37<01:41,  8.42s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 86%|████████▋ | 69/80 [06:46<01:36,  8.78s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 88%|████████▊ | 70/80 [06:54<01:26,  8.64s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 89%|████████▉ | 71/80 [07:03<01:16,  8.49s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 90%|█████████ | 72/80 [07:11<01:07,  8.40s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 91%|█████████▏| 73/80 [07:19<00:58,  8.40s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 92%|█████████▎| 74/80 [07:28<00:50,  8.46s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 94%|█████████▍| 75/80 [07:36<00:41,  8.38s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 95%|█████████▌| 76/80 [07:44<00:33,  8.30s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 96%|█████████▋| 77/80 [07:52<00:24,  8.30s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 98%|█████████▊| 78/80 [08:01<00:16,  8.25s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


 99%|█████████▉| 79/80 [08:09<00:08,  8.22s/it]

Training World Model...
Training Agent...
init_imagine_buffer: 64x16@torch.float32


100%|██████████| 80/80 [08:17<00:00,  6.22s/it]


## Breakdown: joint_train_world_model_agent()

In [36]:
## setup variable names for breakdown
env_name=run_params.env_name
num_envs=run_params.conf.JointTrainAgent.NumEnvs
max_steps=run_params.conf.JointTrainAgent.SampleMaxSteps
image_size=run_params.conf.BasicSettings.ImageSize
train_dynamics_every_steps=run_params.conf.JointTrainAgent.TrainDynamicsEverySteps
train_agent_every_steps=run_params.conf.JointTrainAgent.TrainAgentEverySteps
batch_size=3 #FIXME: run_params.conf.JointTrainAgent.BatchSize
demonstration_batch_size=(
    run_params.conf.JointTrainAgent.DemonstrationBatchSize
    if run_params.conf.JointTrainAgent.UseDemonstration
    else 0
)
batch_length=16 #FIXME: run_params.conf.JointTrainAgent.BatchLength
imagine_batch_size=run_params.conf.JointTrainAgent.ImagineBatchSize
imagine_demonstration_batch_size=(
    run_params.conf.JointTrainAgent.ImagineDemonstrationBatchSize
    if run_params.conf.JointTrainAgent.UseDemonstration
    else 0
)
imagine_context_length=run_params.conf.JointTrainAgent.ImagineContextLength
imagine_batch_length=16 #FIXME: run_params.conf.JointTrainAgent.ImagineBatchLength
save_every_steps=run_params.conf.JointTrainAgent.SaveEverySteps
seed=run_params.seed
args=run_params


## Setup env
vec_env = build_vec_env(env_name, image_size, num_envs=1, seed=seed)
print(
    "Current env: "
    + colorama.Fore.YELLOW
    + f"{env_name}"
    + colorama.Style.RESET_ALL
)

# reset envs and variables
sum_reward = np.zeros(num_envs)
current_obs, current_info = vec_env.reset()
context_obs = deque(maxlen=16)
context_action = deque(maxlen=16)

Current env: [33mALE/MsPacman-v5[0m


A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


### Sample from env part

In [38]:
for total_steps in tqdm(range(32)):
    # sample part >>>
    if replay_buffer.ready:  # ready only after warmpup
        print("Replay buffer ready", total_steps)
        # WM and Agent are in eval mode
        world_model.eval()
        agent.eval()
        with torch.no_grad():
            if len(context_action) == 0:
                # this is the case in the first step
                action = vec_env.action_space.sample()
            else:
                context_latent = world_model.encode_obs(
                    torch.cat(list(context_obs), dim=1)
                )
                model_context_action = np.stack(list(context_action), axis=1)
                model_context_action = torch.Tensor(model_context_action).to(DEVICE)
                prior_flattened_sample, last_dist_feat = (
                    world_model.calc_last_dist_feat(
                        context_latent, model_context_action
                    )
                )
                latent = torch.cat([prior_flattened_sample, last_dist_feat], dim=-1)
                # get the action, goal and skill from the agent
                action = agent.sample_as_env_action(latent)
        # [B, H, W, C] -> [B, 1, C, H, W] # B=1
        context_obs.append(
            torch.permute(
                torch.tensor(current_obs, device=DEVICE), (0, 3, 1, 2)
            ).unsqueeze(1)
            / 255
        )
        context_action.append(action)
    else:
        # simply sample random action
        action = vec_env.action_space.sample()

    # Perform action in the env and observe the next state, reward, done, truncated
    obs, reward, done, truncated, info = vec_env.step(action)

    # Append the transition to the replay buffer
    replay_buffer.append(
        current_obs, action, reward, np.logical_or(done, info["life_loss"])
        )

    done_flag = np.logical_or(done, truncated)
    if done_flag.any():  # end of episode
        for i in range(num_envs):
            if done_flag[i]:
                sum_reward[i] = 0

    # Update current_obs, current_info and sum_reward
    sum_reward += reward
    current_obs = obs
    current_info = info
    # <<< sample part

100%|██████████| 32/32 [00:00<00:00, 173.89it/s]

Replay buffer ready 21
Replay buffer ready 22
Replay buffer ready 23
Replay buffer ready 24
Replay buffer ready 25
Replay buffer ready 26
Replay buffer ready 27
Replay buffer ready 28
Replay buffer ready 29
Replay buffer ready 30
Replay buffer ready 31





In [39]:
obs.shape, action.shape, reward.shape, done.shape, truncated.shape

((1, 64, 64, 3), (1,), (1,), (1,), (1,))

In [21]:
replay_buffer.length

32

In [337]:
# manually fill the replay buffer goal and skill
# goal and skill should ideally be appended in the replay buffer
# replay_buffer.buffer["goal"][0:32] = torch.zeros(32,1,1024)
# replay_buffer.buffer["skill"][0:32] = torch.zeros(32,1, 8,8)

### Train world model part

In [22]:
##Train world model part >>>
train_world_model_step(
    replay_buffer=replay_buffer,
    world_model=world_model,
    batch_size=batch_size,
    demonstration_batch_size=demonstration_batch_size,
    batch_length=batch_length,
    logger=logger,
)
##<<< Train world model part
## Breakdown of the above code
# Sample from replay buffer
# buffer_sample = replay_buffer.sample(
#     batch_size, demonstration_batch_size, batch_length
# )
# for key, value in buffer_sample.items():
#     print(f"{key}, Value shape: {value.shape}")
# obs, Value shape: torch.Size([3, 16, 3, 64, 64])
# action, Value shape: torch.Size([3, 16])
# reward, Value shape: torch.Size([3, 16])
# termination, Value shape: torch.Size([3, 16])
# goal, Value shape: torch.Size([3, 16])
# skill, Value shape: torch.Size([3, 16])
# print(f"Shapes of obs: {obs.shape}, action: {action.shape}, reward: {reward.shape}, termination: {termination.shape}")

## Train world model with the sampled data
# world_model.update(buffer_sample["obs"], buffer_sample["action"], buffer_sample["reward"], buffer_sample["termination"], logger=logger)


### Train agent part

In [23]:
# Train agent part >>>
# print("Training Agent...")
log_video = False
imagined_rollout = world_model_imagine_data(
    replay_buffer=replay_buffer,
    world_model=world_model,
    agent=agent,
    imagine_batch_size=imagine_batch_size,
    imagine_demonstration_batch_size=imagine_demonstration_batch_size,
    imagine_context_length=imagine_context_length,
    imagine_batch_length=imagine_batch_length,
    log_video=log_video,
    logger=logger,
)
for k, v  in imagined_rollout.items():
    print(f"Shape of {k}: {v.shape}")

## breakdown : world_model_imagine_data
# imagine_batch_size = 3
# world_model.eval()
# agent.eval()

# buffer_sample = replay_buffer.sample(
#     imagine_batch_size, imagine_demonstration_batch_size, imagine_context_length
# )
# print(f"Buffer sample items:")
# for k, v  in buffer_sample.items():
#     print(f"Shape of {k}: {v.shape}")

# imagined_rollout = world_model.imagine_data(
#     agent,
#     buffer_sample,
#     imagine_batch_size=imagine_batch_size + imagine_demonstration_batch_size,
#     imagine_batch_length=imagine_batch_length,
#     log_video=log_video,
#     logger=logger,
# )
# print(f"\n\nImagine rollout items:")
# for k, v  in imagined_rollout.items():
#     print(f"{k}: {v.shape}")

init_imagine_buffer: 1024x16@torch.float32
Shape of sample: torch.Size([1024, 17, 1024])
Shape of hidden: torch.Size([1024, 17, 512])
Shape of action: torch.Size([1024, 16])
Shape of reward: torch.Size([1024, 16])
Shape of termination: torch.Size([1024, 16])
Shape of goal: torch.Size([1024, 16, 1024])
Shape of skill: torch.Size([1024, 16, 8, 8])


In [24]:
# For the sake of testing the code, L=16
imagined_rollout["sample"] = imagined_rollout["sample"][:, 0:16]
imagined_rollout["hidden"] = imagined_rollout["hidden"][:, 0:16]
print(f"\n\nImagine rollout items:")
for k, v  in imagined_rollout.items():
    print(f"{k}: {v.shape}")



Imagine rollout items:
sample: torch.Size([1024, 16, 1024])
hidden: torch.Size([1024, 16, 512])
action: torch.Size([1024, 16])
reward: torch.Size([1024, 16])
termination: torch.Size([1024, 16])
goal: torch.Size([1024, 16, 1024])
skill: torch.Size([1024, 16, 8, 8])


In [25]:
# Update agent with imagined data
metrics = agent.update(imagined_rollout)
# <<< Train agent part

In [26]:
from pprint import pprint
# Print the metrics
pprint(metrics)

{'goal_VAE_loss': -685.745361328125,
 'goal_kl_loss': 0.2932986617088318,
 'goal_recon_loss': -685.745361328125,
 'manager_ActorCritic/S': 0.040681224316358566,
 'manager_ActorCritic/critic_loss': 12.936750411987305,
 'manager_ActorCritic/entropy_loss': 1.984375,
 'manager_ActorCritic/norm_ratio': 1.0,
 'manager_ActorCritic/policy_loss': 1.179916501045227,
 'manager_ActorCritic/total_loss': 12.233854293823242,
 'success_manager': 0.0,
 'worker_ActorCritic/S': 0.021855171769857407,
 'worker_ActorCritic/critic_loss': 11.247394561767578,
 'worker_ActorCritic/entropy_loss': 2.140625,
 'worker_ActorCritic/norm_ratio': 1.0,
 'worker_ActorCritic/policy_loss': -0.05270551145076752,
 'worker_ActorCritic/total_loss': 9.16343879699707}
