In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pygame
import sys
import numpy as np

# Constants
TILE_SIZE = 24
GRID_WIDTH = 30
GRID_HEIGHT = 20
SCREEN_WIDTH = TILE_SIZE * GRID_WIDTH
SCREEN_HEIGHT = TILE_SIZE * GRID_HEIGHT
FPS = 10  # Increased FPS for smoother animation
NUM_ACTIONS = 4  # Number of actions (up, down, left, right)

# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Fireboy and Watergirl Grid World")
clock = pygame.time.Clock()

# Colors
COLORS = {
    "0": (120, 66, 18),   # Brown - Empty
    "1": (160, 82, 45),   # Light Brown - Platform
    "F": (255, 0, 0),     # Fireboy - Red Heart
    "W": (0, 0, 255),     # Watergirl - Blue Heart
    "G": (0, 255, 0),     # Green Agent
    "RP": (255, 0, 0),  # Red Poison - Now a block
    "BP": (0, 0, 255),  # Blue Poison - Now a block
    "GP": (0, 255, 0),  # Green Poison - Now a block
    "FG": (255, 0, 0),    # Fireboy Goal
    "WG": (0, 0, 255),    # Watergirl Goal
    "RB": (255, 0, 0),    # Red Gem
    "BB": (0, 0, 255),    # Blue Gem
    "YS": (255, 255, 0),  # Yellow Slide
    "PS": (128, 0, 128),  # Purple Slide
    "YB": (255, 255, 0),  # Yellow Button
    "PB": (128, 0, 128),  # Purple Button
}

BASE_COLOR = COLORS["0"]

# Original grid layout - THIS WILL NEVER BE MODIFIED
ORIGINAL_GRID = [["0" for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)]

# Place sample tiles
ORIGINAL_GRID[18][1] = "F"    # Fireboy
ORIGINAL_GRID[15][1] = "W"    # Watergirl
ORIGINAL_GRID[2][19] = "G"    # Green agent

# Poison tiles
for x in range(10, 13):
    ORIGINAL_GRID[18][x] = "RP"

for x in range(17, 20):
    ORIGINAL_GRID[18][x] = "BP"

for x in range(18, 21):
    ORIGINAL_GRID[14][x] = "GP"

# Gems
ORIGINAL_GRID[17][11] = ORIGINAL_GRID[8][10] = "RB"
ORIGINAL_GRID[17][18] = ORIGINAL_GRID[10][20] = "BB"

# Goals
ORIGINAL_GRID[2][26] = "FG"
ORIGINAL_GRID[2][28] = "WG"

# Slides
for x in range(1, 4):
    ORIGINAL_GRID[9][x] = "YS"  # Yellow slides

for x in range(26, 29):
    ORIGINAL_GRID[6][x] = "PS"  # Purple slides

# Buttons
ORIGINAL_GRID[8][6] = ORIGINAL_GRID[12][9] = "YB"  # Yellow buttons
ORIGINAL_GRID[8][14] = ORIGINAL_GRID[5][22] = "PB"  # Purple buttons

# Assigning 1's for walls and platforms
for i in range(0, 30):
    ORIGINAL_GRID[0][i] = "1"
    ORIGINAL_GRID[19][i] = "1"

for i in range(0, 20):
    ORIGINAL_GRID[i][0] = "1"
    ORIGINAL_GRID[i][29] = "1"

ORIGINAL_GRID[1][1] = ORIGINAL_GRID[1][2] = "1"
for i in range(4, 29):
    ORIGINAL_GRID[3][i] = "1"

ORIGINAL_GRID[4][1] = ORIGINAL_GRID[4][2] = ORIGINAL_GRID[5][1] = ORIGINAL_GRID[5][2] = "1"

for i in range(1, 26):
    ORIGINAL_GRID[6][i] = "1"

ORIGINAL_GRID[7][18] = ORIGINAL_GRID[7][19] = ORIGINAL_GRID[7][20] = "1"
ORIGINAL_GRID[8][18] = ORIGINAL_GRID[8][19] = ORIGINAL_GRID[8][20] = "1"

for i in range(4, 16):
    ORIGINAL_GRID[9][i] = "1"

for i in range(17, 22):
    ORIGINAL_GRID[11][i] = "1"

for i in range(22, 29):
    ORIGINAL_GRID[10][i] = "1"

ORIGINAL_GRID[10][16] = "1"
ORIGINAL_GRID[11][26] = ORIGINAL_GRID[11][27] = ORIGINAL_GRID[11][28] = "1"
ORIGINAL_GRID[12][26] = ORIGINAL_GRID[12][27] = ORIGINAL_GRID[12][28] = "1"

for i in range(1, 13):
    ORIGINAL_GRID[13][i] = "1"

ORIGINAL_GRID[14][13] = "1"

for i in range(14, 27):
    ORIGINAL_GRID[15][i] = "1"

for i in range(1, 7):
    ORIGINAL_GRID[16][i] = "1"

ORIGINAL_GRID[17][27] = ORIGINAL_GRID[17][28] = ORIGINAL_GRID[18][27] = ORIGINAL_GRID[18][28] = "1"

# Copy original grid to working grid
grid = [[cell for cell in row] for row in ORIGINAL_GRID]

# Track slide positions and active buttons
yellow_slides_positions = []
purple_slides_positions = []
yellow_buttons_active = False
purple_buttons_active = False
slide_animation_counter = 0
SLIDE_ANIMATION_SPEED = 15  # Controls how fast slides move

# Initialize slide positions
def initialize_slides():
    global yellow_slides_positions, purple_slides_positions
    yellow_slides_positions = []
    purple_slides_positions = []
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if ORIGINAL_GRID[y][x] == "YS":
                yellow_slides_positions.append((y, x))
            elif ORIGINAL_GRID[y][x] == "PS":
                purple_slides_positions.append((y, x))

# Create grid maps for environment features
def create_grid_map(feature_type):
    """Create a boolean grid map for different environmental features"""
    grid_map = np.zeros((GRID_HEIGHT, GRID_WIDTH), dtype=bool)
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            cell = grid[y][x]
            if feature_type == 'wall' and cell == "1":
                grid_map[y][x] = True
            elif feature_type == 'red_poison' and cell == "RP":
                grid_map[y][x] = True
            elif feature_type == 'blue_poison' and cell == "BP":
                grid_map[y][x] = True
            elif feature_type == 'green_poison' and cell == "GP":
                grid_map[y][x] = True
            elif feature_type == 'red_gem' and cell == "RB":
                grid_map[y][x] = True
            elif feature_type == 'blue_gem' and cell == "BB":
                grid_map[y][x] = True
            elif feature_type == 'fire_goal' and cell == "FG":
                grid_map[y][x] = True
            elif feature_type == 'water_goal' and cell == "WG":
                grid_map[y][x] = True
            elif feature_type == 'yellow_button' and cell == "YB":
                grid_map[y][x] = True
            elif feature_type == 'purple_button' and cell == "PB":
                grid_map[y][x] = True
            elif feature_type == 'yellow_slide' and cell == "YS":
                grid_map[y][x] = True
            elif feature_type == 'purple_slide' and cell == "PS":
                grid_map[y][x] = True
    
    return grid_map

# Create all grid maps
def initialize_environment_maps():
    wall_map = create_grid_map('wall')
    red_poison_map = create_grid_map('red_poison')
    blue_poison_map = create_grid_map('blue_poison')
    green_poison_map = create_grid_map('green_poison')
    red_gem_map = create_grid_map('red_gem')
    blue_gem_map = create_grid_map('blue_gem')
    fire_goal_map = create_grid_map('fire_goal')
    water_goal_map = create_grid_map('water_goal')
    yellow_button_map = create_grid_map('yellow_button')
    purple_button_map = create_grid_map('purple_button')
    yellow_slide_map = create_grid_map('yellow_slide')
    purple_slide_map = create_grid_map('purple_slide')
    
    # Create combined maps
    poison_map = red_poison_map | blue_poison_map | green_poison_map
    
    # Create environment dictionary
    environment = {
        'wall': wall_map,
        'poison': poison_map,
        'red_poison': red_poison_map,
        'blue_poison': blue_poison_map,
        'green_poison': green_poison_map,
        'red_gems': red_gem_map.copy(),
        'blue_gems': blue_gem_map.copy(),
        'fire_goal': fire_goal_map,
        'water_goal': water_goal_map,
        'yellow_button': yellow_button_map,
        'purple_button': purple_button_map,
        'yellow_slide': yellow_slide_map,
        'purple_slide': purple_slide_map
    }
    
    return environment

# Agent position setup
def get_initial_positions():
    fireboy_pos = None
    watergirl_pos = None
    green_pos = None
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "F":
                fireboy_pos = (y, x)
            elif grid[y][x] == "W":
                watergirl_pos = (y, x)
            elif grid[y][x] == "G":
                green_pos = (y, x)
    
    return fireboy_pos, watergirl_pos, green_pos

# Draw Function
def draw_grid():
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            tile = grid[y][x]
            color = COLORS.get(tile, BASE_COLOR)

            # Agents: Heart shape
            if tile in ["F", "W", "G"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + 6),
                    (x*TILE_SIZE + 6, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE - 6, y*TILE_SIZE + TILE_SIZE//2),
                ])
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)
                pygame.draw.circle(screen, color, (x*TILE_SIZE + 2*TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)

            # Gems: Rhombus
            elif tile in ["RB", "BB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2)
                ])

            # Buttons: Circle
            elif tile in ["YB", "PB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE//2), TILE_SIZE//4)

            # Goals: Triangle
            elif tile in ["FG", "WG"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE)
                ])

            # Slides: Full tile fill
            elif tile in ["YS", "PS"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                
            # Poison: Now blocks with specific colors
            elif tile in ["RP", "BP", "GP"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Default or walls
            else:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Draw the black border
            pygame.draw.rect(screen, (0, 0, 0), (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE), 1)

# Define the policy network for both agents
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        return torch.softmax(self.fc2(x), dim=-1)

# Initialize policies for both agents (Fireboy and Watergirl)
# Increased input size to account for more state information
policy_f = PolicyNetwork(input_size=10, output_size=NUM_ACTIONS)
policy_w = PolicyNetwork(input_size=10, output_size=NUM_ACTIONS)
policy_g = PolicyNetwork(input_size=10, output_size=NUM_ACTIONS)  # Green agent policy

# Optimizers for policies
optimizer_f = optim.Adam(policy_f.parameters(), lr=0.001)
optimizer_w = optim.Adam(policy_w.parameters(), lr=0.001)
optimizer_g = optim.Adam(policy_g.parameters(), lr=0.001)

# Initialize agent state
def reset_environment():
    # Reset grid to original state first
    global grid, yellow_buttons_active, purple_buttons_active, slide_animation_counter
    grid = [[cell for cell in row] for row in ORIGINAL_GRID]
    yellow_buttons_active = False
    purple_buttons_active = False
    slide_animation_counter = 0
    
    # Initialize slides
    initialize_slides()
    
    # Get fresh positions from the reset grid
    fireboy_pos, watergirl_pos, green_pos = get_initial_positions()
    
    # Create agent dictionaries with position and other properties
    red_agent = {'pos': fireboy_pos, 'color': 'red', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    blue_agent = {'pos': watergirl_pos, 'color': 'blue', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    green_agent = {'pos': green_pos, 'color': 'green', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    
    # Create fresh environment maps from the reset grid
    environment = initialize_environment_maps()
    
    return red_agent, blue_agent, green_agent, environment

# Handle slide movement
def update_slides():
    global yellow_slides_positions, purple_slides_positions, grid, slide_animation_counter
    
    # Clear previous slide positions from grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "YS" or grid[y][x] == "PS":
                grid[y][x] = "0"
    
    # Update yellow slides
    new_yellow_positions = []
    for y, x in yellow_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if yellow_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_yellow_positions.append((current_y, x))
                grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((target_y, x))
                grid[target_y][x] = "YS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in yellow_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in yellow_slides_positions):
                    current_pos = [pos for pos in yellow_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_yellow_positions.append((current_y, x))
                    grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((y, x))
                grid[y][x] = "YS"
    
    # Update purple slides
    new_purple_positions = []
    for y, x in purple_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if purple_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_purple_positions.append((current_y, x))
                grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((target_y, x))
                grid[target_y][x] = "PS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in purple_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in purple_slides_positions):
                    current_pos = [pos for pos in purple_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_purple_positions.append((current_y, x))
                    grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((y, x))
                grid[y][x] = "PS"
    
    # Update slide positions for next frame
    yellow_slides_positions = new_yellow_positions
    purple_slides_positions = new_purple_positions
    
    # Update animation counter
    if slide_animation_counter < SLIDE_ANIMATION_SPEED:
        slide_animation_counter += 1

# Get state representation for policy networks
def get_state_representation(agent, other_agent, environment):
    y, x = agent['pos']
    oy, ox = other_agent['pos']
    
    # Simple state representation: relative positions and nearby features
    state = [
        y / GRID_HEIGHT,                  # Normalized y position
        x / GRID_WIDTH,                   # Normalized x position
        (oy - y) / GRID_HEIGHT,           # Relative y position to other agent
        (ox - x) / GRID_WIDTH,            # Relative x position to other agent
        1.0 if agent['is_pressing_button'] else 0.0,  # Is agent pressing button
        1.0 if other_agent['is_pressing_button'] else 0.0,  # Is other agent pressing button
        1.0 if agent['goal_reached'] else 0.0,  # Has agent reached goal
        1.0 if other_agent['goal_reached'] else 0.0,  # Has other agent reached goal
        agent['collected_gems'] / 2.0,    # Normalized collected gems (assuming max 2)
        other_agent['collected_gems'] / 2.0  # Other agent's collected gems
    ]
    
    # Use torch.tensor with requires_grad=True for policy gradients
    return torch.tensor(state, dtype=torch.float, requires_grad=True)

# Check if agent is on a button and update button state
def check_button_press(agent, environment):
    global yellow_buttons_active, purple_buttons_active, slide_animation_counter
    
    y, x = agent['pos']
    was_pressing = agent['is_pressing_button']
    
    if agent['color'] == 'red' and environment['yellow_button'][y][x]:
        agent['is_pressing_button'] = True
        yellow_buttons_active = True
        slide_animation_counter = 0  # Reset animation counter when button is pressed
    elif agent['color'] == 'blue' and environment['purple_button'][y][x]:
        agent['is_pressing_button'] = True
        purple_buttons_active = True
        slide_animation_counter = 0  # Reset animation counter when button is pressed
    else:
        agent['is_pressing_button'] = False
        
    # If the agent stopped pressing a button, reset that button state
    if was_pressing and not agent['is_pressing_button']:
        if agent['color'] == 'red':
            yellow_buttons_active = False
            slide_animation_counter = 0
        elif agent['color'] == 'blue':
            purple_buttons_active = False
            slide_animation_counter = 0

# Check if agent has collected a gem - now color-specific and gems disappear after collection
def check_gem_collection(agent, environment):
    y, x = agent['pos']
    if agent['color'] == 'red' and environment['red_gems'][y][x]:
        environment['red_gems'][y][x] = False  # Remove collected gem
        grid[y][x] = "0"  # Update visual grid to remove the gem
        agent['collected_gems'] += 1
        return True
    elif agent['color'] == 'blue' and environment['blue_gems'][y][x]:
        environment['blue_gems'][y][x] = False  # Remove collected gem
        grid[y][x] = "0"  # Update visual grid to remove the gem
        agent['collected_gems'] += 1
        return True
    return False

# Check if agent has reached their goal
def check_goal_reached(agent, environment):
    y, x = agent['pos']
    if agent['color'] == 'red' and environment['fire_goal'][y][x]:
        agent['goal_reached'] = True
        return True
    elif agent['color'] == 'blue' and environment['water_goal'][y][x]:
        agent['goal_reached'] = True
        return True
    return False

# Check if agent is in poison - now with specific fatal poison types!
def check_poison(agent, environment):
    y, x = agent['pos']
    if agent['color'] == 'red' and environment['blue_poison'][y][x]:
        agent['died'] = True  # Red agent dies in blue poison
        return True
    elif agent['color'] == 'blue' and environment['red_poison'][y][x]:
        agent['died'] = True  # Blue agent dies in red poison
        return True
    elif environment['green_poison'][y][x]:
        agent['died'] = True  # Any agent dies in green poison
        return True
    return False

# Simulate the training process
def calculate_reward(agent, other_agent, green_agent, environment, did_move):
    reward = 0
    y, x = agent['pos']
    
    # Check if the agent is in walls (shouldn't happen with proper movement)
    if environment['wall'][y][x]:
        return -5  # Invalid move penalty
    
    # Check if the agent has fallen into poison
    if check_poison(agent, environment):
        reward -= 10  # Falling into poison is -10
    
    # Check if the agent presses a button
    check_button_press(agent, environment)
    
    # Reward for cooperation - both agents pressing buttons
    if agent['is_pressing_button'] and other_agent['is_pressing_button']:
        reward += 5  # Cooperation reward
    
    # Check if the agent collects a gem
    if check_gem_collection(agent, environment):
        reward += 3  # Collecting gem reward
    
    # Check if the agent reaches the goal
    if check_goal_reached(agent, environment):
        reward += 10  # Reaching goal reward
    
    # Reward for movement (to encourage exploration)
    if did_move:
        reward += 0.1
    
    # Distance-based rewards (getting closer to goal)
    if agent['color'] == 'red':
        goal_positions = np.where(environment['fire_goal'])
    else:
        goal_positions = np.where(environment['water_goal'])
    
    if len(goal_positions[0]) > 0:
        goal_y, goal_x = goal_positions[0][0], goal_positions[1][0]
        distance = np.sqrt((y - goal_y)**2 + (x - goal_x)**2)
        reward += 0.1 * (1.0 / (distance + 1.0))  # Small reward for being closer to goal
    
    return reward

# Define actions 
actions = ["up", "down", "left", "right"]

# Helper function to apply actions and update positions
def apply_action(agent, action, environment):
    y, x = agent['pos']
    new_y, new_x = y, x
    did_move = False
    
    if action == "up" and y > 0 and not environment['wall'][y-1][x]:
        new_y -= 1
        did_move = True
    elif action == "down" and y < GRID_HEIGHT-1 and not environment['wall'][y+1][x]:
        new_y += 1
        did_move = True
    elif action == "left" and x > 0 and not environment['wall'][y][x-1]:
        new_x -= 1
        did_move = True
    elif action == "right" and x < GRID_WIDTH-1 and not environment['wall'][y][x+1]:
        new_x += 1
        did_move = True
    
    agent['pos'] = (new_y, new_x)
    return did_move

# Update grid based on agent positions
def update_grid(red_agent, blue_agent, green_agent):
    # First reset all agent positions in the grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] in ["F", "W", "G"]:
                # Check if there's something from the original grid to restore
                if ORIGINAL_GRID[y][x] not in ["F", "W", "G", "RB", "BB"]:  # Excluding gems
                    grid[y][x] = ORIGINAL_GRID[y][x]
                else:
                    grid[y][x] = "0"
    
    # Set new agent positions
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    gy, gx = green_agent['pos']
    
    # Only place agents if they haven't died
    if not red_agent['died']:
        grid[ry][rx] = "F"
    if not blue_agent['died']:
        grid[by][bx] = "W"
    grid[gy][gx] = "G"

# Select action based on policy network
def select_action(policy, state):
    # Don't use no_grad here, we need the gradients for policy gradient
    probs = policy(state)
    action_idx = torch.multinomial(probs, 1).item()
    return actions[action_idx], probs[action_idx]

# Training loop for a single step
def train_agents(red_agent, blue_agent, green_agent, environment):
    # Get state representations
    red_state = get_state_representation(red_agent, blue_agent, environment)
    blue_state = get_state_representation(blue_agent, red_agent, environment)
    
    # For green agent, we'll use average of red and blue states
    green_state = (red_state + blue_state) / 2
    
    # Select actions and save log probabilities for policy gradient
    red_action, red_prob = select_action(policy_f, red_state)
    blue_action, blue_prob = select_action(policy_w, blue_state)
    green_action, green_prob = select_action(policy_g, green_state)
    
    # Apply actions to the grid
    red_did_move = apply_action(red_agent, red_action, environment)
    blue_did_move = apply_action(blue_agent, blue_action, environment)
    green_did_move = apply_action(green_agent, green_action, environment)
    
    # Update the grid with new positions
    update_grid(red_agent, blue_agent, green_agent)
    
    # Calculate rewards
    red_reward = calculate_reward(red_agent, blue_agent, green_agent, environment, red_did_move)
    blue_reward = calculate_reward(blue_agent, red_agent, green_agent, environment, blue_did_move)
    green_reward = calculate_reward(green_agent, red_agent, blue_agent, environment, green_did_move)
    
    # Convert rewards to tensors
    red_reward_tensor = torch.tensor([red_reward], dtype=torch.float)
    blue_reward_tensor = torch.tensor([blue_reward], dtype=torch.float)
    green_reward_tensor = torch.tensor([green_reward], dtype=torch.float)
    
    # Calculate loss (policy gradient)
    red_loss = -torch.log(red_prob) * red_reward_tensor
    blue_loss = -torch.log(blue_prob) * blue_reward_tensor
    green_loss = -torch.log(green_prob) * green_reward_tensor
    
    # Backpropagate and update policies
    optimizer_f.zero_grad()
    red_loss.backward()
    optimizer_f.step()
    
    optimizer_w.zero_grad()
    blue_loss.backward()
    optimizer_w.step()
    
    optimizer_g.zero_grad()
    green_loss.backward()
    optimizer_g.step()
    
    return red_reward, blue_reward, green_reward

# Main game loop
def main():
    global yellow_buttons_active, purple_buttons_active, slide_animation_counter
    
    # Initialize slides
    initialize_slides()
    
    # Reset environment and get initial states
    red_agent, blue_agent, green_agent, environment = reset_environment()
    
    episode_rewards_f = []
    episode_rewards_w = []
    episode_rewards_g = []
    
    running = True
    steps = 0
    max_steps = 1000  # Maximum steps per episode
    
    # Text setup for displaying stats
    font = pygame.font.Font(None, 24)
    
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_r:  # Reset on 'R' key
                    red_agent, blue_agent, green_agent, environment = reset_environment()
                    steps = 0
                    episode_rewards_f = []
                    episode_rewards_w = []
                    episode_rewards_g = []
        
        # Check if episode is done
        episode_done = (
            red_agent['goal_reached'] and blue_agent['goal_reached'] or
            red_agent['died'] or blue_agent['died'] or
            steps >= max_steps
        )
        
        if episode_done:
            red_agent, blue_agent, green_agent, environment = reset_environment()
            steps = 0
            episode_rewards_f = []
            episode_rewards_w = []
            episode_rewards_g = []
        
        # Train the agents and get rewards
        red_reward, blue_reward, green_reward = train_agents(red_agent, blue_agent, green_agent, environment)
        
        # Store rewards
        episode_rewards_f.append(red_reward)
        episode_rewards_w.append(blue_reward)
        episode_rewards_g.append(green_reward)
        
        # Update slides positions
        update_slides()
        
        # Clear the screen
        screen.fill((0, 0, 0))
        
        # Draw the grid
        draw_grid()
        
        # Display stats
        stats_text = [
            f"Step: {steps}",
            # f"Red: Gems={red_agent['collected_gems']} Goal={red_agent['goal_reached']} Died={red_agent['died']}",
            # f"Blue: Gems={blue_agent['collected_gems']} Goal={blue_agent['goal_reached']} Died={blue_agent['died']}",
            # f"Green: Gems={green_agent['collected_gems']}"
        ]
        
        y_offset = 10
        for text in stats_text:
            text_surface = font.render(text, True, (255, 255, 255))
            screen.blit(text_surface, (10, y_offset))
            y_offset += 25
        
        # Update the display
        pygame.display.flip()
        
        # Increment step counter
        steps += 1
        
        # Control game speed
        clock.tick(FPS)
    
    pygame.quit()
    sys.exit()

if __name__ == "__main__":
    main()

SystemExit: 