In [None]:
!pip install git+https://github.com/sanepunk/gymnax.git@main --quiet

In [None]:
import jax
import gymnax
from flax import nnx
import jax.numpy as jnp
import wandb
import optax
import collections
from tqdm import tqdm

key = jax.random.key(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

env, env_params = gymnax.make("CartPole-v1")

In [None]:
class SmallMoEExpert(nnx.Module):
    """Single expert in the MoE layer"""
    def __init__(self, input_dim, hidden_dim, output_dim, rngs: nnx.Rngs, dropout: float = 0.2):
        super().__init__()
        self.linear1 = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
        self.dropout = nnx.Dropout(dropout, rngs=rngs)

    def __call__(self, x):
        x = jax.nn.leaky_relu(self.dropout(self.linear1(x)))
        return self.linear2(x)

In [None]:
class SmallMoELayer(nnx.Module):
    """Very small MoE with only 2-3 experts"""
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, rngs: nnx.Rngs, dropout: float = 0.2):
        super().__init__()
        self.num_experts = num_experts

        # Create experts
        self.experts = []
        for i in range(num_experts):
            expert_rngs = nnx.Rngs(i)
            self.experts.append(SmallMoEExpert(input_dim, hidden_dim, output_dim, expert_rngs, dropout))

        # Gating network
        self.gate = nnx.Linear(input_dim, num_experts, rngs=rngs)

    def __call__(self, x):
        # Compute gating weights
        gate_logits = self.gate(x)  # [batch_size, seq_len, num_experts]
        gate_weights = jax.nn.softmax(gate_logits, axis=-1)

        # Get expert outputs
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(x))  # [batch_size, seq_len, output_dim]

        expert_outputs = jnp.stack(expert_outputs, axis=-2)  # [batch_size, seq_len, num_experts, output_dim]

        # Weighted combination
        gate_weights = gate_weights[..., None]  # [batch_size, seq_len, num_experts, 1]
        output = jnp.sum(gate_weights * expert_outputs, axis=-2)  # [batch_size, seq_len, output_dim]

        return output

In [None]:
class SimpleAttention(nnx.Module):
    """Simplified single-head attention"""
    def __init__(self, embed_dim, rngs: nnx.Rngs):
        super().__init__()
        self.embed_dim = embed_dim
        self.query = nnx.Linear(embed_dim, embed_dim, rngs=rngs)
        self.key = nnx.Linear(embed_dim, embed_dim, rngs=rngs)
        self.value = nnx.Linear(embed_dim, embed_dim, rngs=rngs)

    def __call__(self, x):
        # x shape: [batch_size, seq_len, embed_dim]
        Q = self.query(x)  # [batch_size, seq_len, embed_dim]
        K = self.key(x)    # [batch_size, seq_len, embed_dim]
        V = self.value(x)  # [batch_size, seq_len, embed_dim]

        # For self-attention with single timestep, we can simplify
        if x.shape[1] == 1:  # seq_len == 1
            # No need for complex attention computation with single timestep
            # Just apply a learned transformation
            output = V  # or could do Q + K + V or other combinations
        else:
            # Full attention computation for multiple timesteps
            # Transpose the last two dimensions properly
            K_T = jnp.swapaxes(K, -2, -1)  # [batch_size, embed_dim, seq_len]
            scores = jnp.matmul(Q, K_T) / jnp.sqrt(self.embed_dim)
            attention_weights = jax.nn.softmax(scores, axis=-1)
            output = jnp.matmul(attention_weights, V)

        return output

In [None]:
class MoEAttentionPolicy(nnx.Module):
    def __init__(self, observation_space, action_space, rngs: nnx.Rngs, dropout: float = 0.2):
        super().__init__()

        self.input_size = observation_space.shape[0]
        self.hidden_size = 64
        self.output_size = action_space.n

        # Input projection to match hidden size
        self.input_proj = nnx.Linear(self.input_size, self.hidden_size, rngs=rngs)

        # First MoE layer (very small - 2 experts)
        self.moe1 = SmallMoELayer(
            input_dim=self.hidden_size,
            hidden_dim=32,
            output_dim=self.hidden_size,
            num_experts=2,
            rngs=rngs,
            dropout=dropout
        )

        # Simple attention layer
        self.attention1 = SimpleAttention(self.hidden_size, rngs=rngs)

        # Second MoE layer (very small - 2 experts)
        self.moe2 = SmallMoELayer(
            input_dim=self.hidden_size,
            hidden_dim=32,
            output_dim=self.hidden_size,
            num_experts=2,
            rngs=rngs,
            dropout=dropout
        )

        self.attention2 = SimpleAttention(self.hidden_size, rngs=rngs)

        # Final layers
        self.output = nnx.Linear(self.hidden_size, self.hidden_size, rngs=rngs)
        self.mean_layer = nnx.Linear(self.hidden_size, self.output_size, rngs=rngs)
        self.action_low = -2.0
        self.action_high = 2.0

    def __call__(self, x):
        """
        Forward pass for batched data
        x: [batch_size, obs_dim] - flattened batch of all timesteps
        """
        # Add sequence dimension for attention and MoE
        if x.ndim == 2:
            x = x[:, None, :]  # [batch_size, 1, obs_dim]

        # Project input to hidden size
        x = self.input_proj(x)  # [batch_size, seq_len, hidden_size]

        # First MoE layer with residual connection
        moe1_out = self.moe1(x)
        x = jnp.add(x, moe1_out)  # Residual connection

        # Attention layer with residual connection
        attn_out = self.attention1(x)
        x = jnp.add(x, attn_out)  # Residual connection

        # Second MoE layer with residual connection
        moe2_out = self.moe2(x)
        x = jnp.add(x, moe2_out)  # Residual connection

        attn_out = self.attention2(x)
        x = jnp.add(x, attn_out)  # Residual connection
        # Final layers
        x = jax.nn.leaky_relu(self.output(x))
        output = self.mean_layer(x)

        # Remove sequence dimension if it was added
        if output.shape[1] == 1:
            output = output.squeeze(1)  # [batch_size, output_size]

        return output

    def select_action(self, x, key):
        # Ensure x has batch dimension for consistent processing
        if x.ndim == 1:
            x = x[None, :]  # Add batch dimension

        logits = self(x)

        # Remove batch dimension if it was added
        if logits.shape[0] == 1:
            logits = logits.squeeze(0)

        return jax.random.categorical(key, logits)



In [None]:
wandb.init(project="JAX-GYMNAX", config={
    "env": "Cartpole",
    "lr": 1e-2,
    "gamma": 0.99,
    "episodes": 100,
})

In [None]:
@nnx.jit
def loss(model, obs, actions, returns):
    log_logits = jax.nn.log_softmax(model(obs))
    log_prob_taken = jnp.take_along_axis(log_logits, actions[:, None], axis=1).squeeze()
    return -jnp.mean(log_prob_taken * returns)

@nnx.jit
def compute_returns(rewards, gamma):
    R = 0
    returns = []
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return jnp.array(returns)

def train(env, env_params, model, episodes: int = 50, learning_rate=1e-3, gamma=0.99):
    optimizer = nnx.Optimizer(
        model,
        optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(learning_rate=learning_rate),
        )
    )
    grad_func = nnx.value_and_grad(loss)
    key = jax.random.PRNGKey(0)
    total_rewards = collections.deque(maxlen=100)

    with tqdm(range(episodes)) as pbar:
        for i in pbar:
            batch_all_obs = []
            batch_all_actions = []
            batch_all_returns = []
            # Using Batching for Gradient Stability
            for _ in range(10):
                episode_obs = []
                episode_actions = []
                episode_rewards = []

                done = False
                key, reset_key = jax.random.split(key)
                obs, state = env.reset(reset_key, env_params)

                while not done:
                    key, action_key, step_key = jax.random.split(key, 3)
                    action = model.select_action(obs, action_key)
                    next_obs, state, reward, done, _ = env.step(step_key, state, action, env_params)

                    episode_obs.append(obs)
                    episode_actions.append(action)
                    episode_rewards.append(reward)
                    obs = next_obs

                total_rewards.append(sum(episode_rewards))
                returns = compute_returns(episode_rewards, gamma)

                batch_all_obs.extend(episode_obs)
                batch_all_actions.extend(episode_actions)
                batch_all_returns.extend(returns)

            final_obs = jnp.stack(batch_all_obs)
            final_actions = jnp.array(batch_all_actions)
            final_returns = jnp.array(batch_all_returns)

            final_returns = (final_returns - jnp.mean(final_returns)) / (jnp.std(final_returns) + 1e-8)

            value, grad = grad_func(model, final_obs, final_actions, final_returns)
            optimizer.update(grad)

            avg_reward = sum(total_rewards) / len(total_rewards)
            wandb.log({
                "episodic_reward": avg_reward,
                "global_step": i,
                "loss": value.item()
            })
            pbar.set_description(f"Episode: {i}, Loss: {value.item():.4f}, Reward: {avg_reward:.2f}")

In [None]:
moe_attention_model = MoEAttentionPolicy(
    env.observation_space(env_params),
    env.action_space(env_params),
    rngs=nnx.Rngs(0),
    dropout=0.05
)
moe_attention_model.train()

train(env, env_params, moe_attention_model, episodes=100)