In [1]:
import functools
import os
# os.chdir('/home/mbortkie/repos/crl_subgoal')
import wandb

from rb import TrajectoryUniformSamplingQueue, jit_wrap, segment_ids_per_row
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Dict, Any
from dataclasses import dataclass
import chex
from flax import struct
from absl import app, flags
from ml_collections import config_flags
from agents import agents
from agents.crl import CRLAgent, get_config
from config import ROOT_DIR
from block_moving_env import *



In [2]:
# vmap environment
NUM_ENVS = 1024
MAX_REPLAY_SIZE = 10000
BATCH_SIZE = 1024
EPISODE_LENGTH = 100
NUM_ACTIONS = 6
GRID_SIZE = 5
NUM_BOXES = 1
SEED = 2

In [3]:
env = BoxPushingEnv(grid_size=GRID_SIZE, max_steps=EPISODE_LENGTH, number_of_boxes=NUM_BOXES)
env = AutoResetWrapper(env)
key = random.PRNGKey(SEED)
env.step = jax.jit(jax.vmap(env.step))
env.reset = jax.jit(jax.vmap(env.reset))
jitted_flatten_batch = jax.jit(jax.vmap(flatten_batch, in_axes=(None, 0, 0)), static_argnums=(0,))

# Replay buffer
dummy_timestep = TimeStep(
    key=key,
    grid=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    target_cells=jnp.zeros((NUM_BOXES, 2), dtype=jnp.int32),
    agent_pos=jnp.zeros((2,), dtype=jnp.int32),
    agent_has_box=jnp.zeros((1,), dtype=jnp.int32),
    steps=jnp.zeros((1,), dtype=jnp.int32),
    action=jnp.zeros((1,), dtype=jnp.int32),
    goal=jnp.zeros((GRID_SIZE, GRID_SIZE), dtype=jnp.int32),
    reward=jnp.zeros((1,), dtype=jnp.int32),
)
replay_buffer = jit_wrap(
    TrajectoryUniformSamplingQueue(
        max_replay_size=MAX_REPLAY_SIZE,
        dummy_data_sample=dummy_timestep,
        sample_batch_size=BATCH_SIZE,
        num_envs=NUM_ENVS,
        episode_length=EPISODE_LENGTH,
    )
)
buffer_state = jax.jit(replay_buffer.init)(key)



In [4]:

# Agent
config = get_config()
config['discrete'] = True
agent_class = agents[config['agent_name']]
example_batch = {
    'observations':dummy_timestep.grid.reshape(1, -1),  # Add batch dimension 
    'actions': jnp.ones((1,), dtype=jnp.int32) * (NUM_ACTIONS-1), # TODO: make sure it should be the maximal value of action space  # Single action for batch size 1
    'value_goals': dummy_timestep.grid.reshape(1, -1),
    'actor_goals': dummy_timestep.grid.reshape(1, -1),
}

print("Testing agent creation")
agent = agent_class.create(
    SEED,
    example_batch['observations'],
    example_batch['actions'],
    config,
    example_batch['value_goals'],
)
print("Agent created")

Testing agent creation
Agent created


In [5]:
data_collection_key = random.PRNGKey(SEED)


In [6]:
@jax.jit
def update_step(carry, _):
    buffer_state, agent, key = carry
    key, batch_key, subkey1, subkey2 = jax.random.split(key, 4)
    # Sample and process transitions
    buffer_state, transitions = replay_buffer.sample(buffer_state)
    batch_keys = jax.random.split(batch_key, transitions.grid.shape[0])
    state, future_state, goal_index = jitted_flatten_batch(0.99, transitions, batch_keys)

    # Get random index for each batch
    random_indices = jax.random.randint(subkey1, (state.grid.shape[0],), minval=0, maxval=state.grid.shape[1])            

    # Extract data at random index
    state1 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices], state)
    actions = state1.action
    future_state1 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices], future_state)
    goal_index1 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices], goal_index)
    
    random_indices2 = jax.random.randint(subkey2, (state.grid.shape[0],), minval=0, maxval=state.grid.shape[1])            

    # Extract data at random index
    state2 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices2], state)
    actions2 = state2.action
    future_state2 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices2], future_state)
    goal_index2 = jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), random_indices2], goal_index)

    state = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate([x1, x2], axis=0), state1, state2)
    actions = jnp.concatenate([actions, actions2], axis=0)
    future_state = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate([x1, x2], axis=0), future_state1, future_state2)
    goal_index = jnp.concatenate([goal_index1, goal_index2], axis=0)
    
    # Create valid batch
    valid_batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }

    # Update agent
    agent, update_info = agent.update(valid_batch)
    return (buffer_state, agent, key), update_info

In [71]:
@jax.jit
def train_epoch(carry, _):
  buffer_state, agent, key = carry
  key, data_key, up_key = jax.random.split(key, 3)
  _, _, timesteps = collect_data(agent, data_key, env, NUM_ENVS, EPISODE_LENGTH)
  buffer_state = replay_buffer.insert(buffer_state, timesteps)
  (buffer_state, agent, _), _ = jax.lax.scan(update_step, (buffer_state, agent, up_key), None, length=1000)
  return (buffer_state, agent, key), None


In [72]:
(buffer_state, agent, key), _ = train_epoch((buffer_state, agent, key), None)

In [73]:

# then outside:
(buffer_state, agent, key), _ = jax.lax.scan(train_epoch, (buffer_state, agent, key), None, length=10)


In [74]:
for i in range(10):
    (buffer_state, agent, key), _ = train_epoch((buffer_state, agent, key), None)

In [75]:
@jax.jit
def train_n_epochs(buffer_state, agent, key):
    (buffer_state, agent, key), _ = jax.lax.scan(
        train_epoch,
        (buffer_state, agent, key),
        None,
        length=10,
    )
    return buffer_state, agent, key

# first call pays compilation cost
(buffer_state, agent, key) = train_n_epochs(buffer_state, agent, key)



In [76]:
# subsequent calls are ~0.05 s
for _ in range(10):
    buffer_state, agent, key = train_n_epochs(buffer_state, agent, key)


In [8]:

for epoch in range(10):
    key, data_collection_key, update_key = jax.random.split(key, 3)
    env_step, info, timesteps_all = collect_data(agent, data_collection_key, env, NUM_ENVS, EPISODE_LENGTH)
    buffer_state = replay_buffer.insert(buffer_state, timesteps_all)
    
    # Run scan for updates
    (buffer_state, agent, _), update_infos = jax.lax.scan(
        update_step,
        (buffer_state, agent, update_key),
        None,
        length=10
    )

    update_infos = jax.tree_util.tree_map(lambda x: x[-1], update_infos)
    
    update_infos.update({
        "eval/reward_min": timesteps_all.reward.min(),
        "eval/reward_max": timesteps_all.reward.max(), 
        "eval/reward_mean": timesteps_all.reward.mean()
    })

In [18]:
buffer_state, batch = replay_buffer.sample(buffer_state)
batch.grid.shape



(1024, 100, 5, 5)

In [19]:
timesteps_all.grid.shape
timesteps_all.grid.swapaxes(1, 0).shape

(1024, 100, 5, 5)

In [46]:
@jax.jit
def extract_at_indices(data, indices):
    return jax.tree_util.tree_map(lambda x: x[jnp.arange(x.shape[0]), indices], data)

@jax.jit
def apply_double_batch_trick(state, future_state, goal_index, key):
    """Sample two random indices and concatenate the results."""
    # Sample two random indices for each batch
    subkey1, subkey2 = jax.random.split(key, 2)
    random_indices1 = jax.random.randint(subkey1, (state.grid.shape[0],), minval=0, maxval=state.grid.shape[1])
    random_indices2 = jax.random.randint(subkey2, (state.grid.shape[0],), minval=0, maxval=state.grid.shape[1])

    state1 = extract_at_indices(state, random_indices1)
    state2 = extract_at_indices(state, random_indices2)
    future_state1 = extract_at_indices(future_state, random_indices1)
    future_state2 = extract_at_indices(future_state, random_indices2)
    goal_index1 = extract_at_indices(goal_index, random_indices1)
    goal_index2 = extract_at_indices(goal_index, random_indices2)
    
    # Concatenate the two samples
    state = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate([x1, x2], axis=0), state1, state2)
    actions = jnp.concatenate([state1.action, state2.action], axis=0)
    future_state = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate([x1, x2], axis=0), future_state1, future_state2)
    goal_index = jnp.concatenate([goal_index1, goal_index2], axis=0)
    
    return state, actions, future_state, goal_index

def evaluate_agent(agent, env, key, num_envs=1024, episode_length=100):
    """Evaluate agent by running rollouts using collect_data and computing losses."""
    key, data_key, double_batch_key = jax.random.split(key, 3)
    # Use collect_data for evaluation rollouts
    _, _, timesteps = collect_data(agent, data_key, env, num_envs, episode_length)
    timesteps = jax.tree_util.tree_map(lambda x: x.swapaxes(1, 0), timesteps)

    batch_keys = jax.random.split(data_key, num_envs)
    state, future_state, goal_index = jitted_flatten_batch(0.99, timesteps, batch_keys)
    
    # Sample and concatenate batch using the new function
    state, actions, future_state, goal_index = apply_double_batch_trick(state, future_state, goal_index, double_batch_key)
    
    # Create valid batch
    valid_batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }

    # Compute losses on example batch
    loss, loss_info = agent.total_loss(valid_batch, None)
    
    # Compile evaluation info
    eval_info = {
        'eval/mean_reward': timesteps.reward.mean(),
        'eval/min_reward': timesteps.reward.min(),
        'eval/max_reward': timesteps.reward.max(),
        'eval/total_loss': loss,
    }
    eval_info.update(loss_info)
    
    return eval_info


In [47]:

# Run evaluation
eval_results = evaluate_agent(agent, env, key, 1024, 100)
print("Evaluation Results:")
for k, v in eval_results.items():
    print(f"{k}: {v}")

(2048, 5, 5)
Evaluation Results:
eval/mean_reward: 0.07830078154802322
eval/min_reward: 0
eval/max_reward: 1
eval/total_loss: 1.5962635278701782
actor/actor_loss: 1.5877385139465332
actor/adv: -0.44561105966567993
actor/bc_log_prob: -1.6485228538513184
critic/binary_accuracy: 0.99951171875
critic/categorical_accuracy: 0.0029296875
critic/contrastive_loss: 0.004223821219056845
critic/logits: -8.592992782592773
critic/logits_neg: -8.593518257141113
critic/logits_pos: -7.519384384155273
critic/v_max: 0.0724712684750557
critic/v_mean: 0.002251780591905117
critic/v_min: 2.553087369960849e-06
value/binary_accuracy: 0.99951171875
value/categorical_accuracy: 0.0029296875
value/contrastive_loss: 0.004301227163523436
value/logits: -8.580251693725586
value/logits_neg: -8.580719947814941
value/logits_pos: -7.618547439575195
value/v_max: 0.04266826808452606
value/v_mean: 0.0023576815146952868
value/v_min: 1.0530050076340558e-06


In [44]:
eval_results = evaluate_agent(agent, env, key, 1024, 100)
