In [1]:
import gymnasium
import ale_py
import argparse
from tensorboardX import SummaryWriter
import cv2
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from tqdm import tqdm
import copy
import colorama
import random
import json
import shutil
import pickle
import os
import wandb
import importlib

In [2]:
# Dynamically reload the modules to reflect any changes
import utils
import replay_buffer
import env_wrapper
import agents
import sub_models.functions_losses
import sub_models.world_models
import sub_models.constants
import train

importlib.reload(utils)
importlib.reload(replay_buffer)
importlib.reload(env_wrapper)
importlib.reload(agents)
importlib.reload(sub_models.functions_losses)
importlib.reload(sub_models.world_models)
importlib.reload(sub_models.constants)
importlib.reload(train)

from utils import seed_np_torch, Logger, load_config
from replay_buffer import ReplayBuffer
from train import (
    build_single_env,
    build_vec_env,
    build_world_model,
    build_agent,
    train_world_model_step,
    world_model_imagine_data,
    joint_train_world_model_agent,
)
from sub_models.constants import DEVICE
print(DEVICE, DEVICE.type)

cuda:1 cuda


In [3]:
# ignore warnings
import warnings

warnings.filterwarnings("ignore")
if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [4]:
class WandbLogger:
    def __init__(self, run):
        self.run = run

    def log(self, key, value, step=None):
        """Log a key-value pair to wandb with optional step."""
        log_dict = {key: value}
        if step is not None:
            self.run.log(log_dict, step=step)
        else:
            self.run.log(log_dict)

In [5]:

class RunParams:
    def __init__(self, env_name="MsPacman", exp_name = "TEM-Transformer"):
        self._env_name = env_name
        self.exp_name = exp_name
        self.seed = 1
        self.config_path = "config_files/STORM.yaml"
        self.trajectory_path = f"D_TRAJ/{self._env_name}.pkl"
        self.env_name = f"ALE/{self._env_name}-v5"

        self.conf = load_config(self.config_path)
        self.print_args()
    def print_args(self):
        print(colorama.Fore.GREEN + "Arguments:" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)
        print(colorama.Fore.GREEN + "exp_name: " + colorama.Style.RESET_ALL + self.exp_name)
        print(colorama.Fore.GREEN + "seed: " + colorama.Style.RESET_ALL + str(self.seed))
        print(colorama.Fore.GREEN + "config_path: " + colorama.Style.RESET_ALL + self.config_path)
        print(colorama.Fore.GREEN + "trajectory_path: " + colorama.Style.RESET_ALL + self.trajectory_path)
        print(colorama.Fore.GREEN + "env_name: " + colorama.Style.RESET_ALL + self.env_name)
        print(colorama.Fore.GREEN + "-----------------" + colorama.Style.RESET_ALL)
    
    # def get_configs(self):
        
    #     config_dict = {
    #         "env_ImageSize": self.conf["BasicSettings"]["ImageSize"],
    #         "env_ReplayBufferOnGPU": self.conf["BasicSettings"]["ReplayBufferOnGPU"],
    #         "WM_InChannels": self.conf["Models"]["WorldModel"]["InChannels"],
    #         "WM_TransformerMaxLength": self.conf["Models"]["WorldModel"]["TransformerMaxLength"],
    #         "WM_TransformerHiddenDim": self.conf["Models"]["WorldModel"]["TransformerHiddenDim"],
    #         "WM_TransformerNumLayers": self.conf["Models"]["WorldModel"]["TransformerNumLayers"],
    #         "WM_TransformerNumHeads": self.conf["Models"]["WorldModel"]["TransformerNumHeads"],
    #         "Agent_NumLayers": self.conf["Models"]["Agent"]["NumLayers"],
    #         "Agent_HiddenDim": self.conf["Models"]["Agent"]["HiddenDim"],
    #         "Agent_Gamma": self.conf["Models"]["Agent"]["Gamma"],
    #         "Agent_Lambda": self.conf["Models"]["Agent"]["Lambda"],
    #         "Agent_EntropyCoef": self.conf["Models"]["Agent"]["EntropyCoef"],
    #         "Train_MaxSteps": self.conf["JointTrainAgent"]["SampleMaxSteps"],
    #         "Train_BufferMaxLength": self.conf["JointTrainAgent"]["BufferMaxLength"],
    #         "Train_BufferWarmUp": self.conf["JointTrainAgent"]["BufferWarmUp"],
    #         "Train_NumEnvs": self.conf["JointTrainAgent"]["NumEnvs"],
    #         "Train_BatchSize": self.conf["JointTrainAgent"]["BatchSize"],
    #         "Train_DemonstrationBatchSize": self.conf["JointTrainAgent"]["DemonstrationBatchSize"],
    #         "Train_BatchLength": self.conf["JointTrainAgent"]["BatchLength"],
    #         "Train_ImagineBatchSize": self.conf["JointTrainAgent"]["ImagineBatchSize"],
    #         "Train_ImagineDemonstrationBatchSize": self.conf["JointTrainAgent"]["ImagineDemonstrationBatchSize"],
    #         "Train_ImagineContextLength": self.conf["JointTrainAgent"]["ImagineContextLength"],
    #         "Train_ImagineBatchLength": self.conf["JointTrainAgent"]["ImagineBatchLength"],
    #         "Train_TrainDynamicsEverySteps": self.conf["JointTrainAgent"]["TrainDynamicsEverySteps"],
    #         "Train_TrainAgentEverySteps": self.conf["JointTrainAgent"]["TrainAgentEverySteps"],
    #         "Train_SaveEverySteps": self.conf["JointTrainAgent"]["SaveEverySteps"],
    #         "Train_UseDemonstration": self.conf["JointTrainAgent"]["UseDemonstration"],
    #     }
    #     return config_dict

run_params = RunParams(env_name="MsPacman", exp_name = "TEM-Transformer_1")
# set seed
seed_np_torch(seed=run_params.seed)
# tensorboard writer
logger = Logger(path=f"runs/{run_params.exp_name}")
# copy config file
shutil.copy(run_params.config_path, f"runs/{run_params.exp_name}/config.yaml")

[32mArguments:[0m
[32m-----------------[0m
[32mexp_name: [0mTEM-Transformer_1
[32mseed: [0m1
[32mconfig_path: [0mconfig_files/STORM.yaml
[32mtrajectory_path: [0mD_TRAJ/MsPacman.pkl
[32menv_name: [0mALE/MsPacman-v5
[32m-----------------[0m


'runs/TEM-Transformer_1/config.yaml'

In [6]:
print(f"Train Steps: {run_params.conf.JointTrainAgent.SampleMaxSteps}")
print(f"Train Batch Size: {run_params.conf.JointTrainAgent.BatchSize}")
print(f"Train Buffer Max Length: {run_params.conf.JointTrainAgent.BufferMaxLength}")
# Setuop env, models, replay buffer
# getting action_dim with dummy env
dummy_env = build_single_env(
    run_params.env_name, run_params.conf.BasicSettings.ImageSize, seed=0
)
action_dim = dummy_env.action_space.n

# build world model and agent
world_model = build_world_model(run_params.conf, action_dim)
agent = build_agent(run_params.conf, action_dim)
print(f"World model transformer: {world_model.storm_transformer.__class__.__name__}")
# Log the number of parameters for both models
world_model_params = sum(p.numel() for p in world_model.parameters() if p.requires_grad)
agent_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)

# build replay buffer
replay_buffer = ReplayBuffer(
    obs_shape=(run_params.conf.BasicSettings.ImageSize, run_params.conf.BasicSettings.ImageSize, 3),
    num_envs=run_params.conf.JointTrainAgent.NumEnvs,
    max_length=run_params.conf.JointTrainAgent.BufferMaxLength,
    warmup_length=run_params.conf.JointTrainAgent.BufferWarmUp,
    store_on_gpu=run_params.conf.BasicSettings.ReplayBufferOnGPU,
)
# judge whether to load demonstration trajectory
if run_params.conf.JointTrainAgent.UseDemonstration:
    print(
        colorama.Fore.MAGENTA
        + f"loading demonstration trajectory from {run_params.trajectory_path}"
        + colorama.Style.RESET_ALL
    )
    replay_buffer.load_trajectory(path=run_params.trajectory_path)


Train Steps: 15000
Train Batch Size: 256
Train Buffer Max Length: 100000


A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


World model transformer: TEMTransformerKVCache


In [None]:
# Initialize wandb
with wandb.init(
    project="WMBRL",  # Replace with your project name
    name=run_params.exp_name,   # Use the experiment name from RunParam
    config = {
        "env_name": run_params.env_name,
        "seed": run_params.seed,
    }
) as run:
    # Log the configuration to wandb
    run.config.update(run_params.conf)
    run.log({"WM_params": f"{world_model_params:.2e}", "Agent_params": f"{agent_params:.2e}"})
    logger = WandbLogger(run)
    # train
    joint_train_world_model_agent(
        env_name=run_params.env_name,
        num_envs=run_params.conf.JointTrainAgent.NumEnvs,
        max_steps=run_params.conf.JointTrainAgent.SampleMaxSteps,
        image_size=run_params.conf.BasicSettings.ImageSize,
        replay_buffer=replay_buffer,
        world_model=world_model,
        agent=agent,
        train_dynamics_every_steps=run_params.conf.JointTrainAgent.TrainDynamicsEverySteps,
        train_agent_every_steps=run_params.conf.JointTrainAgent.TrainAgentEverySteps,
        batch_size=run_params.conf.JointTrainAgent.BatchSize,
        demonstration_batch_size=(
            run_params.conf.JointTrainAgent.DemonstrationBatchSize
            if run_params.conf.JointTrainAgent.UseDemonstration
            else 0
        ),
        batch_length=run_params.conf.JointTrainAgent.BatchLength,
        imagine_batch_size=run_params.conf.JointTrainAgent.ImagineBatchSize,
        imagine_demonstration_batch_size=(
            run_params.conf.JointTrainAgent.ImagineDemonstrationBatchSize
            if run_params.conf.JointTrainAgent.UseDemonstration
            else 0
        ),
        imagine_context_length=run_params.conf.JointTrainAgent.ImagineContextLength,
        imagine_batch_length=run_params.conf.JointTrainAgent.ImagineBatchLength,
        save_every_steps=run_params.conf.JointTrainAgent.SaveEverySteps,
        seed=run_params.seed,
        logger=logger,
        args=run_params,
    )



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mriju11-mukherjee[0m ([33mrm_ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Current env: [33mALE/MsPacman-v5[0m


  0%|          | 0/15000 [00:00<?, ?it/s]

[32mSaving model at total steps 0[0m


  6%|▌         | 931/15000 [00:01<00:16, 870.59it/s]

init_imagine_buffer: 1024x16@torch.float16


  8%|▊         | 1155/15000 [01:39<2:22:26,  1.62it/s]