# Evaluating discovered update rules on Sokoban

This colab demonstrates how to instantiate the `Disco103` update rule and use it for training an RL agent on the `Sokoban` environment.

The repository also contains `ActorCritic` and `PolicyGradient` update rules and a CPU version of `Catch`; feel free to explore and repurpose this code for your needs.

In [1]:
# @title Install the package.

# !pip install git+https://github.com/google-deepmind/disco_rl.git

import collections

import chex
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import rlax
import seaborn as sns
import tqdm

# Types & utils
from disco_rl import types
from disco_rl import utils

# Environments
from disco_rl.environments import base as base_env
from disco_rl.environments import jittable_envs

# Learning
from disco_rl import agent as agent_lib

axis_name = 'i'  # for parallelisation

In [2]:
# @title Gym to Disco environment adapter for Sokoban

import sys
import os

# Add sokoban directory to path (works from colabs directory)
sokoban_path = os.path.join(os.path.dirname(os.getcwd()), 'sokoban')
if sokoban_path not in sys.path:
  sys.path.insert(0, sokoban_path)

from sokoban import mathisfun_sokoban as mfs
import dm_env
from dm_env import specs

class GymToDiscoEnv(base_env.Environment):
  """Adapter that wraps Gymnasium Sokoban environment for Disco RL."""
  
  def __init__(self, batch_size: int = 1, gamma: float = 0.99, seed: int = 0):
    self.batch_size = batch_size
    self.gamma = gamma
    
    # Create the gym environment
    self._gym_env = mfs.MathIsFunSokoban()
    # Reset to initialize
    self._gym_env.reset(seed=seed)
    
  def single_observation_spec(self) -> types.Specs:
    """Returns the observation spec."""
    obs_space = self._gym_env.observation_space
    # Convert to float32 and normalize to [0, 1]
    return {
      'observation': specs.Array(
        shape=obs_space.shape,
        dtype=jnp.float32,
        name='observation'
      )
    }
  
  def single_action_spec(self) -> types.ActionSpec:
    """Returns the action spec."""
    n = self._gym_env.action_space.n
    return specs.BoundedArray(
      shape=(),
      dtype=jnp.int32,
      minimum=0,
      maximum=n - 1,
      name='action'
    )
  
  def reset(
      self, rng_key: chex.PRNGKey
  ) -> tuple[dict, types.EnvironmentTimestep]:
    """Resets the environment."""
    # Use seed from rng if available, otherwise use default
    obs, info = self._gym_env.reset()
    
    # Normalize observation to [0, 1] and convert to float32
    obs_normalized = jnp.asarray(obs, dtype=jnp.float32) / 255.0
    
    timestep = types.EnvironmentTimestep(
      observation={'observation': obs_normalized},
      step_type=jnp.array(dm_env.StepType.FIRST, dtype=jnp.int32),
      reward=jnp.array(0.0, dtype=jnp.float32)
    )
    
    return {}, timestep
  
  def step(
      self, state: dict, actions: chex.ArrayTree
  ) -> tuple[dict, types.EnvironmentTimestep]:
    """Steps the environment."""
    # Extract action (handle batched case)
    actions_array = jnp.asarray(actions)
    if actions_array.shape == ():
      action = int(actions_array)
    elif len(actions_array.shape) == 1 and actions_array.shape[0] == 1:
      action = int(actions_array[0])
    else:
      action = int(actions_array.flatten()[0])
    
    obs, reward, terminated, truncated, info = self._gym_env.step(action)
    done = terminated or truncated
    
    # Normalize observation to [0, 1] and convert to float32
    obs_normalized = jnp.asarray(obs, dtype=jnp.float32) / 255.0
    
    # Determine step type
    if done:
      step_type = dm_env.StepType.LAST
    else:
      step_type = dm_env.StepType.MID
    
    # Auto-reset if done
    if done:
      obs, _ = self._gym_env.reset()
      obs_normalized = jnp.asarray(obs, dtype=jnp.float32) / 255.0
    
    timestep = types.EnvironmentTimestep(
      observation={'observation': obs_normalized},
      step_type=jnp.array(step_type, dtype=jnp.int32),
      reward=jnp.array(reward, dtype=jnp.float32)
    )
    
    return state, timestep



ModuleNotFoundError: No module named 'sokoban'

In [None]:
# @title Download and unpack `Disco103` weights.

def unflatten_params(flat_params: chex.ArrayTree) -> chex.ArrayTree:
  params = {}
  for key_wb in flat_params:
    key = '/'.join(key_wb.split('/')[:-1])
    params[key] = {
        'b': flat_params[f'{key}/b'],
        'w': flat_params[f'{key}/w'],
    }
  return params


disco_103_fname = 'disco_103.npz'
disco_103_url = f"https://raw.githubusercontent.com/google-deepmind/disco_rl/main/disco_rl/update_rules/weights/{disco_103_fname}"
# !wget $disco_103_url

with open(f'/home/skr/Downloads/disco_rl/colabs/disco_103.npz', 'rb') as file:
  disco_103_params = unflatten_params(np.load(file))

print(f'Loaded {len(disco_103_params) * 2} parameter tensors for Disco103.')

In [None]:
# @title Instantiate a simple MLP agent.


def get_env(batch_size: int) -> base_env.Environment:
  return GymToDiscoEnv(batch_size=batch_size, gamma=0.99, seed=0)


# Create a dummy environment.
env = get_env(batch_size=1)

# Create settings for an agent.
agent_settings = agent_lib.get_settings_disco()
agent_settings.net_settings.name = 'mlp'
agent_settings.net_settings.net_args = dict(
    dense=(512, 512),
    model_arch_name='lstm',
    head_w_init_std=1e-2,
    model_kwargs=dict(
        head_mlp_hiddens=(128,),
        lstm_size=128,
    ),
)
agent_settings.learning_rate = 1e-2

# Create the agent.
agent = agent_lib.Agent(
    agent_settings=agent_settings,
    single_observation_spec=env.single_observation_spec(),
    single_action_spec=env.single_action_spec(),
    batch_axis_name=axis_name,
)

# Ensure that the agent's update rule's parameters have the same specs.
random_update_rule_params, _ = agent.update_rule.init_params(
    jax.random.PRNGKey(0)
)
if agent_settings.update_rule_name == 'disco':
  chex.assert_trees_all_equal_shapes_and_dtypes(
      random_update_rule_params, disco_103_params
  )
  print('Update rule parameters have the same specs.')
else:
  print('Not using a discovered rule, skipping check.')

In [None]:
# @title Helper functions for interacting with environments.
def unroll_cpu_actor(
    params,
    actor_state,
    ts,
    env_state,
    rng,
    env,
    rollout_len,
    actor_step_fn,
    devices,
):
  """Unrolls the policy for a CPU environments."""
  del devices  # Not needed for single device
  actor_timesteps = []
  for _ in range(rollout_len):
    rng, step_rng = jax.random.split(rng)

    actor_timestep, actor_state = actor_step_fn(
        params, step_rng, ts, actor_state
    )
    # Extract action - for batch_size=1, actions is a scalar or shape (1,)
    actions = actor_timestep.actions
    if hasattr(actions, 'shape'):
      if actions.shape == ():
        action = int(actions)
      elif len(actions.shape) == 1 and actions.shape[0] == 1:
        action = int(actions[0])
      else:
        action = int(actions.flatten()[0])
    else:
      action = int(actions)
    env_state, ts = env.step(env_state, action)

    actor_timesteps.append(actor_timestep)

  actor_rollout = types.ActorRollout.from_timestep(
      utils.tree_stack(actor_timesteps, axis=0)
  )
  return actor_rollout, actor_state, ts, env_state


def unroll_jittable_actor(
    params,
    actor_state,
    ts,
    env_state,
    rng,
    env,
    rollout_len,
    actor_step_fn,
    devices,
):
  """Unrolls the policy for a jittable environment."""
  del actor_step_fn, devices

  def _single_step(carry, step_rng):
    env_state, ts, actor_state = carry
    actor_timestep, actor_state = agent.actor_step(
        params, step_rng, ts, actor_state
    )
    env_state, ts = env.step(env_state, actor_timestep.actions)
    return (env_state, ts, actor_state), actor_timestep

  (env_state, ts, actor_state), actor_rollout = jax.lax.scan(
      _single_step,
      (env_state, ts, actor_state),
      jax.random.split(rng, rollout_len),
  )

  actor_rollout = types.ActorRollout.from_timestep(actor_rollout)
  return actor_rollout, actor_state, ts, env_state


def accumulate_rewards(acc_rewards, x):
  rewards, discounts = x

  def _step_fn(acc_rewards, x):
    rewards, discounts = x
    acc_rewards += rewards
    return acc_rewards * discounts, acc_rewards

  return jax.lax.scan(_step_fn, acc_rewards, (rewards, discounts))

In [None]:
# @title A simple Replay buffer.


class SimpleReplayBuffer:
  """A simple FIFO replay buffer for JAX arrays."""

  def __init__(self, capacity: int, seed: int):
    """Initializes the buffer."""
    self.buffer = collections.deque(maxlen=capacity)
    self.capacity = capacity
    self.np_rng = np.random.default_rng(seed)

  def add(self, rollout: types.ActorRollout) -> None:
    """Appends a batch of trajectories to the buffer."""
    rollout = jax.device_get(rollout)
    # split_tree = split_tree_on_dim(rollout, 2)
    split_tree = rlax.tree_split_leaves(rollout, axis=2)  # across batch dim
    self.buffer.extend(split_tree)

  def sample(self, batch_size: int) -> types.ActorRollout | None:
    """Samples a batch of trajectories from the buffer."""
    buffer_size = len(self.buffer)
    if buffer_size == 0:
      print("Warning: Trying to sample from an empty buffer.")
      return None

    indices = self.np_rng.integers(buffer_size, size=batch_size)
    batched_samples = utils.tree_stack(
        [self.buffer[i] for i in indices], axis=2
    )
    return batched_samples

  def __len__(self) -> int:
    """Returns the current number of transitions in the buffer."""
    return len(self.buffer)

In [None]:
# @title Training loop

num_steps = 1000
batch_size = 64
rollout_len = 29
rng_key = jax.random.PRNGKey(0)

replay_ratio = 32
buffer = SimpleReplayBuffer(capacity=1024, seed=17)
min_buffer_size = batch_size

# Use CPU mode with batch_size=1 for Sokoban
num_envs = 1
devices = (jax.devices()[0],)  # Single device
env = get_env(num_envs)

# Init states.
env_state, ts = env.reset(rng_key)
acc_rewards = jnp.zeros((num_envs,))
learner_state = agent.initial_learner_state(rng_key)
actor_state = agent.initial_actor_state(rng_key)
update_rule_params = disco_103_params

# Use CPU mode (no pmap) for Sokoban
is_jittable_actor = False
actor_step_fn = agent.actor_step
learner_step_fn = agent.learner_step
unroll_actor = unroll_cpu_actor
acc_rewards_fn = accumulate_rewards

# Buffers.
all_metrics = []
all_rewards = []
all_discounts = []
all_steps = []
all_returns = []
total_steps = 0

# Run the loop.
for step in tqdm.tqdm(range(num_steps)):
  rng_key, rng_actor, rng_learner = jax.random.split(rng_key, 3)

  # Generate new trajectories and add them to the buffer.
  actor_rollout, actor_state, ts, env_state = unroll_actor(
      learner_state.params,
      actor_state,
      ts,
      env_state,
      rng_actor,
      env,
      rollout_len,
      actor_step_fn,
      devices,
  )
  buffer.add(actor_rollout)

  # Accumulate statistics.
  total_steps += np.prod(actor_rollout.rewards.shape)
  acc_rewards, returns = acc_rewards_fn(
      acc_rewards,
      (actor_rollout.rewards, actor_rollout.discounts),
  )
  all_steps.append(total_steps)
  all_rewards.append(jax.device_get(actor_rollout.rewards))
  all_discounts.append(jax.device_get(actor_rollout.discounts))
  all_returns.append(jax.device_get(returns))

  # Update agent's parameters on the samples from the buffer.
  if len(buffer) >= min_buffer_size:
    learner_rollout = buffer.sample(batch_size)
    learner_state, _, metrics = learner_step_fn(
        rng_learner,
        learner_rollout,
        learner_state,
        actor_state,
        update_rule_params,
        False,  # is_meta_training
    )
    all_metrics.append(jax.device_get(metrics))

# Collect all logs and statistics (no gather needed for single device)
if all_metrics:
  all_metrics = [m if isinstance(m, dict) else {'loss': m} for m in all_metrics]
  # Stack metrics if needed
  if len(all_metrics) > 0 and isinstance(all_metrics[0], dict):
    stacked_metrics = {}
    for key in all_metrics[0].keys():
      stacked_metrics[key] = jnp.stack([m[key] for m in all_metrics])
    all_metrics = stacked_metrics

In [None]:
# @title Process logs
all_returns = np.array(all_returns)
all_discounts = np.array(all_discounts)
all_steps = np.array(all_steps)
total_returns = (all_returns * (1 - all_discounts)).sum(axis=(1, 2))
total_episodes = (1 - all_discounts).sum(axis=(1, 2))
avg_returns = total_returns / total_episodes

padded_metrics = {}
pad_width = len(all_steps) - len(all_metrics)
for key in all_metrics[0].keys():
  values = np.array([m[key] for m in all_metrics])
  padded_metrics[key] = np.pad(values, (pad_width, 0), constant_values=np.nan)

df = pd.DataFrame(
    dict(
        steps=all_steps,
        avg_returns=avg_returns,
        **padded_metrics,
    )
)

df['name'] = agent_settings.update_rule_name

In [None]:
sns.lineplot(data=df, x='steps', y='avg_returns')