In [1]:
import functools
import os
# os.chdir('/home/mbortkie/repos/crl_subgoal')
import wandb

from rb import TrajectoryUniformSamplingQueue, jit_wrap, segment_ids_per_row
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Dict, Any
from dataclasses import dataclass
import chex
from flax import struct
from absl import app, flags
from ml_collections import config_flags
from impls.agents import agents
from impls.agents.crl import CRLAgent, get_config
from config import ROOT_DIR
from block_moving_env import *
from main import *





In [2]:
# vmap environment
NUM_ENVS = 1024
MAX_REPLAY_SIZE = 10000
BATCH_SIZE = 1024
EPISODE_LENGTH = 100
NUM_ACTIONS = 6
GRID_SIZE = 10
NUM_BOXES = 10
SEED = 2

In [None]:
env = BoxPushingEnv(grid_size=GRID_SIZE, max_steps=EPISODE_LENGTH, number_of_boxes=NUM_BOXES)
env = AutoResetWrapper(env)
key = random.PRNGKey(SEED)
env.step = jax.jit(jax.vmap(env.step))
env.reset = jax.jit(jax.vmap(env.reset))
jitted_flatten_batch = jax.jit(jax.vmap(flatten_batch, in_axes=(None, 0, 0)), static_argnums=(0,))

# Replay buffer
dummy_timestep = TimeStep(
    key=key,
    grid=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    target_cells=jnp.zeros((NUM_BOXES, 2), dtype=jnp.int32),
    agent_pos=jnp.zeros((2,), dtype=jnp.int32),
    agent_has_box=jnp.zeros((1,), dtype=jnp.int32),
    steps=jnp.zeros((1,), dtype=jnp.int32),
    action=jnp.zeros((1,), dtype=jnp.int32),
    goal=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    reward=jnp.zeros((1,), dtype=jnp.int32),
    done=jnp.zeros((1,), dtype=jnp.int32),
)
replay_buffer = jit_wrap(
    TrajectoryUniformSamplingQueue(
        max_replay_size=MAX_REPLAY_SIZE,
        dummy_data_sample=dummy_timestep,
        sample_batch_size=BATCH_SIZE,
        num_envs=NUM_ENVS,
        episode_length=EPISODE_LENGTH,
    )
)
buffer_state = jax.jit(replay_buffer.init)(key)



2025-07-01 22:30:03.084757: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 8.61GiB (9243810148 bytes) by rematerialization; only reduced to 8.70GiB (9338880080 bytes), down from 8.70GiB (9338880080 bytes) originally


In [4]:

# Agent
config = get_config()
config['discrete'] = True
agent_class = agents[config['agent_name']]
example_batch = {
    'observations':dummy_timestep.grid.reshape(1, -1),  # Add batch dimension 
    'actions': jnp.ones((1,), dtype=jnp.int32) * (NUM_ACTIONS-1), # TODO: make sure it should be the maximal value of action space  # Single action for batch size 1
    'value_goals': dummy_timestep.grid.reshape(1, -1),
    'actor_goals': dummy_timestep.grid.reshape(1, -1),
}

print("Testing agent creation")
agent = agent_class.create(
    SEED,
    example_batch['observations'],
    example_batch['actions'],
    config,
    example_batch['value_goals'],
)
print("Agent created")

Testing agent creation
Agent created


In [5]:
data_collection_key = random.PRNGKey(SEED)


In [6]:
@jax.jit
def update_step(carry, _):
    buffer_state, agent, key = carry
    key, batch_key, double_batch_key = jax.random.split(key, 3)
    # Sample and process transitions
    buffer_state, transitions = replay_buffer.sample(buffer_state)
    batch_keys = jax.random.split(batch_key, transitions.grid.shape[0])
    state, future_state, goal_index = jitted_flatten_batch(0.99, transitions, batch_keys)

    state, actions, future_state, goal_index = apply_double_batch_trick(state, future_state, goal_index, double_batch_key)
    # Create valid batch
    valid_batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }

    # Update agent
    agent, update_info = agent.update(valid_batch)
    return (buffer_state, agent, key), update_info


In [7]:

@jax.jit
def train_epoch(carry, _):
    buffer_state, agent, key = carry
    key, data_key, up_key = jax.random.split(key, 3)
    _, _, timesteps = collect_data(agent, data_key, env, NUM_ENVS, EPISODE_LENGTH)
    buffer_state = replay_buffer.insert(buffer_state, timesteps)
    (buffer_state, agent, _), _ = jax.lax.scan(update_step, (buffer_state, agent, up_key), None, length=1000)
    return (buffer_state, agent, key), None

@jax.jit
def train_n_epochs(buffer_state, agent, key):
    (buffer_state, agent, key), _ = jax.lax.scan(
        train_epoch,
        (buffer_state, agent, key),
        None,
        length=10,
    )
    return buffer_state, agent, key

In [8]:
collect_data(agent, key, env, 1024, 100)
update_step((buffer_state, agent, key), None)
print("done")

2025-07-01 22:30:25.649266: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 337.76MiB (354166672 bytes) by rematerialization; only reduced to 8.81GiB (9460974608 bytes), down from 8.86GiB (9515837288 bytes) originally


done


In [9]:

import time
import numpy as np

# Measure collect_data time over 10 executions
collect_times = []
for i in range(10):
    start_time = time.time()
    collect_data(agent, key, env, 1024, 100)
    collect_time = time.time() - start_time
    collect_times.append(collect_time)

avg_collect_time = np.mean(collect_times)
std_collect_time = np.std(collect_times)
print(f"collect_data average time: {avg_collect_time:.4f} ± {std_collect_time:.4f} seconds")

# Measure update_step time over 10 executions
update_times = []
for i in range(10):
    start_time = time.time()
    carry, _ = update_step((buffer_state, agent, key), None)
    update_time = time.time() - start_time
    (buffer_state, agent, key) = carry
    update_times.append(update_time)

avg_update_time = np.mean(update_times)
std_update_time = np.std(update_times)
print(f"update_step average time: {avg_update_time:.4f} ± {std_update_time:.4f} seconds")

collect_data average time: 0.0254 ± 0.0011 seconds
update_step average time: 0.0438 ± 0.1281 seconds


In [10]:

import time
import numpy as np

# Measure key splitting time over 10 executions
key_split_times = []
for i in range(10):
    start_time = time.time()
    key, batch_key, double_batch_key = jax.random.split(key, 3)
    key_split_time = time.time() - start_time
    key_split_times.append(key_split_time)

avg_key_split_time = np.mean(key_split_times)
std_key_split_time = np.std(key_split_times)
print(f"key splitting average time: {avg_key_split_time:.6f} ± {std_key_split_time:.6f} seconds")

# Measure buffer sampling time over 10 executions
buffer_sample_times = []
for i in range(10):
    start_time = time.time()
    buffer_state, transitions = replay_buffer.sample(buffer_state)
    buffer_sample_time = time.time() - start_time
    buffer_sample_times.append(buffer_sample_time)

avg_buffer_sample_time = np.mean(buffer_sample_times)
std_buffer_sample_time = np.std(buffer_sample_times)
print(f"buffer sampling average time: {avg_buffer_sample_time:.6f} ± {std_buffer_sample_time:.6f} seconds")

# Measure batch key generation time over 10 executions
batch_keys_times = []
for i in range(10):
    start_time = time.time()
    batch_keys = jax.random.split(batch_key, transitions.grid.shape[0])
    batch_keys_time = time.time() - start_time
    batch_keys_times.append(batch_keys_time)

avg_batch_keys_time = np.mean(batch_keys_times)
std_batch_keys_time = np.std(batch_keys_times)
print(f"batch keys generation average time: {avg_batch_keys_time:.6f} ± {std_batch_keys_time:.6f} seconds")

# Measure flatten batch time over 10 executions
flatten_batch_times = []
for i in range(10):
    start_time = time.time()
    state, future_state, goal_index = jitted_flatten_batch(0.99, transitions, batch_keys)
    flatten_batch_time = time.time() - start_time
    flatten_batch_times.append(flatten_batch_time)

avg_flatten_batch_time = np.mean(flatten_batch_times)
std_flatten_batch_time = np.std(flatten_batch_times)
print(f"flatten batch average time: {avg_flatten_batch_time:.6f} ± {std_flatten_batch_time:.6f} seconds")

# Measure double batch trick time over 10 executions
double_batch_times = []
for i in range(10):
    start_time = time.time()
    apply_double_batch_trick(state, future_state, goal_index, double_batch_key)
    double_batch_time = time.time() - start_time
    double_batch_times.append(double_batch_time)

avg_double_batch_time = np.mean(double_batch_times)
std_double_batch_time = np.std(double_batch_times)
print(f"double batch trick average time: {avg_double_batch_time:.6f} ± {std_double_batch_time:.6f} seconds")

state, actions, future_state, goal_index = apply_double_batch_trick(state, future_state, goal_index, double_batch_key)


# Measure batch creation time over 10 executions
batch_creation_times = []
for i in range(10):
    start_time = time.time()
    valid_batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }
    batch_creation_time = time.time() - start_time
    batch_creation_times.append(batch_creation_time)

avg_batch_creation_time = np.mean(batch_creation_times)
std_batch_creation_time = np.std(batch_creation_times)
print(f"batch creation average time: {avg_batch_creation_time:.6f} ± {std_batch_creation_time:.6f} seconds")

# Measure agent update time over 10 executions
agent_update_times = []
for i in range(10):
    start_time = time.time()
    agent, update_info = agent.update(valid_batch)
    agent_update_time = time.time() - start_time
    agent_update_times.append(agent_update_time)

avg_agent_update_time = np.mean(agent_update_times)
std_agent_update_time = np.std(agent_update_times)
print(f"agent update average time: {avg_agent_update_time:.6f} ± {std_agent_update_time:.6f} seconds")

# Print total average time
total_avg_time = avg_key_split_time + avg_buffer_sample_time + avg_batch_keys_time + avg_flatten_batch_time + avg_double_batch_time + avg_batch_creation_time + avg_agent_update_time
total_std_time = np.sqrt(std_key_split_time**2 + std_buffer_sample_time**2 + std_batch_keys_time**2 + std_flatten_batch_time**2 + std_double_batch_time**2 + std_batch_creation_time**2 + std_agent_update_time**2)
print(f"total average time: {total_avg_time:.6f} ± {total_std_time:.6f} seconds")

key splitting average time: 0.028839 ± 0.085857 seconds


2025-07-01 22:30:39.253427: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 270.04MiB (283154712 bytes) by rematerialization; only reduced to 8.78GiB (9432285316 bytes), down from 8.78GiB (9432289416 bytes) originally
2025-07-01 22:30:49.666778: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.70GiB (rounded to 9338880000)requested by op 
2025-07-01 22:30:49.667153: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ***************************************************************************************************_
E0701 22:30:49.667184 3198855 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9338880000 bytes. [tf-allocator-allocation-error='']


ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9338880000 bytes.

In [11]:

# Measure scan execution time over 10 executions
scan_times = []
for i in range(10):
    start_time = time.time()
    (buffer_state, agent, _), _ = jax.lax.scan(update_step, (buffer_state, agent, key), None, length=10)
    scan_time = time.time() - start_time
    scan_times.append(scan_time)

avg_scan_time = np.mean(scan_times)
std_scan_time = np.std(scan_times)
print(f"scan execution average time: {avg_scan_time:.6f} ± {std_scan_time:.6f} seconds")

2025-07-01 20:14:51.037675: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 337.76MiB (354165852 bytes) by rematerialization; only reduced to 8.83GiB (9481893056 bytes), down from 8.86GiB (9515784168 bytes) originally
2025-07-01 20:15:12.432393: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.70GiB (rounded to 9338880000)requested by op 
2025-07-01 20:15:12.432765: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ***************************************************************************************************_
E0701 20:15:12.432839 3139955 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9338880000 bytes. [tf-allocator-allocation-error='']


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9338880000 bytes.

In [11]:
SEED = 4
env = BoxPushingEnv(grid_size=GRID_SIZE, max_steps=EPISODE_LENGTH, number_of_boxes=NUM_BOXES)
env = AutoResetWrapper(env)
key = random.PRNGKey(SEED)
# env.step = jax.jit(jax.vmap(env.step))
# env.reset = jax.jit(jax.vmap(env.reset))

In [12]:
state, info = env.reset(key)


In [13]:
state.grid

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 1, 0, 0, 0, 0, 1, 1, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],
       [0, 1, 0, 0, 0, 1, 2, 2, 0, 2],
       [0, 0, 0, 0, 0, 0, 0, 2, 0, 0],
       [0, 0, 0, 1, 0, 1, 2, 0, 3, 2],
       [0, 0, 1, 0, 0, 2, 0, 0, 0, 2]], dtype=int8)

In [14]:
solved_state = env._env.create_solved_state(state)
solved_state.grid

Array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0, 10, 10,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 10, 10,  0, 10],
       [ 0,  0,  0,  0,  0,  0,  0, 10,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 10,  0,  3, 10],
       [ 0,  0,  0,  0,  0, 10,  0,  0,  0, 10]], dtype=int8)

In [15]:
state_next, reward, done, info = env.step(solved_state, 2)
state_next.grid

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 3, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0, 2, 0, 0, 2, 2],
       [0, 0, 0, 0, 0, 2, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 2, 2, 0],
       [0, 0, 0, 0, 1, 0, 2, 1, 2, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 2, 2]], dtype=int8)

In [16]:
print(reward, done, info)

10 True {'boxes_on_target': Array(10, dtype=int32)}


In [23]:
env._env._is_goal_reached(solved_state.grid)

Array(True, dtype=bool)

In [19]:
False | env._env._is_goal_reached(solved_state.grid)


Array(True, dtype=bool)