In [1]:
import gymnasium
import ale_py
import argparse
from tensorboardX import SummaryWriter
import cv2
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from tqdm import tqdm
import copy
import colorama
import random
import json
import shutil
import pickle
import os
import wandb
import importlib

In [2]:
import sys
sys.path.append(os.path.abspath("/data/I6347325/work_space/STORM"))# Dynamically reload the modules to reflect any changes

import utils
import sub_models.replay_buffer
import env_wrapper
# import agents
import sub_models.director_agents
import sub_models.functions_losses
import sub_models.world_models
import sub_models.constants
import train

importlib.reload(utils)
importlib.reload(sub_models.replay_buffer)
importlib.reload(env_wrapper)
importlib.reload(sub_models.director_agents)
importlib.reload(sub_models.functions_losses)
importlib.reload(sub_models.world_models)
importlib.reload(sub_models.constants)
importlib.reload(train)

from utils import seed_np_torch, Logger, load_config
from sub_models.replay_buffer import ReplayBuffer
from train import (
    build_single_env,
    build_vec_env,
    build_world_model,
    build_agent,
    joint_train_world_model_agent,
)
from sub_models.constants import DEVICE
print(DEVICE, DEVICE.type)

# ignore warnings
import warnings

warnings.filterwarnings("ignore")
if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True



cuda:0 cuda


In [3]:
class RunParams:
    def __init__(self, env_names, exp_name = "STORM-Director"):
        # self._env_name = env_name
        self.exp_name = exp_name
        self.seed = 1
        self.config_path = "../config_files/STORM.yaml"
        # self.trajectory_path = f"D_TRAJ/{self._env_name}.pkl"
        self.env_names = env_names

        self.conf = load_config(self.config_path)
        self.print_args()
    def print_args(self):
        print(colorama.Fore.GREEN + "Arguments:" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "exp_name: " + colorama.Style.RESET_ALL + self.exp_name)
        print(colorama.Fore.GREEN + "seed: " + colorama.Style.RESET_ALL + str(self.seed))
        # print(colorama.Fore.GREEN + "config_path: " + colorama.Style.RESET_ALL + self.config_path)
        print(colorama.Fore.GREEN + "env_name: " + colorama.Style.RESET_ALL)
        print(self.env_names)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)

env_names = ["MiniGrid-Empty-8x8-v0", "MiniGrid-SimpleCrossingS9N3-v0", "MiniGrid-DoorKey-8x8-v0", "MiniGrid-FourRooms-v0"] 
run_params = RunParams(env_names, exp_name = "STORM-Director-2")
# set seed
seed_np_torch(seed=run_params.seed)
# tensorboard writer
logger = Logger(path=f"runs/{run_params.exp_name}")

[32mArguments:[0m
[32m-----------------[0m
[32mexp_name: [0mSTORM-Director-2
[32mseed: [0m1
[32menv_name: [0m
['MiniGrid-Empty-8x8-v0', 'MiniGrid-SimpleCrossingS9N3-v0', 'MiniGrid-DoorKey-8x8-v0', 'MiniGrid-FourRooms-v0']
[32m-----------------[0m


In [8]:
run_params.env_names

['MiniGrid-Empty-8x8-v0',
 'MiniGrid-SimpleCrossingS9N3-v0',
 'MiniGrid-DoorKey-8x8-v0',
 'MiniGrid-FourRooms-v0']

In [4]:
print(f"Train Steps: {run_params.conf.JointTrainAgent.SampleMaxSteps}")
print(f"Train Batch Size: {run_params.conf.JointTrainAgent.BatchSize}")
print(f"Train Buffer Max Length: {run_params.conf.JointTrainAgent.BufferMaxLength}")
print(f"Number of Environments: {len(run_params.env_names)}")
# Setuop env, models, replay buffer
# getting action_dim with dummy env
dummy_env = build_single_env(
    run_params.env_names[0], run_params.conf.BasicSettings.ImageSize)
action_dim = dummy_env.action_space.n

# build world model and agent
world_model = build_world_model(run_params.conf, action_dim)
agent = build_agent(run_params.conf, action_dim)
print(f"World model transformer: {world_model.storm_transformer.__class__.__name__}")
# Log the number of parameters for both models
world_model_params = sum(p.numel() for p in world_model.parameters() if p.requires_grad)
agent_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)

# build replay buffer
replay_buffer = ReplayBuffer(
    obs_shape=(run_params.conf.BasicSettings.ImageSize, run_params.conf.BasicSettings.ImageSize, 3),
    num_envs=len(run_params.env_names),
    max_length=run_params.conf.JointTrainAgent.BufferMaxLength,
    warmup_length=run_params.conf.JointTrainAgent.BufferWarmUp,
    store_on_gpu=run_params.conf.BasicSettings.ReplayBufferOnGPU,
)
# judge whether to load demonstration trajectory
if run_params.conf.JointTrainAgent.UseDemonstration:
    print(
        colorama.Fore.MAGENTA
        + f"loading demonstration trajectory from {run_params.trajectory_path}"
        + colorama.Style.RESET_ALL
    )
    replay_buffer.load_trajectory(path=run_params.trajectory_path)

Train Steps: 25000
Train Batch Size: 128
Train Buffer Max Length: 50000
Number of Environments: 4
World model transformer: StochasticTransformerKVCache


## Breakdown: joint_train_world_model_agent()

In [10]:
## setup variable names for breakdown
env_names=run_params.env_names
num_envs=len(run_params.env_names)
max_steps=run_params.conf.JointTrainAgent.SampleMaxSteps
image_size=run_params.conf.BasicSettings.ImageSize
train_dynamics_every_steps=run_params.conf.JointTrainAgent.TrainDynamicsEverySteps
train_agent_every_steps=run_params.conf.JointTrainAgent.TrainAgentEverySteps
batch_size=3 #FIXME: run_params.conf.JointTrainAgent.BatchSize
demonstration_batch_size=(
    run_params.conf.JointTrainAgent.DemonstrationBatchSize
    if run_params.conf.JointTrainAgent.UseDemonstration
    else 0
)
batch_length=16 #FIXME: run_params.conf.JointTrainAgent.BatchLength
imagine_batch_size=run_params.conf.JointTrainAgent.ImagineBatchSize
imagine_demonstration_batch_size=(
    run_params.conf.JointTrainAgent.ImagineDemonstrationBatchSize
    if run_params.conf.JointTrainAgent.UseDemonstration
    else 0
)
imagine_context_length=run_params.conf.JointTrainAgent.ImagineContextLength
imagine_batch_length=16 #FIXME: run_params.conf.JointTrainAgent.ImagineBatchLength
save_every_steps=run_params.conf.JointTrainAgent.SaveEverySteps
seed=run_params.seed
args=run_params

vec_env = build_vec_env(env_names, image_size)
print(
    "Current env: "
    + colorama.Fore.YELLOW
    + f"{len(env_names)} parallel envs"
    + colorama.Style.RESET_ALL
)

# reset envs and variables
sum_reward = np.zeros(num_envs)
current_obs, current_info = vec_env.reset()
context_obs = deque(maxlen=16)
context_action = deque(maxlen=16)

Current env: [33m4 parallel envs[0m




### Sample from env part

In [38]:
metrics ={}
for total_steps in tqdm(range(32)):
    # sample part >>>
    if replay_buffer.ready:  # ready only after warmpup
        # WM and Agent are in eval mode
        world_model.eval()
        agent.eval()
        with torch.no_grad():
            if len(context_action) == 0:
                # this is the case in the first step
                action = vec_env.action_space.sample()  # [E]
            else:
                context_latent = world_model.encode_obs(
                    torch.cat(list(context_obs), dim=1)
                )
                model_context_action = np.stack(list(context_action), axis=1)
                model_context_action = torch.Tensor(model_context_action).to(DEVICE)
                prior_flattened_sample, last_dist_feat = (
                    world_model.calc_last_dist_feat(
                        context_latent, model_context_action
                    )
                )  # [E,n,1024], [E,n,512]
                action = agent.sample_as_env_action(
                    torch.cat([prior_flattened_sample, last_dist_feat], dim=-1),
                    greedy=False,
                )  # [E]
        # [E, H, W, C] -> [E, 1, C, H, W]
        context_obs.append(
            torch.permute(
                torch.tensor(current_obs, device=DEVICE), (0, 3, 1, 2)
            ).unsqueeze(1)
            / 255
        )
        context_action.append(action)
    else:
        # sample single random action
        action = vec_env.action_space.sample()  # [E]

    # Perform action in the env and observe the next state, reward, done, truncated
    # Single Unbatched instances: # ((4, 64, 64, 3), (4,), (4,), (4,), (4,)); E=4
    obs, reward, done, truncated, info = vec_env.step(action)

    # Append the transition to the replay buffer
    replay_buffer.append(current_obs, action, reward, done)

    done_flag = np.logical_or(done, truncated)
    if done_flag.any():  # end of episode
        for i in range(num_envs):
            if done_flag[i]:
                env_id = env_names[i][:16]
                # Log reward for this environment
                metrics[f"sample/{env_id}_reward"] = sum_reward[i]
                metrics[f"sample/{env_id}_episode_steps"] = step_counters[i]
                metrics["replay_buffer/length"] = len(replay_buffer)
                # Reset reward tracker and step counter
                sum_reward[i] = 0
                step_counters[i] = 0

    # Update current_obs, current_info and sum_reward
    sum_reward += reward  # [E]
    current_obs = obs
    current_info = info
    step_counters += 1

100%|██████████| 32/32 [00:00<00:00, 643.16it/s]


In [26]:
obs, reward, done, truncated, info = vec_env.step(action)

In [39]:
replay_buffer.length

32

In [45]:
buffer_sample = replay_buffer.sample(batch_size=64, external_batch_size=0, batch_length=16)
buffer_sample["obs"].shape, buffer_sample["action"].shape, buffer_sample["reward"].shape

(torch.Size([64, 16, 3, 64, 64]), torch.Size([64, 16]), torch.Size([64, 16]))

In [44]:
buffer_sample["obs"].shape

torch.Size([64, 16, 3, 64, 64])

In [None]:
replay_buffer.buffer["obs"].shape 

torch.Size([12500, 4, 64, 64, 3])

: 

### Train world model part

In [27]:
##Train world model part >>>
wm_metrics = train_world_model(
    replay_buffer=replay_buffer,
    world_model=world_model,
    batch_size=batch_size,
    demonstration_batch_size=demonstration_batch_size,
    batch_length=batch_length,
    # logger=logger,
)
##<<< Train world model part
## Breakdown of the above code
# Sample from replay buffer
# buffer_sample = replay_buffer.sample(
#     batch_size, demonstration_batch_size, batch_length
# )
# for key, value in buffer_sample.items():
#     print(f"{key}, Value shape: {value.shape}")
# obs, Value shape: torch.Size([3, 16, 3, 64, 64])
# action, Value shape: torch.Size([3, 16])
# reward, Value shape: torch.Size([3, 16])
# termination, Value shape: torch.Size([3, 16])
# goal, Value shape: torch.Size([3, 16])
# skill, Value shape: torch.Size([3, 16])
# print(f"Shapes of obs: {obs.shape}, action: {action.shape}, reward: {reward.shape}, termination: {termination.shape}")

## Train world model with the sampled data
# world_model.update(buffer_sample["obs"], buffer_sample["action"], buffer_sample["reward"], buffer_sample["termination"], logger=logger)


### Train agent part

In [28]:
# Train agent part >>>
# print("Training Agent...")
log_video = False
imagined_rollout = world_model_imagine_data(
    replay_buffer=replay_buffer,
    world_model=world_model,
    agent=agent,
    imagine_batch_size=imagine_batch_size,
    imagine_demonstration_batch_size=imagine_demonstration_batch_size,
    imagine_context_length=imagine_context_length,
    imagine_batch_length=imagine_batch_length,
    log_video=log_video,
    # logger=logger,
)
for k, v  in imagined_rollout.items():
    print(f"Shape of {k}: {v.shape}")

## breakdown : world_model_imagine_data
# imagine_batch_size = 3
# world_model.eval()
# agent.eval()

# buffer_sample = replay_buffer.sample(
#     imagine_batch_size, imagine_demonstration_batch_size, imagine_context_length
# )
# print(f"Buffer sample items:")
# for k, v  in buffer_sample.items():
#     print(f"Shape of {k}: {v.shape}")

# imagined_rollout = world_model.imagine_data(
#     agent,
#     buffer_sample,
#     imagine_batch_size=imagine_batch_size + imagine_demonstration_batch_size,
#     imagine_batch_length=imagine_batch_length,
#     log_video=log_video,
#     logger=logger,
# )
# print(f"\n\nImagine rollout items:")
# for k, v  in imagined_rollout.items():
#     print(f"{k}: {v.shape}")

init_imagine_buffer: 1024x16@torch.float32
Shape of sample: torch.Size([1024, 16, 1024])
Shape of hidden: torch.Size([1024, 16, 512])
Shape of action: torch.Size([1024, 16])
Shape of reward: torch.Size([1024, 16])
Shape of termination: torch.Size([1024, 16])
Shape of goal: torch.Size([1024, 16, 1024])
Shape of skill: torch.Size([1024, 16, 8, 8])


In [29]:
# Update agent with imagined data
metrics = agent.update(imagined_rollout)
# <<< Train agent part

TypeError: can't convert cuda:1 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [26]:
from pprint import pprint
# Print the metrics
pprint(metrics)

{'goal_VAE_loss': -685.745361328125,
 'goal_kl_loss': 0.2932986617088318,
 'goal_recon_loss': -685.745361328125,
 'manager_ActorCritic/S': 0.040681224316358566,
 'manager_ActorCritic/critic_loss': 12.936750411987305,
 'manager_ActorCritic/entropy_loss': 1.984375,
 'manager_ActorCritic/norm_ratio': 1.0,
 'manager_ActorCritic/policy_loss': 1.179916501045227,
 'manager_ActorCritic/total_loss': 12.233854293823242,
 'success_manager': 0.0,
 'worker_ActorCritic/S': 0.021855171769857407,
 'worker_ActorCritic/critic_loss': 11.247394561767578,
 'worker_ActorCritic/entropy_loss': 2.140625,
 'worker_ActorCritic/norm_ratio': 1.0,
 'worker_ActorCritic/policy_loss': -0.05270551145076752,
 'worker_ActorCritic/total_loss': 9.16343879699707}


### Final full call

In [None]:
# Initialize wandb
# with wandb.init(
#     project="WMBRL",  # Replace with your project name
#     name=run_params.exp_name,   # Use the experiment name from RunParam
#     config = {
#         "env_name": run_params.env_name,
#         "seed": run_params.seed,
#     }
# ) as run:
    # Log the configuration to wandb
    # run.config.update(run_params.conf)
    # run.log({"WM_params": f"{world_model_params:.2e}", "Agent_params": f"{agent_params:.2e}"})
    # logger = WandbLogger(run)
    # train
joint_train_world_model_agent(
    env_name=run_params.env_name,
    num_envs=run_params.conf.JointTrainAgent.NumEnvs,
    max_steps=run_params.conf.JointTrainAgent.SampleMaxSteps,
    image_size=run_params.conf.BasicSettings.ImageSize,
    replay_buffer=replay_buffer,
    world_model=world_model,
    agent=agent,
    train_dynamics_every_steps=run_params.conf.JointTrainAgent.TrainDynamicsEverySteps,
    train_agent_every_steps=run_params.conf.JointTrainAgent.TrainAgentEverySteps,
    batch_size=run_params.conf.JointTrainAgent.BatchSize,
    demonstration_batch_size=(
        run_params.conf.JointTrainAgent.DemonstrationBatchSize
        if run_params.conf.JointTrainAgent.UseDemonstration
        else 0
    ),
    batch_length=run_params.conf.JointTrainAgent.BatchLength,
    imagine_batch_size=run_params.conf.JointTrainAgent.ImagineBatchSize,
    imagine_demonstration_batch_size=(
        run_params.conf.JointTrainAgent.ImagineDemonstrationBatchSize
        if run_params.conf.JointTrainAgent.UseDemonstration
        else 0
    ),
    imagine_context_length=run_params.conf.JointTrainAgent.ImagineContextLength,
    imagine_batch_length=run_params.conf.JointTrainAgent.ImagineBatchLength,
    save_every_steps=run_params.conf.JointTrainAgent.SaveEverySteps,
    seed=run_params.seed,
    logger=logger,
    args=run_params,
)



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mriju11-mukherjee[0m ([33mrm_ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Current env: [33mALE/MsPacman-v5[0m


  0%|          | 0/15000 [00:00<?, ?it/s]

[32mSaving model at total steps 0[0m


  6%|▌         | 931/15000 [00:01<00:16, 870.59it/s]

init_imagine_buffer: 1024x16@torch.float16


  8%|▊         | 1155/15000 [01:39<2:22:26,  1.62it/s]