In [None]:
from typing import List, Tuple, Dict
import jax
import jax.numpy as jnp

import jax
import jax.numpy as jnp
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict

# Abstract base class for agents.
class Agent(ABC):
    @abstractmethod
    def __init__(self, *args, **kwargs):
        """
        Initialize the agent with necessary parameters, state variables,
        and internal data structures (e.g., Q-tables or policy networks).
        """
        raise NotImplementedError

    @abstractmethod
    def reset(self, randomize: bool = True):
        """
        Reset the agent's internal state for a new episode.
        Args:
            randomize (bool): If True, randomize the initial state.
        """
        raise NotImplementedError

    @abstractmethod
    def get_state(self):
        """
        Return the current state of the agent.
        """
        raise NotImplementedError

    @abstractmethod
    def choose_action(self, state):
        """
        Choose an action based on the current state using an exploration/exploitation policy.
        Args:
            state: The current state.
        Returns:
            The chosen action.
        """
        raise NotImplementedError

    @abstractmethod
    def step(self, action):
        """
        Execute the given action, updating the agent's state.
        Args:
            action: The action to perform.
        Returns:
            The new state after the action.
        """
        raise NotImplementedError

    @abstractmethod
    def update(self, state, action, reward, next_state):
        """
        Update the agent's internal parameters based on the observed transition.
        Args:
            state: The state before taking the action.
            action: The action taken.
            reward: The reward received.
            next_state: The resulting state.
        """
        raise NotImplementedError

class Environment(ABC):
    @abstractmethod
    def reset(self) -> Dict:
        """
        Reset any internal state of the environment and return environment info if needed.
        For example, it might return grid parameters.
        """
        raise NotImplementedError

    @abstractmethod
    def compute_rewards(self, agent_states: List[Tuple[int, int]]) -> List[float]:
        """
        Given the list of agent states, compute and return a reward for each agent.
        
        Args:
            agent_states (List[Tuple[int, int]]): List of (x, y) positions for all agents.
        
        Returns:
            List[float]: A list of rewards, one per agent.
        """
        raise NotImplementedError

    @abstractmethod
    def is_terminal(self, agent_states: List[Tuple[int, int]]) -> bool:
        """
        Determine if the environment is in a terminal state based on the agent states.
        
        Args:
            agent_states (List[Tuple[int, int]]): List of agent positions.
        
        Returns:
            bool: True if the simulation should terminate.
        """
        raise NotImplementedError

class SingleAgent(Agent):
    def __init__(self, grid_size, alpha=0.1, gamma=0.9, epsilon=0.1, seed=0):
        """
        Initialize the single-agent. Asumes a grid world environment.
        Args:
            grid_size (int): Size of the grid (assumes a square grid).
            alpha (float): Learning rate.
            gamma (float): Discount factor.
            epsilon (float): Exploration rate.
            seed (int): Random seed for reproducibility.
        """
        self.grid_size = grid_size
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.num_actions = 5  # Actions: 0:stay, 1:up, 2:down, 3:left, 4:right

        # Create a Q-table with shape: (grid_size, grid_size, num_actions)
        self.q_table = jnp.zeros((grid_size, grid_size, self.num_actions))
        
        # Initialize PRNG key for JAX-based randomness.
        self.key = jax.random.PRNGKey(seed)
        
        # Agent's initial position (starting at the center).
        self.x = grid_size // 2
        self.y = grid_size // 2

    def reset(self, randomize: bool = True):
        """
        Reset the agent's position.
        Args:
            randomize (bool): If True, set a random starting position; otherwise, use the center.
        """
        if randomize:
            self.key, subkey = jax.random.split(self.key)
            self.x = int(jax.random.randint(subkey, (), 0, self.grid_size))
            self.key, subkey = jax.random.split(self.key)
            self.y = int(jax.random.randint(subkey, (), 0, self.grid_size))
        else:
            self.x = self.y = self.grid_size // 2

    def get_state(self):
        """
        Return the agent's current state as a tuple (x, y).
        """
        return (self.x, self.y)

    @jax.jit
    def _move(self, pos, action, grid_size):
        """
        Compute the new position based on the current position and action.
        Args:
            pos (tuple): Current (x, y) position.
            action (int): Action index.
            grid_size (int): Size of the grid.
        Returns:
            A tuple representing the new position (x, y).
        """
        x, y = pos
        
        def stay():
            return (x, y)
        def up():
            return (x, jnp.maximum(y - 1, 0))
        def down():
            return (x, jnp.minimum(y + 1, grid_size - 1))
        def left():
            return (jnp.maximum(x - 1, 0), y)
        def right():
            return (jnp.minimum(x + 1, grid_size - 1), y)
        
        # jax.lax.switch dispatches based on the action index.
        new_pos = jax.lax.switch(action, [stay, up, down, left, right])
        return new_pos

    def choose_action(self, state):
        """
        Choose an action using an epsilon-greedy policy.
        Args:
            state (tuple): The current state (x, y).
        Returns:
            The selected action as an integer.
        """
        self.key, subkey = jax.random.split(self.key)
        rand_val = jax.random.uniform(subkey)
        if rand_val < self.epsilon:
            # Explore: choose a random action.
            self.key, subkey = jax.random.split(self.key)
            action = int(jax.random.randint(subkey, (), 0, self.num_actions))
            return action
        else:
            # Exploit: choose the action with the highest Q-value.
            x, y = state
            q_vals = self.q_table[x, y, :]
            action = int(jnp.argmax(q_vals))
            return action

    def step(self, action):
        """
        Execute an action and update the agent's position.
        Args:
            action (int): The action to perform.
        Returns:
            The new state (x, y) after the move.
        """
        current_state = self.get_state()
        new_state = self._move(current_state, action, self.grid_size)
        # Update the internal position.
        self.x, self.y = int(new_state[0]), int(new_state[1])
        return new_state

    def update(self, state, action, reward, next_state):
        """
        Update the Q-table using the Q-learning update rule.
        Args:
            state (tuple): The state before the action.
            action (int): The action taken.
            reward (float): The reward received.
            next_state (tuple): The state after the action.
        """
        x, y = state
        nx, ny = next_state
        
        # Current Q-value.
        current_q = self.q_table[x, y, action]
        # Maximum Q-value in the next state.
        max_next_q = jnp.max(self.q_table[nx, ny, :])
        
        # Q-learning update.
        new_q = (1 - self.alpha) * current_q + self.alpha * (reward + self.gamma * max_next_q)
        
        # Update the Q-table immutably.
        self.q_table = self.q_table.at[x, y, action].set(new_q)

class GridCenterReward(Environment):
    def __init__(self, grid_size: int, center_reward: float = 1.0):
        """
        Initialize the grid environment.
        
        Args:
            grid_size (int): The width/height of the grid.
            center_reward (float): Reward given when an agent reaches the center.
        """
        self.grid_size = grid_size
        # Represent the center as a JAX array.
        self.center = jnp.array([grid_size // 2, grid_size // 2])
        self.center_reward = center_reward

    def reset(self) -> Dict:
        """
        Reset the environment. This environment is stateless,
        so reset simply returns grid info.
        """
        return {"grid_size": self.grid_size, "center": self.center}

    def compute_rewards(self, agent_states: List[Tuple[int, int]]) -> List[float]:
        """
        Award the center reward to any agent that reaches the center.
        
        Args:
            agent_states (List[Tuple[int, int]]): List of (x, y) positions for each agent.
            
        Returns:
            List[float]: A reward for each agent.
        """
        # Convert list of tuples to a JAX array of shape (n_agents, 2)
        states = jnp.array(agent_states)
        # Compare each state to the center in a vectorized way.
        is_center = jnp.all(states == self.center, axis=1)
        rewards = jnp.where(is_center, self.center_reward, 0.0)
        # Convert back to a Python list of floats.
        return list(rewards.tolist())

    def is_terminal(self, agent_states: List[Tuple[int, int]]) -> bool:
        """
        Determine if the episode should terminate.
        The episode terminates if any agent reaches the center.
        
        Args:
            agent_states (List[Tuple[int, int]]): List of (x, y) positions for each agent.
            
        Returns:
            bool: True if any agent is at the center.
        """
        states = jnp.array(agent_states)
        is_center = jnp.all(states == self.center, axis=1)
        return bool(jnp.any(is_center))

class Trainer:
    def __init__(self, agents: List, environment: Environment, num_episodes: int = 1000, max_steps: int = 100):
        """
        Initialize the trainer.
        
        Args:
            agents (List): A list of agent instances.
            environment (Environment): An instance of the environment.
            num_episodes (int): Number of episodes to train.
            max_steps (int): Maximum steps per episode.
        """
        self.agents = agents
        self.env = environment
        self.num_episodes = num_episodes
        self.max_steps = max_steps

    def _episode_step(self, state: jnp.ndarray, unused):
        """
        A single time step for all agents using JAX vectorized operations.
        
        Args:
            state: jnp.ndarray of shape (num_agents, 2) representing current agent positions.
            unused: placeholder for scan (not used).
        
        Returns:
            next_state: Updated state after one step.
            info: Tuple containing (prev_state, actions, rewards, done).
        """
        num_agents = state.shape[0]
        
        # --- Action Selection ---
        # For each agent, call its choose_action function.
        # Here we loop over agents (since the number of agents is small),
        # but you can later try to re-write choose_action in a fully vectorized way.
        actions = []
        for i in range(num_agents):
            # Convert the i-th agent's state (a JAX array) to a tuple for compatibility.
            s = tuple(state[i].tolist())
            action = self.agents[i].choose_action(s)
            actions.append(action)
        actions = jnp.array(actions)  # shape (num_agents,)

        # --- Environment Step ---
        # Each agent takes a step with its chosen action.
        next_states = []
        for i in range(num_agents):
            # We assume agent.step returns a new state as a tuple.
            ns = self.agents[i].step(int(actions[i]))
            next_states.append(jnp.array(ns))
        next_state = jnp.stack(next_states)  # shape (num_agents, 2)

        # --- Compute Rewards & Terminal Condition ---
        # Convert next_states to a list of tuples for the environment.
        state_list = [tuple(s.tolist()) for s in next_state]
        rewards = jnp.array(self.env.compute_rewards(state_list))  # shape (num_agents,)
        done = self.env.is_terminal(state_list)

        return next_state, (state, actions, rewards, done)

    def train_jax(self):
        """
        Run a single episode using JAX's vectorized loop (lax.scan).
        
        Returns:
            final_states: The states of all agents after the episode.
            scan_info: A tuple containing per-step info (states, actions, rewards, done flags).
        """
        # Gather initial states from all agents as a JAX array of shape (num_agents, 2).
        init_states = jnp.stack([jnp.array(self.agents[i].get_state()) for i in range(len(self.agents))])
        
        # Run the episode loop for max_steps using lax.scan.
        final_states, scan_info = jax.lax.scan(self._episode_step, init_states, None, length=self.max_steps)
        return final_states, scan_info

    def train(self):
        """
        Run the training loop over multiple episodes.
        For each episode, the environment is reset and the vectorized (JAX) inner loop is run.
        """
        for episode in range(self.num_episodes):
            # Reset the environment and all agents.
            _ = self.env.reset()
            for agent in self.agents:
                agent.reset(randomize=True)
            
            # Run the episode in a vectorized manner.
            final_states, scan_info = self.train_jax()
            # (scan_info contains a tuple of (states, actions, rewards, done) for each step)
            
            # (Here you could aggregate rewards, update logging, etc.)
            print(f"Episode {episode+1} finished.")


In [2]:
agent = SingleAgent(grid_size=5, alpha=0.1, gamma=0.9, epsilon=0.1, seed=0)
env = GridCenterReward(grid_size=5, center_reward=1.0)

In [None]:
tr = Trainer(agents = [agent], environment = env, num_episodes = 1000, max_steps = 100)
report = tr.train()