In [18]:
from craftax.craftax_env import make_craftax_env_from_name
from nicewebrl.nicejax import TimestepWrapper
import jax
import flax.linen as nn
from typing import Optional, Any
import functools
import flax.struct as struct
import jax.numpy as jnp
from flax.core import unfreeze

In [2]:
env = make_craftax_env_from_name("Craftax-Symbolic-v1", auto_reset=False)
env = TimestepWrapper(env)
env_params = env.default_params

Loading Craftax textures from cache.
Textures successfully loaded from cache.


In [3]:
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

# Get an initial state and observation
init_timestep = env.reset(rngs[0], env_params)

# Pick random action
action = env.action_space(env_params).sample(rngs[1])

# Step environment
timestep = env.step(rngs[2], init_timestep, action, env_params)

  return lax_numpy.astype(arr, dtype)


In [22]:
config = {
        # --- Environment Settings ---
        "ENV_NAME": "Craftax-Symbolic-v1", # Example, adjust as needed
        "NUM_ENVS": 32,  # Number of parallel environments (PureJaxRL DQN used 10, can increase)

        # --- Training Loop Settings ---
        "TOTAL_TIMESTEPS": 1_000_000,    # Total environment steps
        "TRAINING_INTERVAL": 5,          # How many env steps per actor sequence collection
        "LEARNING_STARTS": 10_000,       # Timesteps before learning begins
        "TARGET_UPDATE_INTERVAL": 1_000, # How many LEARNER UPDATES between target network syncs (R2D2 uses ~2500 steps)

        # --- Network Settings ---
        "RNN_HIDDEN_DIM": 256,     # Size of RNN hidden state (Dyna code used 256)
        "ENCODER_HIDDEN_DIM": 512, # Hidden dim for observation encoder MLP
        "NUM_ENCODER_LAYERS": 0,   # Hidden layers for observation encoder MLP
        "Q_HIDDEN_DIM": 1024,      # Hidden dim for Q-head MLP (Dyna code used 512)
        "NUM_Q_LAYERS": 2,         # Hidden layers for Q-head MLP (Dyna code used 1)
        "USE_BIAS": True,          # Whether to use bias in Dense layers

        # --- Optimizer Settings ---
        "LR": 3e-4,
        "LR_LINEAR_DECAY": False,  # Whether to use linear LR decay
        "EPS_ADAM": 1e-5,          # Adam optimizer epsilon (ACME default 1e-5)
        "MAX_GRAD_NORM": 80,       # Gradient clipping norm (ACME default 40.0)
        "TAU": 1.0,

        # --- Buffer Settings ---
        "BUFFER_SIZE": 50_000,     # Total transitions in buffer (R2D2 often uses 1M+, adjust based on memory)
        "TOTAL_BATCH_SIZE": 1280,  # Total transitions sampled from buffer
        "SAMPLE_BATCH_SIZE": 32,   # Batch size sampled from buffer for learning (e.g., 32, 64)
        "SAMPLING_PERIOD": 1,      # Store sequences overlapping by N-1 steps (1 is standard)

        # --- Loss Function Settings ---
        "GAMMA": 0.99,             # Discount factor
        "TD_LAMBDA": 0.9,          # TD-Lambda parameter
        "STEP_COST": 0.0,          # Optional cost added per step (DynaLossFn default 0.0)
        "ONLINE_COEFF": 1.0,       # Weight for the loss on real data
        "DYNA_COEFF": 1.0,         # Weight for the loss on simulated data (DynaLossFn default 1.0)

        # --- Dyna Simulation Settings ---
        "NUM_SIMULATIONS": 2,       # Number of parallel simulations per starting state (DynaLossFn default 2)
        "SIMULATION_LENGTH": 10,    # Length of each simulated rollout (DynaLossFn default 5)
        "WINDOW_SIZE": 1,           # Number of windows to use, must be 1 for DynaLossFn

        # --- Actor Settings (Exploration) ---
        # Choose one exploration strategy
        "NUM_EPSILONS": 256,        # Number of epsilon schedules
        "EPSILON_MIN": 0.05,        # Minimum epsilon
        "EPSILON_MAX": 0.9,         # Maximum epsilon
        "EPSILON_BASE": 0.1,        # Base epsilon

        # --- Logging ---
        "LEARNER_LOG_PERIOD": 500,  # How many LEARNER UPDATES between logging losses/metrics
        "GRADIENT_LOG_PERIOD": 500, # How many GRADIENT UPDATES between logging losses/metrics
        "LEARNER_EXTRA_LOG_PERIOD": 5_000, # How many LEARNER UPDATES between extra logging

        # --- Miscellaneous ---
        "SEED": 1,
        "NUM_SEEDS": 1,
        "ENTITY": "hoonshin",
        "PROJECT": "dyna-crafter",
        "WANDB_MODE": "disabled",
    }

In [26]:
@struct.dataclass
class Predictions:
    q_vals: jax.Array
    state: struct.PyTreeNode

class MLP(nn.Module):
    hidden_dim: int
    out_dim: Optional[int] = None
    num_layers: int = 1
    use_bias: bool = True

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_dim, use_bias=self.use_bias)(x)
            x = jax.nn.leaky_relu(x)

        x = nn.Dense(self.out_dim or self.hidden_dim, use_bias=self.use_bias)(x)
        return x

class DynaAgent(nn.Module):
    config: dict
    env: TimestepWrapper
    env_params: Any

    def setup(self):
        self.encoder = MLP(
            hidden_dim=self.config["ENCODER_HIDDEN_DIM"],
            num_layers=self.config["NUM_ENCODER_LAYERS"],
            use_bias=self.config["USE_BIAS"],
            name="encoder_mlp",
        )
        self.q_head = MLP(
            hidden_dim=self.config["Q_HIDDEN_DIM"],
            out_dim=self.env.action_space(self.env_params).n,
            num_layers=self.config["NUM_Q_LAYERS"],
            use_bias=self.config["USE_BIAS"],
            name="q_head_mlp",
        )
        self.rnn = nn.GRUCell(
            features=self.config["RNN_HIDDEN_DIM"],
            name="gru_cell"
        )

        # Cache config values for use during scan (avoid dict access in traced code)
        self.hidden_size = self.config["RNN_HIDDEN_DIM"]
        self.num_envs = self.config["NUM_ENVS"]

    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    def __call__(self, carry, x):
        """
        carry: GRU hidden state [batch, hidden_size]
        x: tuple of (obs, reset flags)
           obs: [T, B, obs_dim...], resets: [T, B]
        """
        rnn_state = carry
        obs, resets = x  # each [batch, ...]

        # Reinitialize RNN state for environments that have reset
        rnn_state = jnp.where(
            resets[:, None],  # [batch, 1]
            self.initialize_carry(resets.shape[0], self.hidden_size),  # [batch, hidden]
            rnn_state
        )

        embeds = self.encoder(obs)  # [batch, embedding_dim]
        next_rnn_state, rnn_out = self.rnn(rnn_state, embeds)  # both [batch, hidden]
        q_vals = self.q_head(rnn_out)  # [batch, num_actions]

        preds = Predictions(q_vals=q_vals, state=next_rnn_state)
        return next_rnn_state, preds
    
    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        return jnp.zeros((batch_size, hidden_size))

    def apply_world_model(self, timestep: struct.PyTreeNode, action: jax.Array, rng: jax.Array) -> struct.PyTreeNode:
        """
        Simulates one step using the 'world model' (ground truth env).
        This wraps the true `env.step` function.
        """
        def step_fn(rng, ts, act):
            return self.env.step(rng, ts, act, self.env_params)

        rngs = jax.random.split(rng, self.num_envs)
        next_timestep = jax.vmap(step_fn)(rngs, timestep, action)
        return next_timestep

In [27]:
agent = DynaAgent(
            config=config,
            env=env,
            env_params=env_params,
        )
rng, init_rng = jax.random.split(rng)
init_x = (
    jnp.zeros(
        (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
    ),
    jnp.zeros((1, config["NUM_ENVS"])),
)
init_carry = DynaAgent.initialize_carry(config["NUM_ENVS"], config["RNN_HIDDEN_DIM"])
online_params = agent.init(init_rng, init_carry, init_x)

In [28]:
jax.tree_util.tree_map(lambda x: x.shape, unfreeze(online_params))

{'params': {'encoder_mlp': {'Dense_0': {'bias': (512,),
    'kernel': (8268, 512)}},
  'gru_cell': {'hn': {'bias': (256,), 'kernel': (256, 256)},
   'hr': {'kernel': (256, 256)},
   'hz': {'kernel': (256, 256)},
   'in': {'bias': (256,), 'kernel': (512, 256)},
   'ir': {'bias': (256,), 'kernel': (512, 256)},
   'iz': {'bias': (256,), 'kernel': (512, 256)}},
  'q_head_mlp': {'Dense_0': {'bias': (1024,), 'kernel': (256, 1024)},
   'Dense_1': {'bias': (1024,), 'kernel': (1024, 1024)},
   'Dense_2': {'bias': (43,), 'kernel': (1024, 43)}}}}

In [21]:
dummy_rnn = nn.GRUCell(features=config["RNN_HIDDEN_DIM"])
init_carry = dummy_rnn.initialize_carry(jax.random.PRNGKey(0), input_shape=(config["NUM_ENVS"], config["RNN_HIDDEN_DIM"]))
init_carry.shape

(32, 256)

In [29]:
env.observation_space(env_params).shape[0]

8268

In [30]:
env.observation_space(env_params).shape

(8268,)

In [34]:
x = jnp.zeros((1,))
x[-1]

Array(0., dtype=float32)

In [35]:
x = jnp.zeros((11, 2))
x[:-1]

Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)