In [1]:
import functools
import os
# os.chdir('/home/mbortkie/repos/crl_subgoal')
import wandb

from rb import TrajectoryUniformSamplingQueue, jit_wrap, segment_ids_per_row
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Dict, Any
from dataclasses import dataclass
import chex
from flax import struct
from absl import app, flags
from ml_collections import config_flags
from impls.agents import agents
from impls.agents.crl import CRLAgent, get_config
from config import SRC_ROOT_DIR
from block_moving_env import *
from main import *





In [2]:
# vmap environment
NUM_ENVS = 1024
MAX_REPLAY_SIZE = 10000
BATCH_SIZE = 1024
EPISODE_LENGTH = 100
NUM_ACTIONS = 6
GRID_SIZE = 10
NUM_BOXES = 10
SEED = 2

In [None]:
env = BoxPushingEnv(grid_size=GRID_SIZE, max_steps=EPISODE_LENGTH, number_of_boxes=NUM_BOXES)
env = AutoResetWrapper(env)
key = random.PRNGKey(SEED)
env.step = jax.jit(jax.vmap(env.step))
env.reset = jax.jit(jax.vmap(env.reset))
jitted_flatten_batch = jax.jit(jax.vmap(flatten_batch, in_axes=(None, 0, 0)), static_argnums=(0,))

# Replay buffer
dummy_timestep = TimeStep(
    key=key,
    grid=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    agent_pos=jnp.zeros((2,), dtype=jnp.int32),
    agent_has_box=jnp.zeros((1,), dtype=jnp.int32),
    steps=jnp.zeros((1,), dtype=jnp.int32),
    action=jnp.zeros((1,), dtype=jnp.int32),
    goal=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    reward=jnp.zeros((1,), dtype=jnp.int32),
    done=jnp.zeros((1,), dtype=jnp.int32),
)
replay_buffer = jit_wrap(
    TrajectoryUniformSamplingQueue(
        max_replay_size=MAX_REPLAY_SIZE,
        dummy_data_sample=dummy_timestep,
        sample_batch_size=BATCH_SIZE,
        num_envs=NUM_ENVS,
        episode_length=EPISODE_LENGTH,
    )
)
buffer_state = jax.jit(replay_buffer.init)(key)



2025-07-02 19:48:05.565181: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 8.57GiB (9204898148 bytes) by rematerialization; only reduced to 8.74GiB (9379840080 bytes), down from 8.74GiB (9379840080 bytes) originally


In [4]:

# Agent
config = get_config()
config['discrete'] = True
agent_class = agents[config['agent_name']]
example_batch = {
    'observations':dummy_timestep.grid.reshape(1, -1),  # Add batch dimension 
    'actions': jnp.ones((1,), dtype=jnp.int32) * (NUM_ACTIONS-1), # TODO: make sure it should be the maximal value of action space  # Single action for batch size 1
    'value_goals': dummy_timestep.grid.reshape(1, -1),
    'actor_goals': dummy_timestep.grid.reshape(1, -1),
}

print("Testing agent creation")
agent = agent_class.create(
    SEED,
    example_batch['observations'],
    example_batch['actions'],
    config,
    example_batch['value_goals'],
)
print("Agent created")

Testing agent creation
Agent created


In [5]:
data_collection_key = random.PRNGKey(SEED)


In [6]:
@jax.jit
def update_step(carry, _):
    buffer_state, agent, key = carry
    key, batch_key, double_batch_key = jax.random.split(key, 3)
    # Sample and process transitions
    buffer_state, transitions = replay_buffer.sample(buffer_state)
    batch_keys = jax.random.split(batch_key, transitions.grid.shape[0])
    state, future_state, goal_index = jitted_flatten_batch(0.99, transitions, batch_keys)

    state, actions, future_state, goal_index = apply_double_batch_trick(state, future_state, goal_index, double_batch_key)
    # Create valid batch
    valid_batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }

    # Update agent
    agent, update_info = agent.update(valid_batch)
    return (buffer_state, agent, key), update_info


In [7]:

@jax.jit
def train_epoch(carry, _):
    buffer_state, agent, key = carry
    key, data_key, up_key = jax.random.split(key, 3)
    _, _, timesteps = collect_data(agent, data_key, env, NUM_ENVS, EPISODE_LENGTH)
    buffer_state = replay_buffer.insert(buffer_state, timesteps)
    (buffer_state, agent, _), _ = jax.lax.scan(update_step, (buffer_state, agent, up_key), None, length=1000)
    return (buffer_state, agent, key), None

@jax.jit
def train_n_epochs(buffer_state, agent, key):
    (buffer_state, agent, key), _ = jax.lax.scan(
        train_epoch,
        (buffer_state, agent, key),
        None,
        length=10,
    )
    return buffer_state, agent, key

In [8]:
SEED = 4
env = BoxPushingEnv(grid_size=GRID_SIZE, max_steps=EPISODE_LENGTH, number_of_boxes=NUM_BOXES)
env = AutoResetWrapper(env)
key = random.PRNGKey(SEED)
# env.step = jax.jit(jax.vmap(env.step))
# env.reset = jax.jit(jax.vmap(env.reset))

In [9]:
state, info = env.reset(key)


In [10]:
state.grid

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 1, 0, 0, 0, 0, 1, 1, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],
       [0, 1, 0, 0, 0, 1, 2, 2, 0, 2],
       [0, 0, 0, 0, 0, 0, 0, 2, 0, 0],
       [0, 0, 0, 1, 0, 1, 2, 0, 3, 2],
       [0, 0, 1, 0, 0, 2, 0, 0, 0, 2]], dtype=int8)

In [11]:
solved_state = env._env.create_solved_state(state)
solved_state.grid

Array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0, 10, 10,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 10, 10,  0, 10],
       [ 0,  0,  0,  0,  0,  0,  0, 10,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 10,  0,  3, 10],
       [ 0,  0,  0,  0,  0, 10,  0,  0,  0, 10]], dtype=int8)

In [12]:
state_next, reward, done, info = env.step(solved_state, 2)
state_next.grid

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 3, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0, 2, 0, 0, 2, 2],
       [0, 0, 0, 0, 0, 2, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 2, 2, 0],
       [0, 0, 0, 0, 1, 0, 2, 1, 2, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 2, 2]], dtype=int8)

In [13]:
print(reward, done, info)

1 True {'boxes_on_target': Array(10, dtype=int32)}


In [14]:
env._env._is_goal_reached(solved_state.grid)

Array(True, dtype=bool)

In [15]:
False | env._env._is_goal_reached(solved_state.grid)


Array(True, dtype=bool)

In [18]:
state_next.grid

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 3, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0, 2, 0, 0, 2, 2],
       [0, 0, 0, 0, 0, 2, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 2, 2, 0],
       [0, 0, 0, 0, 1, 0, 2, 1, 2, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 2, 2]], dtype=int8)

In [19]:
GridStatesEnum.project_to_no_target(state_next.grid)

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 3, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], dtype=int8)