In [None]:
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Tuple, Dict

import jax
import jax.numpy as jnp


# --- Batched Agent State ---

@jax.tree_util.register_pytree_node_class
@dataclass
class BatchSingleAgentState:
    x: jnp.ndarray       # shape: (n_agents,)
    y: jnp.ndarray       # shape: (n_agents,)
    q_table: jnp.ndarray # shape: (n_agents, grid_size, grid_size, num_actions)
    key: jnp.ndarray     # shape: (n_agents, ...)

    def tree_flatten(self):
        # The children are the arrays; there is no auxiliary static data.
        children = (self.x, self.y, self.q_table, self.key)
        aux_data = None
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

# --- Abstract Base Classes ---

class Agent(ABC):
    @abstractmethod
    def init_state_batch(self, n_agents: int, randomize: bool = True) -> BatchSingleAgentState:
        pass

    @abstractmethod
    def choose_action_batch(self, state: BatchSingleAgentState) -> Tuple[jnp.ndarray, BatchSingleAgentState]:
        pass

    @abstractmethod
    def move_batch(self, state: BatchSingleAgentState, actions: jnp.ndarray) -> BatchSingleAgentState:
        pass

    @abstractmethod
    def update_batch(self, state: BatchSingleAgentState, actions: jnp.ndarray,
                       rewards: jnp.ndarray, next_state: BatchSingleAgentState) -> BatchSingleAgentState:
        pass


class Environment(ABC):
    @abstractmethod
    def reset(self) -> Dict:
        pass

    @abstractmethod
    def compute_rewards_batch(self, state: BatchSingleAgentState) -> jnp.ndarray:
        pass

    @abstractmethod
    def is_terminal_batch(self, state: BatchSingleAgentState) -> bool:
        pass


# --- Concrete Implementations ---

class SingleAgent(Agent):
    def __init__(self, grid_size: int, alpha: float = 0.1, gamma: float = 0.9,
                 epsilon: float = 0.1, seed: int = 0):
        self.grid_size = grid_size
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.num_actions = 5  # 0:stay, 1:up, 2:down, 3:left, 4:right
        self.seed = seed

    def init_state_batch(self, n_agents: int, randomize: bool = True) -> BatchSingleAgentState:
        base_key = jax.random.PRNGKey(self.seed)
        # Split the base key into n_agents keys
        agent_keys = jax.random.split(base_key, n_agents)
        def init_fn(key):
            # Split key into three parts: two for random position and one to carry forward.
            key, subkey_x, subkey_y = jax.random.split(key, 3)
            if randomize:
                x = jax.random.randint(subkey_x, (), 0, self.grid_size, dtype=jnp.int32)
                y = jax.random.randint(subkey_y, (), 0, self.grid_size, dtype=jnp.int32)
            else:
                x = jnp.array(self.grid_size // 2, dtype=jnp.int32)
                y = jnp.array(self.grid_size // 2, dtype=jnp.int32)
            return x, y, key
        xs, ys, new_keys = jax.vmap(init_fn)(agent_keys)
        q_table = jnp.zeros((n_agents, self.grid_size, self.grid_size, self.num_actions))
        return BatchSingleAgentState(x=xs, y=ys, q_table=q_table, key=new_keys)

    def choose_action_batch(self, state: BatchSingleAgentState) -> Tuple[jnp.ndarray, BatchSingleAgentState]:
        def choose_fn(x, y, q_table, key):
            key, subkey = jax.random.split(key)
            rand_val = jax.random.uniform(subkey)
            # Branch: either choose a random action or choose argmax of Q-values.
            def random_branch(key):
                key, subkey2 = jax.random.split(key)
                action = jax.random.randint(subkey2, (), 0, self.num_actions)
                return action, key
            def exploit_branch(key):
                action = jnp.argmax(q_table[x, y, :])
                return action, key
            action, new_key = jax.lax.cond(rand_val < self.epsilon,
                                           random_branch,
                                           exploit_branch,
                                           operand=key)
            return action, new_key
        actions, new_keys = jax.vmap(choose_fn)(state.x, state.y, state.q_table, state.key)
        new_state = BatchSingleAgentState(x=state.x, y=state.y, q_table=state.q_table, key=new_keys)
        return actions, new_state

    def move_batch(self, state: BatchSingleAgentState, actions: jnp.ndarray) -> BatchSingleAgentState:
        def move_fn(x, y, action):
            grid_size = self.grid_size
            def stay():
                return x, y
            def up():
                return x, jnp.maximum(y - 1, 0)
            def down():
                return x, jnp.minimum(y + 1, grid_size - 1)
            def left():
                return jnp.maximum(x - 1, 0), y
            def right():
                return jnp.minimum(x + 1, grid_size - 1), y
            new_x, new_y = jax.lax.switch(action, [stay, up, down, left, right])
            return new_x, new_y
        new_x, new_y = jax.vmap(move_fn)(state.x, state.y, actions)
        return BatchSingleAgentState(x=new_x, y=new_y, q_table=state.q_table, key=state.key)

    def update_batch(self, state: BatchSingleAgentState, actions: jnp.ndarray,
                     rewards: jnp.ndarray, next_state: BatchSingleAgentState) -> BatchSingleAgentState:
        def update_fn(x, y, q_table, action, reward, next_x, next_y):
            current_q = q_table[x, y, action]
            max_next_q = jnp.max(q_table[next_x, next_y, :])
            new_q = (1 - self.alpha) * current_q + self.alpha * (reward + self.gamma * max_next_q)
            new_q_table = q_table.at[x, y, action].set(new_q)
            return new_q_table
        new_q_table = jax.vmap(update_fn)(
            state.x, state.y, state.q_table, actions, rewards, next_state.x, next_state.y
        )
        return BatchSingleAgentState(x=state.x, y=state.y, q_table=new_q_table, key=state.key)


class GridCenterReward(Environment):
    def __init__(self, grid_size: int, center_reward: float = 1.0):
        self.grid_size = grid_size
        self.center = jnp.array([grid_size // 2, grid_size // 2], dtype=jnp.int32)
        self.center_reward = center_reward

    def reset(self) -> Dict:
        return {"grid_size": self.grid_size, "center": self.center}

    def compute_rewards_batch(self, state: BatchSingleAgentState) -> jnp.ndarray:
        positions = jnp.stack([state.x, state.y], axis=1)  # shape: (n_agents, 2)
        is_center = jnp.all(positions == self.center, axis=1)
        rewards = jnp.where(is_center, self.center_reward, 0.0)
        return rewards

    def is_terminal_batch(self, state: BatchSingleAgentState) -> bool:
        positions = jnp.stack([state.x, state.y], axis=1)
        is_center = jnp.all(positions == self.center, axis=1)
        return jnp.any(is_center)


# --- Trainer using JAX's vectorized scan ---
class Trainer:
    def __init__(self, agent: SingleAgent, environment: GridCenterReward,
                 n_agents: int, num_episodes: int = 1000, max_steps: int = 100):
        self.agent = agent
        self.env = environment
        self.n_agents = n_agents
        self.num_episodes = num_episodes
        self.max_steps = max_steps

    def _episode_step(self, state: BatchSingleAgentState, _):
        # Choose actions in a batched, jitted way.
        actions, state_after_choice = self.agent.choose_action_batch(state)
        # Update positions based on actions.
        new_state = self.agent.move_batch(state_after_choice, actions)
        # Compute rewards for the new positions.
        rewards = self.env.compute_rewards_batch(new_state)
        # Update Q-table based on transition.
        updated_state = self.agent.update_batch(state, actions, rewards, new_state)
        # For logging, we return positions (from state.x, state.y), actions, and rewards.
        return updated_state, (state.x, state.y, actions, rewards)

    def train_episode(self, state: BatchSingleAgentState):
        final_state, scan_info = jax.lax.scan(self._episode_step, state, None, length=self.max_steps)
        return final_state, scan_info

    def train(self):
        for episode in range(self.num_episodes):
            _ = self.env.reset()
            # Initialize the batched state for all agents.
            state = self.agent.init_state_batch(self.n_agents, randomize=True)
            final_state, scan_info = self.train_episode(state)
            # For example, we log the final positions.
            final_positions = jnp.stack([final_state.x, final_state.y], axis=1)
            print(f"Episode {episode+1} finished")



In [6]:
grid_size = 5
n_agents = 1  # Increase number of agents to better amortize JAX overhead.
agent = SingleAgent(grid_size=grid_size, alpha=0.1, gamma=0.9, epsilon=0.1, seed=42)
env = GridCenterReward(grid_size=grid_size, center_reward=1.0)
trainer = Trainer(agent=agent, environment=env, n_agents=n_agents, num_episodes=10, max_steps=20)
trainer.train()

Episode 1 finished
Episode 2 finished
Episode 3 finished
Episode 4 finished
Episode 5 finished
Episode 6 finished
Episode 7 finished
Episode 8 finished
Episode 9 finished
Episode 10 finished
