In [1]:
import random
import numpy as np
from tqdm import tqdm
from random import sample

In [2]:
class JointAgent:
    """
    Agent representing two entities that learn jointly using Q-learning.
    Each entity (agent 1 and agent 2) maintains its own value function.
    """
    def __init__(self, width, height, draw_radius, alpha=0.1, gamma=0.9, epsilon=0.1):
        # Initialize positions randomly for two agents
        self.x1 = random.randint(0, width - 1)
        self.y1 = random.randint(0, height - 1)
        self.x2 = random.randint(0, width - 1)
        self.y2 = random.randint(0, height - 1)
        self.width = width
        self.height = height
        # self.draw_radius is not used (commented out for performance reasons)
        self.alpha = alpha  # learning rate
        self.gamma = gamma  # discount factor for future rewards
        self.epsilon = epsilon  # epsilon for epsilon-greedy action selection
        self.q_values1 = {}  # Q-values (or value function) for agent 1
        self.q_values2 = {}  # Q-values for agent 2
        self.occupancy = {}  # Tracks visitation counts
        self.previous_state = None
        self.previous_action = None  
        self.possible_actions = ["stay", "up", "down", "left", "right"]
        self.rewards_active = ()  # Current active reward location(s)

    def update_occupancy(self):
        """Increment occupancy counter for current joint state."""
        state0 = self.get_state(0)[0]
        state1 = self.get_state(1)[0]
        self.occupancy[(state0, state1)] = 1 + self.occupancy.get((state0, state1), 0)

    def reset(self):
        """Reset both agent positions randomly and clear active rewards."""
        self.x1 = random.randint(0, self.width - 1)
        self.y1 = random.randint(0, self.height - 1)
        self.x2 = random.randint(0, self.width - 1)
        self.y2 = random.randint(0, self.height - 1)
        self.rewards_active = ()

    def reset_to_center(self):
        """Reset both agent positions to the center of the grid."""
        center = (int(self.width / 2), int(self.height / 2))
        self.x1, self.y1 = center
        self.x2, self.y2 = center
        self.rewards_active = ()

    def choose_action(self):
        """
        Choose an action pair (one for each agent) using an epsilon-greedy policy.
        First, all 25 (5x5) possible action pairs are evaluated and the one with the
        highest Q-value is chosen (with random tie-breaking). Q-values are computed
        separately for each agent.
        """
        if random.uniform(0, 1) < self.epsilon:
            # Random actions for exploration.
            return random.choice(self.possible_actions), random.choice(self.possible_actions)
        else:
            # Evaluate all possible action pairs.
            action_pairs = [(a1, a2) for a1 in self.possible_actions for a2 in self.possible_actions]
            # Get Q-values for agent 1 for each possible next state.
            q_values = [self.get_q_value(0,
                           self.get_next_state(((self.x1, self.y1), self.rewards_active), a1),
                           self.get_next_state(((self.x2, self.y2), self.rewards_active), a2))
                        for a1, a2 in action_pairs]
            max_indices = np.where(np.array(q_values) == np.max(q_values))[0]
            idx = random.choice(max_indices)
            action1 = action_pairs[idx][0]

            # Get Q-values for agent 2 for each possible next state.
            q_values = [self.get_q_value(1,
                           self.get_next_state(((self.x1, self.y1), self.rewards_active), a1),
                           self.get_next_state(((self.x2, self.y2), self.rewards_active), a2))
                        for a1, a2 in action_pairs]
            max_indices = np.where(np.array(q_values) == np.max(q_values))[0]
            idx = random.choice(max_indices)
            action2 = action_pairs[idx][1]

            return action1, action2  

    def get_next_state(self, states, action):
        """
        Compute the next state given the current state and an action.
        The state is represented as a tuple: (position, rewards_active)
        """
        pos, _ = states
        if action == "up":
            next_pos = (pos[0], max(0, pos[1] - 1))
        elif action == "down":
            next_pos = (pos[0], min(self.height - 1, pos[1] + 1))
        elif action == "left":
            next_pos = (max(0, pos[0] - 1), pos[1])
        elif action == "right":
            next_pos = (min(self.width - 1, pos[0] + 1), pos[1])
        elif action == "stay":
            next_pos = (pos[0], pos[1])
        return (next_pos, self.rewards_active)

    def get_q_value(self, idx, state1, state2):
        """Retrieve the Q-value for a given joint state for agent 0 or 1."""
        if idx == 0:
            return self.q_values1.get((state1, state2), 0)
        elif idx == 1:
            return self.q_values2.get((state1, state2), 0)

    def get_state(self, idx):
        """Return the current state (position and reward activation) for agent idx."""
        if idx == 0:
            return ((self.x1, self.y1), self.rewards_active)
        elif idx == 1:
            return ((self.x2, self.y2), self.rewards_active)

    def update_q_value(self, state1, state2, reward1, reward2):
        """
        Update Q-values for both agents using a Bellman update (Gauss-Seidel style).
        """
        next_states1 = [self.get_next_state(state1, a) for a in self.possible_actions]
        next_states2 = [self.get_next_state(state2, a) for a in self.possible_actions]
        
        current_state = (self.get_state(0), self.get_state(1))
        
        # Update for agent 0.
        next_q_values = [self.get_q_value(0, s1, s2) for s1 in next_states1 for s2 in next_states2]
        self.q_values1[current_state] = (1 - self.alpha) * self.get_q_value(0, self.get_state(0), self.get_state(1)) + \
                                        self.alpha * (reward1 + self.gamma * max(next_q_values))
        # Update for agent 1.
        next_q_values = [self.get_q_value(1, s1, s2) for s1 in next_states1 for s2 in next_states2]
        self.q_values2[current_state] = (1 - self.alpha) * self.get_q_value(1, self.get_state(0), self.get_state(1)) + \
                                        self.alpha * (reward2 + self.gamma * max(next_q_values))

    def move(self, idx, action):
        """
        Update the position of the agent (0 or 1) based on the given action.
        """
        if idx == 0:
            if action == "up":
                self.y1 = max(0, self.y1 - 1)
            elif action == "down":
                self.y1 = min(self.height - 1, self.y1 + 1)
            elif action == "left":
                self.x1 = max(0, self.x1 - 1)
            elif action == "right":
                self.x1 = min(self.width - 1, self.x1 + 1)
        elif idx == 1:
            if action == "up":
                self.y2 = max(0, self.y2 - 1)
            elif action == "down":
                self.y2 = min(self.height - 1, self.y2 + 1)
            elif action == "left":
                self.x2 = max(0, self.x2 - 1)
            elif action == "right":
                self.x2 = min(self.width - 1, self.x2 + 1)

In [3]:
class BaseEnvironment:
    """
    Base class for training environments.
    
    This class defines common parameters and methods such as the grid size,
    reward magnitudes, and a mapping from reward location identifiers to grid
    coordinates. Subclasses must override the reward activation and computation
    logic.
    """
    def __init__(self, grid_size, target_reward1, target_reward2, together_reward, travel_reward):
        self.grid_size = grid_size
        self.target_reward1 = target_reward1
        self.target_reward2 = target_reward2
        self.together_reward = together_reward
        self.travel_reward = travel_reward
        self.center_pos = (grid_size // 2, grid_size // 2)
        self.steps_without_reward = 0  # Counter to trigger resets when no reward is collected
        self.reset_reward_locations()  # Sets self.true_reward_locations based on regime

        # Mapping from reward identifiers to grid coordinates.
        # Note: single-element tuples use a trailing comma.
        self.reward_place_to_coord = {
            (): (),
            ('up',): ((grid_size // 2, grid_size - 1),),
            ('right',): ((grid_size - 1, grid_size // 2),),
            ('down',): ((grid_size // 2, 0),),
            ('left',): ((0, grid_size // 2),),
            ('up', 'right'): ((grid_size // 2, grid_size - 1), (grid_size - 1, grid_size // 2)),
            ('right', 'down'): ((grid_size - 1, grid_size // 2), (grid_size // 2, 0)),
            ('down', 'left'): ((grid_size // 2, 0), (0, grid_size // 2)),
            ('up', 'left'): ((0, grid_size // 2), (grid_size // 2, grid_size - 1)),
            ('up', 'down'): ((grid_size // 2, grid_size - 1), (grid_size // 2, 0)),
            ('right', 'left'): ((grid_size - 1, grid_size // 2), (0, grid_size // 2)),
            ('up', 'right', 'down'): ((grid_size // 2, grid_size - 1), (grid_size - 1, grid_size // 2), (grid_size // 2, 0)),
            ('right', 'down', 'left'): ((grid_size - 1, grid_size // 2), (grid_size // 2, 0), (0, grid_size // 2)),
            ('up', 'down', 'left'): ((grid_size // 2, 0), (0, grid_size // 2), (grid_size // 2, grid_size - 1)),
            ('up', 'right', 'left'): ((0, grid_size // 2), (grid_size // 2, grid_size - 1), (grid_size - 1, grid_size // 2)),
            ('up', 'right', 'down', 'left'): ((grid_size // 2, grid_size - 1), (grid_size - 1, grid_size // 2),
                                              (grid_size // 2, 0), (0, grid_size // 2))
        }

    def reset_reward_locations(self):
        """
        Set the true reward locations for the environment.
        Subclasses must override this to provide regime-specific settings.
        """
        raise NotImplementedError

    def check_and_activate_rewards(self, agent, pos1, pos2):
        """
        Check if an agent reaches the center to activate reward locations.
        This method may differ per regime.
        """
        # Default: do nothing.
        pass

    def compute_rewards(self, agent, pos1, pos2):
        """
        Compute and return rewards for both agents based on their positions.
        Must return a tuple: (reward1, reward2, collected) where collected is a boolean
        indicating if a target reward was collected.
        """
        collected = False
        reward1 = 0
        reward2 = 0
        if agent.rewards_active:
            coords = self.reward_place_to_coord.get(self.true_reward_locations, ())
            if coords:
                if pos1 == coords[0] and pos2 == coords[0]:
                    reward1 += self.target_reward1
                    reward2 += self.target_reward2
                    collected = True
                elif len(coords) > 1 and pos1 == coords[1] and pos2 == coords[1]:
                    reward1 += self.target_reward1
                    reward2 += self.target_reward2
                    collected = True
        return reward1, reward2, collected

    def step(self, agent, action1, action2):
        """
        Execute one time step in the environment.
        Moves the agents according to their chosen actions, computes the rewards,
        updates internal counters, and resets the environment if necessary.
        
        Returns:
            reward1, reward2: rewards for agent 1 and agent 2.
            info: (optional) additional information.
        """
        # Move the agents.
        agent.move(0, action1)
        agent.move(1, action2)
        agent.update_occupancy()
        
        # Get current positions.
        state1 = agent.get_state(0)
        state2 = agent.get_state(1)
        pos1, _ = state1
        pos2, _ = state2

        # Regime-specific activation of rewards.
        self.check_and_activate_rewards(agent, pos1, pos2)
        # Compute rewards based on current positions.
        reward1, reward2, collected = self.compute_rewards(agent, pos1, pos2)
        
        # Add a bonus if the agents are together.
        if pos1 == pos2:
            reward1 += self.together_reward
            reward2 += self.together_reward

        # Apply travel (energy loss) penalty.
        reward1 += self.travel_reward
        reward2 += self.travel_reward

        # Update step counter.
        if not collected:
            self.steps_without_reward += 1
        else:
            self.steps_without_reward = 0

        # If no rewards for a long time, reset agent positions and reward locations.
        if self.steps_without_reward > 50:
            agent.reset()
            self.steps_without_reward = 0
            self.reset_reward_locations()

        return reward1, reward2, {}

class EnvironmentRegime0(BaseEnvironment):
    """
    Regime 0:
      - Rewards are delivered immediately when ANY agent reaches the center.
      - There is no subsequent reward zone to navigate.
    """
    def reset_reward_locations(self):
        # For regime 0, reward zones are not used.
        self.true_reward_locations = None

    def check_and_activate_rewards(self, agent, pos1, pos2):
        # In regime 0, we do not need to activate a reward zone.
        pass

    def compute_rewards(self, agent, pos1, pos2):
        """
        If either agent reaches the center, deliver the target rewards immediately.
        """
        collected = False
        reward1 = 0
        reward2 = 0
        if pos1 == self.center_pos or pos2 == self.center_pos:
            reward1 = self.target_reward1
            reward2 = self.target_reward2
            collected = True
            # Clear any active reward zones.
            agent.rewards_active = ()
        return reward1, reward2, collected

class EnvironmentRegime1(BaseEnvironment):
    """
    Regime 1:
    - Reward locations are fixed to ('right', 'left').
    - Rewards are activated when an agent reaches the center (but the reward is
      only delivered when the agents later reach the designated coordinates).
    """
    def reset_reward_locations(self):
        self.true_reward_locations = ('right', 'left')

    def check_and_activate_rewards(self, agent, pos1, pos2):
        if not agent.rewards_active and (pos1 == self.center_pos or pos2 == self.center_pos):
            agent.rewards_active = self.true_reward_locations

class EnvironmentRegime2(BaseEnvironment):
    """
    Regime 2:
    - Reward locations are fixed to ('up', 'down').
    - Reward activation happens when either agent reaches the center.
    """
    def reset_reward_locations(self):
        self.true_reward_locations = ('up', 'down')

    def check_and_activate_rewards(self, agent, pos1, pos2):
        if not agent.rewards_active and (pos1 == self.center_pos or pos2 == self.center_pos):
            agent.rewards_active = self.true_reward_locations

class EnvironmentRegime3(BaseEnvironment):
    """
    Regime 3:
    - Reward locations are chosen randomly from a set of possibilities.
    - Rewards are activated upon an agent reaching the center.
    """
    def reset_reward_locations(self):
        possibilities = [
            ('up', 'right'),
            ('right', 'down'),
            ('down', 'left'),
            ('up', 'left'),
            ('right', 'left'),
            ('up', 'down')
        ]
        self.true_reward_locations = sample(possibilities, 1)[0]

    def check_and_activate_rewards(self, agent, pos1, pos2):
        if not agent.rewards_active and (pos1 == self.center_pos or pos2 == self.center_pos):
            agent.rewards_active = self.true_reward_locations


In [4]:
class Trainer:
    """
    Trainer class to handle the training loop of a JointAgent within a given environment.
    """
    def __init__(self, agent, environment, num_frames, epsilon_start, epsilon_end):
        """
        Initialize the trainer.
        
        Args:
            agent: Instance of JointAgent.
            environment: An instance of a BaseEnvironment subclass.
            num_frames: Total number of training frames (iterations).
            epsilon_start: Initial epsilon value.
            epsilon_end: Final epsilon value.
        """
        self.agent = agent
        self.env = environment
        self.num_frames = num_frames
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        # Histories to track rewards (can be used for plotting later)
        self.rewards_history = {'agent1': [], 'agent2': []}

    def train(self):
        """Run the training loop."""
        for frame in tqdm(range(self.num_frames)):
            # Linearly update epsilon for exploration/exploitation trade-off.
            self.agent.epsilon = ((self.epsilon_end - self.epsilon_start) / self.num_frames) * frame + self.epsilon_start
            
            # Agent selects actions.
            action1, action2 = self.agent.choose_action()
            # Environment processes the step.
            reward1, reward2, info = self.env.step(self.agent, action1, action2)
            
            # If a previous state exists, update Q-values using the Bellman equation.
            current_state = (self.agent.get_state(0), self.agent.get_state(1))
            if self.agent.previous_state is not None:
                self.agent.update_q_value(self.agent.get_state(0), self.agent.get_state(1), reward1, reward2)
            self.agent.previous_state = current_state
            self.agent.previous_action = (action1, action2)
            
            # Record rewards for analysis.
            self.rewards_history['agent1'].append(reward1)
            self.rewards_history['agent2'].append(reward2)
        
        return self.rewards_history

In [5]:
if __name__ == "__main__":
    # Define training parameters
    grid_size = 5
    target_reward1 = 20
    target_reward2 = 20
    together_reward = 0
    travel_reward = -1
    num_frames = 1_000_00
    epsilon_start = 0.99
    epsilon_end = 0.4

    # Initialize the joint agent.
    agent = JointAgent(width=grid_size, height=grid_size, draw_radius=1)
    
    # Create one environment instance per regime.
    env0 = EnvironmentRegime0(grid_size, target_reward1, target_reward2, together_reward, travel_reward)
    env1 = EnvironmentRegime1(grid_size, target_reward1, target_reward2, together_reward, travel_reward)
    env2 = EnvironmentRegime2(grid_size, target_reward1, target_reward2, together_reward, travel_reward)
    env3 = EnvironmentRegime3(grid_size, target_reward1, target_reward2, together_reward, travel_reward)
    
    # List the environments and train sequentially.
    environments = [env0, env1, env2, env3]
    
    for i, env in enumerate(environments):
        print(f"--- Training in Environment Regime {i} ---")
        trainer = Trainer(agent, env, num_frames, epsilon_start, epsilon_end)
        rewards_history = trainer.train()
        # Reset reward activation for the next regime.
        agent.rewards_active = ()
        # (Optional) Save or plot rewards_history here.


--- Training in Environment Regime 0 ---


100%|██████████| 100000/100000 [00:01<00:00, 53689.38it/s]


--- Training in Environment Regime 1 ---


100%|██████████| 100000/100000 [00:01<00:00, 52899.61it/s]


--- Training in Environment Regime 2 ---


100%|██████████| 100000/100000 [00:01<00:00, 52848.09it/s]


--- Training in Environment Regime 3 ---


100%|██████████| 100000/100000 [00:01<00:00, 51704.98it/s]


In [6]:
agent.q_values1

{(((3, 0), ()), ((3, 3), ())): 151.9527366654862,
 (((3, 0), ()), ((2, 3), ())): 168.69492538147165,
 (((4, 0), ()), ((2, 4), ())): 150.87588362103986,
 (((4, 0), ()), ((1, 4), ())): 135.2449351865924,
 (((4, 1), ()), ((1, 4), ())): 134.89246212865785,
 (((4, 1), ()), ((0, 4), ())): 135.39289331490525,
 (((4, 2), ()), ((1, 4), ())): 151.77424358263673,
 (((4, 1), ()), ((2, 4), ())): 151.92590258076387,
 (((3, 1), ()), ((3, 4), ())): 151.94930255747727,
 (((2, 1), ()), ((3, 4), ())): 169.91470845721156,
 (((2, 2), ()), ((4, 4), ())): 180.93817103280813,
 (((2, 3), ()), ((4, 3), ())): 169.9985260672111,
 (((2, 4), ()), ((4, 3), ())): 151.99750049324612,
 (((1, 4), ()), ((4, 2), ())): 151.89149913022013,
 (((0, 4), ()), ((4, 1), ())): 126.11802684896608,
 (((0, 4), ()), ((4, 2), ())): 151.6394986864723,
 (((0, 3), ()), ((3, 2), ())): 169.9931379205486,
 (((0, 3), ()), ((4, 2), ())): 151.98205952275745,
 (((1, 3), ()), ((3, 2), ())): 169.06069646587866,
 (((2, 3), ()), ((4, 2), ())): 169.9

In [7]:
len(agent.q_values1.keys())

4362