In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import pygame
import sys
import numpy as np
from collections import deque

# Constants
TILE_SIZE = 24
GRID_WIDTH = 30
GRID_HEIGHT = 20
SCREEN_WIDTH = TILE_SIZE * GRID_WIDTH
SCREEN_HEIGHT = TILE_SIZE * GRID_HEIGHT
FPS = 10  # Increased FPS for smoother animation
NUM_ACTIONS = 4  # Number of actions (up, down, left, right)

# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Fireboy and Watergirl Grid World")
clock = pygame.time.Clock()

# Colors
COLORS = {
    "0": (120, 66, 18),   # Brown - Empty
    "1": (160, 82, 45),   # Light Brown - Platform
    "F": (255, 0, 0),     # Fireboy - Red Heart
    "W": (0, 0, 255),     # Watergirl - Blue Heart
    "G": (0, 255, 0),     # Green Agent
    "RP": (255, 0, 0),  # Red Poison - Now a block
    "BP": (0, 0, 255),  # Blue Poison - Now a block
    "GP": (0, 255, 0),  # Green Poison - Now a block
    "FG": (255, 0, 0),    # Fireboy Goal
    "WG": (0, 0, 255),    # Watergirl Goal
    "RB": (255, 0, 0),    # Red Gem
    "BB": (0, 0, 255),    # Blue Gem
    "YS": (255, 255, 0),  # Yellow Slide
    "PS": (128, 0, 128),  # Purple Slide
    "YB": (255, 255, 0),  # Yellow Button
    "PB": (128, 0, 128),  # Purple Button
}

BASE_COLOR = COLORS["0"]

# Original grid layout - THIS WILL NEVER BE MODIFIED
ORIGINAL_GRID = [["0" for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)]

# Place sample tiles
ORIGINAL_GRID[18][1] = "F"    # Fireboy
ORIGINAL_GRID[15][1] = "W"    # Watergirl
ORIGINAL_GRID[2][19] = "G"    # Green agent

# Poison tiles
for x in range(10, 13):
    ORIGINAL_GRID[18][x] = "RP"

for x in range(17, 20):
    ORIGINAL_GRID[18][x] = "BP"

for x in range(18, 21):
    ORIGINAL_GRID[14][x] = "GP"

# Gems
ORIGINAL_GRID[17][11] = ORIGINAL_GRID[8][10] = "RB"
ORIGINAL_GRID[17][18] = ORIGINAL_GRID[10][20] = "BB"

# Goals
ORIGINAL_GRID[2][26] = "FG"
ORIGINAL_GRID[2][28] = "WG"

# Slides
for x in range(1, 4):
    ORIGINAL_GRID[9][x] = "YS"  # Yellow slides

for x in range(26, 29):
    ORIGINAL_GRID[6][x] = "PS"  # Purple slides

# Buttons
ORIGINAL_GRID[8][6] = ORIGINAL_GRID[12][9] = "YB"  # Yellow buttons
ORIGINAL_GRID[8][14] = ORIGINAL_GRID[5][22] = "PB"  # Purple buttons

# Assigning 1's for walls and platforms
for i in range(0, 30):
    ORIGINAL_GRID[0][i] = "1"
    ORIGINAL_GRID[19][i] = "1"

for i in range(0, 20):
    ORIGINAL_GRID[i][0] = "1"
    ORIGINAL_GRID[i][29] = "1"

ORIGINAL_GRID[1][1] = ORIGINAL_GRID[1][2] = "1"
for i in range(4, 29):
    ORIGINAL_GRID[3][i] = "1"

ORIGINAL_GRID[4][1] = ORIGINAL_GRID[4][2] = ORIGINAL_GRID[5][1] = ORIGINAL_GRID[5][2] = "1"

for i in range(1, 26):
    ORIGINAL_GRID[6][i] = "1"

ORIGINAL_GRID[7][18] = ORIGINAL_GRID[7][19] = ORIGINAL_GRID[7][20] = "1"
ORIGINAL_GRID[8][18] = ORIGINAL_GRID[8][19] = ORIGINAL_GRID[8][20] = "1"

for i in range(4, 16):
    ORIGINAL_GRID[9][i] = "1"

for i in range(17, 22):
    ORIGINAL_GRID[11][i] = "1"

for i in range(22, 29):
    ORIGINAL_GRID[10][i] = "1"

ORIGINAL_GRID[10][16] = "1"
ORIGINAL_GRID[11][26] = ORIGINAL_GRID[11][27] = ORIGINAL_GRID[11][28] = "1"
ORIGINAL_GRID[12][26] = ORIGINAL_GRID[12][27] = ORIGINAL_GRID[12][28] = "1"

for i in range(1, 13):
    ORIGINAL_GRID[13][i] = "1"

ORIGINAL_GRID[14][13] = "1"

for i in range(14, 27):
    ORIGINAL_GRID[15][i] = "1"

for i in range(1, 7):
    ORIGINAL_GRID[16][i] = "1"

ORIGINAL_GRID[17][27] = ORIGINAL_GRID[17][28] = ORIGINAL_GRID[18][27] = ORIGINAL_GRID[18][28] = "1"

# Copy original grid to working grid
grid = [[cell for cell in row] for row in ORIGINAL_GRID]

# Track slide positions and active buttons
yellow_slides_positions = []
purple_slides_positions = []
yellow_buttons_active = False
purple_buttons_active = False
slide_animation_counter = 0
SLIDE_ANIMATION_SPEED = 15  # Controls how fast slides move

# Initialize slide positions
def initialize_slides():
    global yellow_slides_positions, purple_slides_positions
    yellow_slides_positions = []
    purple_slides_positions = []
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if ORIGINAL_GRID[y][x] == "YS":
                yellow_slides_positions.append((y, x))
            elif ORIGINAL_GRID[y][x] == "PS":
                purple_slides_positions.append((y, x))

# Create grid maps for environment features
def create_grid_map(feature_type):
    """Create a boolean grid map for different environmental features"""
    grid_map = np.zeros((GRID_HEIGHT, GRID_WIDTH), dtype=bool)
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            cell = grid[y][x]
            if feature_type == 'wall' and cell == "1":
                grid_map[y][x] = True
            elif feature_type == 'red_poison' and cell == "RP":
                grid_map[y][x] = True
            elif feature_type == 'blue_poison' and cell == "BP":
                grid_map[y][x] = True
            elif feature_type == 'green_poison' and cell == "GP":
                grid_map[y][x] = True
            elif feature_type == 'red_gem' and cell == "RB":
                grid_map[y][x] = True
            elif feature_type == 'blue_gem' and cell == "BB":
                grid_map[y][x] = True
            elif feature_type == 'fire_goal' and cell == "FG":
                grid_map[y][x] = True
            elif feature_type == 'water_goal' and cell == "WG":
                grid_map[y][x] = True
            elif feature_type == 'yellow_button' and cell == "YB":
                grid_map[y][x] = True
            elif feature_type == 'purple_button' and cell == "PB":
                grid_map[y][x] = True
            elif feature_type == 'yellow_slide' and cell == "YS":
                grid_map[y][x] = True
            elif feature_type == 'purple_slide' and cell == "PS":
                grid_map[y][x] = True
    
    return grid_map

# Create all grid maps
def initialize_environment_maps():
    wall_map = create_grid_map('wall')
    red_poison_map = create_grid_map('red_poison')
    blue_poison_map = create_grid_map('blue_poison')
    green_poison_map = create_grid_map('green_poison')
    red_gem_map = create_grid_map('red_gem')
    blue_gem_map = create_grid_map('blue_gem')
    fire_goal_map = create_grid_map('fire_goal')
    water_goal_map = create_grid_map('water_goal')
    yellow_button_map = create_grid_map('yellow_button')
    purple_button_map = create_grid_map('purple_button')
    yellow_slide_map = create_grid_map('yellow_slide')
    purple_slide_map = create_grid_map('purple_slide')
    
    # Create combined maps
    poison_map = red_poison_map | blue_poison_map | green_poison_map
    
    # Create environment dictionary
    environment = {
        'wall': wall_map,
        'poison': poison_map,
        'red_poison': red_poison_map,
        'blue_poison': blue_poison_map,
        'green_poison': green_poison_map,
        'red_gems': red_gem_map.copy(),
        'blue_gems': blue_gem_map.copy(),
        'fire_goal': fire_goal_map,
        'water_goal': water_goal_map,
        'yellow_button': yellow_button_map,
        'purple_button': purple_button_map,
        'yellow_slide': yellow_slide_map,
        'purple_slide': purple_slide_map
    }
    
    return environment

# Agent position setup
def get_initial_positions():
    fireboy_pos = None
    watergirl_pos = None
    green_pos = None
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "F":
                fireboy_pos = (y, x)
            elif grid[y][x] == "W":
                watergirl_pos = (y, x)
            elif grid[y][x] == "G":
                green_pos = (y, x)
    
    return fireboy_pos, watergirl_pos, green_pos

# Draw Function
def draw_grid():
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            tile = grid[y][x]
            color = COLORS.get(tile, BASE_COLOR)

            # Agents: Heart shape
            if tile in ["F", "W", "G"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + 6),
                    (x*TILE_SIZE + 6, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE - 6, y*TILE_SIZE + TILE_SIZE//2),
                ])
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)
                pygame.draw.circle(screen, color, (x*TILE_SIZE + 2*TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)

            # Gems: Rhombus
            elif tile in ["RB", "BB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2)
                ])

            # Buttons: Circle
            elif tile in ["YB", "PB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE//2), TILE_SIZE//4)

            # Goals: Triangle
            elif tile in ["FG", "WG"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE)
                ])

            # Slides: Full tile fill
            elif tile in ["YS", "PS"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                
            # Poison: Now blocks with specific colors
            elif tile in ["RP", "BP", "GP"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Default or walls
            else:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Draw the black border
            pygame.draw.rect(screen, (0, 0, 0), (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE), 1)

# Define the policy network for cooperative agents
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, output_size)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        return torch.softmax(self.fc2(x), dim=-1)

# Define the DQN for the green agent
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_size)
    
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.fc3(x)  # Q-values (not softmax)

# Initialize policies for cooperative agents (Fireboy and Watergirl)
policy_f = PolicyNetwork(input_size=10, output_size=NUM_ACTIONS)
policy_w = PolicyNetwork(input_size=10, output_size=NUM_ACTIONS)

# Initialize DQN for the green adversarial agent with larger input size to include more opponent information
# Increased input size for more complete state representation
dqn = DQN(input_size=20, output_size=NUM_ACTIONS)
target_dqn = DQN(input_size=20, output_size=NUM_ACTIONS)  # Target network for stable learning
target_dqn.load_state_dict(dqn.state_dict())  # Initially the same as the online network

# Optimizers
optimizer_f = optim.Adam(policy_f.parameters(), lr=0.001)
optimizer_w = optim.Adam(policy_w.parameters(), lr=0.001)
optimizer_dqn = optim.Adam(dqn.parameters(), lr=0.0005)  # Lower learning rate for stability

# DQN hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99  # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 10000
TARGET_UPDATE = 10  # How often to update target network
epsilon = EPSILON_START

# Experience replay memory
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    
    def push(self, state, action, next_state, reward, done):
        self.memory.append((state, action, next_state, reward, done))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

# Initialize replay memory
memory = ReplayMemory(10000)

# Initialize agent state
def reset_environment():
    # Reset grid to original state first
    global grid, yellow_buttons_active, purple_buttons_active, slide_animation_counter
    grid = [[cell for cell in row] for row in ORIGINAL_GRID]
    yellow_buttons_active = False
    purple_buttons_active = False
    slide_animation_counter = 0
    
    # Initialize slides
    initialize_slides()
    
    # Get fresh positions from the reset grid
    fireboy_pos, watergirl_pos, green_pos = get_initial_positions()
    
    # Create agent dictionaries with position and other properties
    red_agent = {'pos': fireboy_pos, 'color': 'red', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    blue_agent = {'pos': watergirl_pos, 'color': 'blue', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    green_agent = {'pos': green_pos, 'color': 'green', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    
    # Create fresh environment maps from the reset grid
    environment = initialize_environment_maps()
    
    return red_agent, blue_agent, green_agent, environment

# Handle slide movement
def update_slides():
    global yellow_slides_positions, purple_slides_positions, grid, slide_animation_counter
    
    # Clear previous slide positions from grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "YS" or grid[y][x] == "PS":
                grid[y][x] = "0"
    
    # Update yellow slides
    new_yellow_positions = []
    for y, x in yellow_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if yellow_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_yellow_positions.append((current_y, x))
                grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((target_y, x))
                grid[target_y][x] = "YS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in yellow_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in yellow_slides_positions):
                    current_pos = [pos for pos in yellow_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_yellow_positions.append((current_y, x))
                    grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((y, x))
                grid[y][x] = "YS"
    
    # Update purple slides
    new_purple_positions = []
    for y, x in purple_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if purple_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_purple_positions.append((current_y, x))
                grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((target_y, x))
                grid[target_y][x] = "PS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in purple_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in purple_slides_positions):
                    current_pos = [pos for pos in purple_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_purple_positions.append((current_y, x))
                    grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((y, x))
                grid[y][x] = "PS"
    
    # Update slide positions for next frame
    yellow_slides_positions = new_yellow_positions
    purple_slides_positions = new_purple_positions
    
    # Update animation counter
    if slide_animation_counter < SLIDE_ANIMATION_SPEED:
        slide_animation_counter += 1

# Get state representation for policy networks (cooperative agents)
def get_state_representation(agent, other_agent, environment):
    y, x = agent['pos']
    oy, ox = other_agent['pos']
    
    # Simple state representation: relative positions and nearby features
    state = [
        y / GRID_HEIGHT,                  # Normalized y position
        x / GRID_WIDTH,                   # Normalized x position
        (oy - y) / GRID_HEIGHT,           # Relative y position to other agent
        (ox - x) / GRID_WIDTH,            # Relative x position to other agent
        1.0 if agent['is_pressing_button'] else 0.0,  # Is agent pressing button
        1.0 if other_agent['is_pressing_button'] else 0.0,  # Is other agent pressing button
        1.0 if agent['goal_reached'] else 0.0,  # Has agent reached goal
        1.0 if other_agent['goal_reached'] else 0.0,  # Has other agent reached goal
        agent['collected_gems'] / 2.0,    # Normalized collected gems (assuming max 2)
        other_agent['collected_gems'] / 2.0  # Other agent's collected gems
    ]
    
    # Use torch.tensor with requires_grad=True for policy gradients
    return torch.tensor(state, dtype=torch.float, requires_grad=True)

# Enhanced state representation for DQN (green agent)
def get_dqn_state_representation(green_agent, red_agent, blue_agent, environment):
    gy, gx = green_agent['pos']
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    
    # Find locations of goals
    fire_goal_positions = np.where(environment['fire_goal'])
    water_goal_positions = np.where(environment['water_goal'])
    fgy, fgx = fire_goal_positions[0][0], fire_goal_positions[1][0] if len(fire_goal_positions[0]) > 0 else (0, 0)
    wgy, wgx = water_goal_positions[0][0], water_goal_positions[1][0] if len(water_goal_positions[0]) > 0 else (0, 0)
    
    # Calculate distances to fire/water agents
    dist_to_red = np.sqrt((gy - ry)**2 + (gx - rx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_blue = np.sqrt((gy - by)**2 + (gx - bx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Calculate distances from fire/water to their goals
    red_to_goal = np.sqrt((ry - fgy)**2 + (rx - fgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    blue_to_goal = np.sqrt((by - wgy)**2 + (bx - wgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Create a comprehensive state representation
    state = [
        gy / GRID_HEIGHT,                  # Green agent y position
        gx / GRID_WIDTH,                   # Green agent x position
        ry / GRID_HEIGHT,                  # Red agent y position
        rx / GRID_WIDTH,                   # Red agent x position
        by / GRID_HEIGHT,                  # Blue agent y position
        bx / GRID_WIDTH,                   # Blue agent x position
        dist_to_red,                       # Distance to red agent
        dist_to_blue,                      # Distance to blue agent
        red_to_goal,                       # Red agent's distance to goal
        blue_to_goal,                      # Blue agent's distance to goal
        1.0 if red_agent['is_pressing_button'] else 0.0,  # Is red agent pressing button
        1.0 if blue_agent['is_pressing_button'] else 0.0,  # Is blue agent pressing button
        1.0 if red_agent['goal_reached'] else 0.0,        # Has red reached goal
        1.0 if blue_agent['goal_reached'] else 0.0,       # Has blue reached goal
        red_agent['collected_gems'] / 2.0,                # Red collected gems
        blue_agent['collected_gems'] / 2.0,               # Blue collected gems
        fgy / GRID_HEIGHT,                 # Fire goal y position
        fgx / GRID_WIDTH,                  # Fire goal x position
        wgy / GRID_HEIGHT,                 # Water goal y position
        wgx / GRID_WIDTH,                  # Water goal x position
    ]
    
    return torch.tensor(state, dtype=torch.float).unsqueeze(0)  # Add batch dimension

# Check if agent is on a button and update button state
def check_button_press(agent, environment):
    global yellow_buttons_active, purple_buttons_active, slide_animation_counter
    
    y, x = agent['pos']
    was_pressing = agent['is_pressing_button']
    
    if agent['color'] == 'red' and environment['yellow_button'][y][x]:
        agent['is_pressing_button'] = True
        yellow_buttons_active = True
        slide_animation_counter = 0  # Reset animation counter when button is pressed
    elif agent['color'] == 'blue' and environment['purple_button'][y][x]:
        agent['is_pressing_button'] = True
        purple_buttons_active = True
        slide_animation_counter = 0  # Reset animation counter when button is pressed
    else:
        agent['is_pressing_button'] = False
        
    # If the agent stopped pressing a button, reset that button state
    if was_pressing and not agent['is_pressing_button']:
        if agent['color'] == 'red':
            yellow_buttons_active = False
            slide_animation_counter = 0
        elif agent['color'] == 'blue':
            purple_buttons_active = False
            slide_animation_counter = 0

# Check if agent has collected a gem - now color-specific and gems disappear after collection
def check_gem_collection(agent, environment):
    y, x = agent['pos']
    if agent['color'] == 'red' and environment['red_gems'][y][x]:
        environment['red_gems'][y][x] = False  # Remove collected gem
        grid[y][x] = "0"  # Update visual grid to remove the gem
        agent['collected_gems'] += 1
        return True
    elif agent['color'] == 'blue' and environment['blue_gems'][y][x]:
        environment['blue_gems'][y][x] = False  # Remove collected gem
        grid[y][x] = "0"  # Update visual grid to remove the gem
        agent['collected_gems'] += 1
        return True
    return False

# Check if agent has reached their goal
def check_goal_reached(agent, environment):
    y, x = agent['pos']
    if agent['color'] == 'red' and environment['fire_goal'][y][x]:
        agent['goal_reached'] = True
        return True
    elif agent['color'] == 'blue' and environment['water_goal'][y][x]:
        agent['goal_reached'] = True
        return True
    return False

# Check if agent is in poison - now with specific fatal poison types!
def check_poison(agent, environment):
    y, x = agent['pos']
    
    # Red agent (Fireboy) dies in blue and green poison
    if agent['color'] == 'red':
        if environment['blue_poison'][y][x] or environment['green_poison'][y][x]:
            agent['died'] = True
            return True
    
    # Blue agent (Watergirl) dies in red and green poison
    elif agent['color'] == 'blue':
        if environment['red_poison'][y][x] or environment['green_poison'][y][x]:
            agent['died'] = True
            return True
    
    # Green agent dies in all poisons
    elif agent['color'] == 'green':
        if environment['poison'][y][x]:
            agent['died'] = True
            return True
            
    return False

# Perform an action and return new state, reward, and done flag
def step_environment(red_agent, blue_agent, green_agent, actions, environment):
    r_action, b_action, g_action = actions
    r_reward, b_reward, g_reward = 0, 0, 0
    
    # Action mapping: 0=up, 1=right, 2=down, 3=left
    action_mapping = [(-1, 0), (0, 1), (1, 0), (0, -1)]  # (dy, dx)
    
    # Execute actions for all agents
    execute_action(red_agent, r_action, action_mapping, environment)
    execute_action(blue_agent, b_action, action_mapping, environment)
    execute_action(green_agent, g_action, action_mapping, environment)
    
    # Update grid with new agent positions
    update_grid_with_agents(red_agent, blue_agent, green_agent)
    
    # Check for button presses
    check_button_press(red_agent, environment)
    check_button_press(blue_agent, environment)
    
    # Update slides based on button states
    update_slides()
    
    # Check for gem collection
    if check_gem_collection(red_agent, environment):
        r_reward += 10  # Reward for collecting a gem
    if check_gem_collection(blue_agent, environment):
        b_reward += 10  # Reward for collecting a gem
    
    # Check for goal reached
    if check_goal_reached(red_agent, environment):
        r_reward += 50  # Big reward for reaching goal
    if check_goal_reached(blue_agent, environment):
        b_reward += 50  # Big reward for reaching goal
    
    # Check for poison
    if check_poison(red_agent, environment):
        r_reward -= 100  # Penalty for dying
    if check_poison(blue_agent, environment):
        b_reward -= 100  # Penalty for dying
    if check_poison(green_agent, environment):
        g_reward -= 10  # Smaller penalty for adversary dying
    
    # Adversarial reward: green gets positive reward when fire/water dies
    if red_agent['died'] or blue_agent['died']:
        g_reward += 50
    
    # Proximity penalty/reward for green agent (rewards being close to heroes)
    gy, gx = green_agent['pos']
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    
    # Calculate distances
    dist_to_red = np.sqrt((gy - ry)**2 + (gx - rx)**2)
    dist_to_blue = np.sqrt((gy - by)**2 + (gx - bx)**2)
    
    # Reward for being close to heroes
    proximity_reward = max(0, 10 - dist_to_red) + max(0, 10 - dist_to_blue)
    g_reward += proximity_reward * 0.5
    
    # Cooperative reward: shared between red and blue
    team_reward = r_reward + b_reward
    r_reward = team_reward
    b_reward = team_reward
    
    # Check if game is over
    game_over = (red_agent['goal_reached'] and blue_agent['goal_reached']) or \
                red_agent['died'] or blue_agent['died'] or \
                green_agent['died']
    
    return (r_reward, b_reward, g_reward), game_over

# Execute an action for a single agent
def execute_action(agent, action, action_mapping, environment):
    if agent['goal_reached'] or agent['died']:
        return  # Don't move if agent is done
    
    y, x = agent['pos']
    dy, dx = action_mapping[action]
    new_y, new_x = y + dy, x + dx
    
    # Check if the move is valid (not a wall)
    if 0 <= new_y < GRID_HEIGHT and 0 <= new_x < GRID_WIDTH and not environment['wall'][new_y][new_x]:
        agent['pos'] = (new_y, new_x)

# Update the grid with agent positions
def update_grid_with_agents(red_agent, blue_agent, green_agent):
    # First clear previous agent positions
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] in ["F", "W", "G"]:
                grid[y][x] = "0"
    
    # Then place agents at their new positions if they're not dead
    if not red_agent['died']:
        ry, rx = red_agent['pos']
        grid[ry][rx] = "F"
    
    if not blue_agent['died']:
        by, bx = blue_agent['pos']
        grid[by][bx] = "W"
    
    if not green_agent['died']:
        gy, gx = green_agent['pos']
        grid[gy][gx] = "G"

# Select action with epsilon-greedy policy for DQN
def select_dqn_action(state, epsilon):
    if random.random() < epsilon:
        return random.randint(0, NUM_ACTIONS - 1)
    else:
        with torch.no_grad():
            # Get Q-values and return action with highest Q-value
            q_values = dqn(state)
            return q_values.max(1)[1].item()

# Train the DQN with batch from replay memory
def optimize_dqn():
    if len(memory) < BATCH_SIZE:
        return
    
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch
    batch = list(zip(*transitions))
    
    # Extract components
    state_batch = torch.cat(batch[0])
    action_batch = torch.tensor(batch[1], dtype=torch.long).unsqueeze(1)
    next_state_batch = torch.cat(batch[2])
    reward_batch = torch.tensor(batch[3], dtype=torch.float).unsqueeze(1)
    done_batch = torch.tensor(batch[4], dtype=torch.float).unsqueeze(1)
    
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
    state_action_values = dqn(state_batch).gather(1, action_batch)
    
    # Compute V(s_{t+1}) for all next states using target network
    with torch.no_grad():
        next_state_values = target_dqn(next_state_batch).max(1)[0].unsqueeze(1)
    
    # Compute the expected Q values: r + γ * max_a' Q(s', a')
    expected_state_action_values = reward_batch + (GAMMA * next_state_values * (1 - done_batch))
    
    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    
    # Optimize the model
    optimizer_dqn.zero_grad()
    loss.backward()
    for param in dqn.parameters():
        param.grad.data.clamp_(-1, 1)  # Gradient clipping
    optimizer_dqn.step()

# Training loop function
def train():
    global epsilon
    episode_rewards = []
    win_rate = []
    frame_count = 0
    
    for episode in range(1000):  # Train for 1000 episodes
        # Reset environment and get initial state
        red_agent, blue_agent, green_agent, environment = reset_environment()
        episode_reward_r = 0
        episode_reward_b = 0
        episode_reward_g = 0
        won = False
        
        for frame in range(100):  # Max 100 frames per episode
            # Select actions
            red_state = get_state_representation(red_agent, blue_agent, environment)
            blue_state = get_state_representation(blue_agent, red_agent, environment)
            green_state = get_dqn_state_representation(green_agent, red_agent, blue_agent, environment)
            
            # For red and blue agents: use policy networks
            red_probs = policy_f(red_state)
            red_action = torch.multinomial(red_probs, 1).item()
            
            blue_probs = policy_w(blue_state)
            blue_action = torch.multinomial(blue_probs, 1).item()
            
            # For green agent: use epsilon-greedy from DQN
            epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
                      math.exp(-1. * frame_count / EPSILON_DECAY)
            green_action = select_dqn_action(green_state, epsilon)
            
            # Take a step in the environment
            (r_reward, b_reward, g_reward), done = step_environment(
                red_agent, blue_agent, green_agent, 
                (red_action, blue_action, green_action), 
                environment
            )
            
            # Get next state
            next_red_state = get_state_representation(red_agent, blue_agent, environment)
            next_blue_state = get_state_representation(blue_agent, red_agent, environment)
            next_green_state = get_dqn_state_representation(green_agent, red_agent, blue_agent, environment)
            
            # Store transition in replay memory for green agent
            memory.push(green_state, green_action, next_green_state, g_reward, done)
            
            # Train DQN
            optimize_dqn()
            
            # Update episode rewards
            episode_reward_r += r_reward
            episode_reward_b += b_reward
            episode_reward_g += g_reward
            
            # Check if game was won
            if red_agent['goal_reached'] and blue_agent['goal_reached']:
                won = True
            
            frame_count += 1
            
            # Update target network periodically
            if frame_count % TARGET_UPDATE == 0:
                target_dqn.load_state_dict(dqn.state_dict())
            
            # Update display and handle events
            draw_grid()
            pygame.display.flip()
            clock.tick(FPS)
            
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
            
            if done:
                break
        
        # Record episode statistics
        episode_rewards.append((episode_reward_r, episode_reward_b, episode_reward_g))
        win_rate.append(1 if won else 0)
        
        # Print episode information every 10 episodes
        if episode % 10 == 0:
            print(f"Episode: {episode}, Red Reward: {episode_reward_r}, Blue Reward: {episode_reward_b}, Green Reward: {episode_reward_g}")
            print(f"Win Rate (last 10): {sum(win_rate[-10:]) / 10}")
            
        # Save models every 100 episodes
        if episode % 100 == 0:
            torch.save(policy_f.state_dict(), "policy_f.pth")
            torch.save(policy_w.state_dict(), "policy_w.pth")
            torch.save(dqn.state_dict(), "dqn.pth")

# Main game loop
def main():
    running = True
    red_agent, blue_agent, green_agent, environment = reset_environment()
    
    # Training mode flag
    training_mode = False
    
    while running:
        # Handle Pygame events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_t:
                    # Toggle training mode
                    training_mode = not training_mode
                    print(f"Training mode: {'ON' if training_mode else 'OFF'}")
                elif event.key == pygame.K_r:
                    # Reset environment
                    red_agent, blue_agent, green_agent, environment = reset_environment()
                elif event.key == pygame.K_SPACE and training_mode:
                    # Start training
                    train()
        
        # Interactive mode (non-training)
        if not training_mode:
            # Get keyboard input for manual control
            keys = pygame.key.get_pressed()
            
            # Red agent controls (WASD)
            red_action = None
            if keys[pygame.K_w]:
                red_action = 0  # Up
            elif keys[pygame.K_d]:
                red_action = 1  # Right
            elif keys[pygame.K_s]:
                red_action = 2  # Down
            elif keys[pygame.K_a]:
                red_action = 3  # Left
            
            # Blue agent controls (Arrow keys)
            blue_action = None
            if keys[pygame.K_UP]:
                blue_action = 0  # Up
            elif keys[pygame.K_RIGHT]:
                blue_action = 1  # Right
            elif keys[pygame.K_DOWN]:
                blue_action = 2  # Down
            elif keys[pygame.K_LEFT]:
                blue_action = 3  # Left
            
            # Get green agent action from model
            green_state = get_dqn_state_representation(green_agent, red_agent, blue_agent, environment)
            green_action = select_dqn_action(green_state, 0.1)  # Small exploration rate
            
            # Take actions if provided
            if red_action is not None and blue_action is not None:
                step_environment(red_agent, blue_agent, green_agent, 
                                (red_action, blue_action, green_action), environment)
        
        # Update slide positions
        update_slides()
        
        # Draw the grid
        screen.fill((0, 0, 0))
        draw_grid()
        
        # Display game over message if applicable
        if red_agent['died'] or blue_agent['died']:
            font = pygame.font.SysFont(None, 55)
            text = font.render("Game Over! Press R to restart", True, (255, 255, 255))
            screen.blit(text, (SCREEN_WIDTH/2 - text.get_width()/2, SCREEN_HEIGHT/2 - text.get_height()/2))
        elif red_agent['goal_reached'] and blue_agent['goal_reached']:
            font = pygame.font.SysFont(None, 55)
            text = font.render("You Win! Press R to restart", True, (255, 255, 255))
            screen.blit(text, (SCREEN_WIDTH/2 - text.get_width()/2, SCREEN_HEIGHT/2 - text.get_height()/2))
        
        # Update the display
        pygame.display.flip()
        clock.tick(FPS)
    
    pygame.quit()
    sys.exit()

# Run the game
if __name__ == "__main__":
    main()

SystemExit: 

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import pygame
import sys
import numpy as np
from collections import deque
import math

# Constants
TILE_SIZE = 24
GRID_WIDTH = 30
GRID_HEIGHT = 20
SCREEN_WIDTH = TILE_SIZE * GRID_WIDTH
SCREEN_HEIGHT = TILE_SIZE * GRID_HEIGHT
FPS = 10  # Increased FPS for smoother animation
NUM_ACTIONS = 4  # Number of actions (up, down, left, right)

# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Fireboy and Watergirl Grid World")
clock = pygame.time.Clock()

# Colors
COLORS = {
    "0": (120, 66, 18),   # Brown - Empty
    "1": (160, 82, 45),   # Light Brown - Platform
    "F": (255, 0, 0),     # Fireboy - Red Heart
    "W": (0, 0, 255),     # Watergirl - Blue Heart
    "G": (0, 255, 0),     # Green Agent
    "RP": (255, 0, 0),  # Red Poison - Now a block
    "BP": (0, 0, 255),  # Blue Poison - Now a block
    "GP": (0, 255, 0),  # Green Poison - Now a block
    "FG": (255, 0, 0),    # Fireboy Goal
    "WG": (0, 0, 255),    # Watergirl Goal
    "RB": (255, 0, 0),    # Red Gem
    "BB": (0, 0, 255),    # Blue Gem
    "YS": (255, 255, 0),  # Yellow Slide
    "PS": (128, 0, 128),  # Purple Slide
    "YB": (255, 255, 0),  # Yellow Button
    "PB": (128, 0, 128),  # Purple Button
}

BASE_COLOR = COLORS["0"]

# Original grid layout - THIS WILL NEVER BE MODIFIED
ORIGINAL_GRID = [["0" for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)]

# Place sample tiles
ORIGINAL_GRID[18][1] = "F"    # Fireboy
ORIGINAL_GRID[15][1] = "W"    # Watergirl
ORIGINAL_GRID[2][19] = "G"    # Green agent

# Poison tiles
for x in range(10, 13):
    ORIGINAL_GRID[18][x] = "RP"

for x in range(17, 20):
    ORIGINAL_GRID[18][x] = "BP"

for x in range(18, 21):
    ORIGINAL_GRID[14][x] = "GP"

# Gems
ORIGINAL_GRID[17][11] = ORIGINAL_GRID[8][10] = "RB"
ORIGINAL_GRID[17][18] = ORIGINAL_GRID[10][20] = "BB"

# Goals
ORIGINAL_GRID[2][26] = "FG"
ORIGINAL_GRID[2][28] = "WG"

# Slides
for x in range(1, 4):
    ORIGINAL_GRID[9][x] = "YS"  # Yellow slides

for x in range(26, 29):
    ORIGINAL_GRID[6][x] = "PS"  # Purple slides

# Buttons
ORIGINAL_GRID[8][6] = ORIGINAL_GRID[12][9] = "YB"  # Yellow buttons
ORIGINAL_GRID[8][14] = ORIGINAL_GRID[5][22] = "PB"  # Purple buttons

# Assigning 1's for walls and platforms
for i in range(0, 30):
    ORIGINAL_GRID[0][i] = "1"
    ORIGINAL_GRID[19][i] = "1"

for i in range(0, 20):
    ORIGINAL_GRID[i][0] = "1"
    ORIGINAL_GRID[i][29] = "1"

ORIGINAL_GRID[1][1] = ORIGINAL_GRID[1][2] = "1"
for i in range(4, 29):
    ORIGINAL_GRID[3][i] = "1"

ORIGINAL_GRID[4][1] = ORIGINAL_GRID[4][2] = ORIGINAL_GRID[5][1] = ORIGINAL_GRID[5][2] = "1"

for i in range(1, 26):
    ORIGINAL_GRID[6][i] = "1"

ORIGINAL_GRID[7][18] = ORIGINAL_GRID[7][19] = ORIGINAL_GRID[7][20] = "1"
ORIGINAL_GRID[8][18] = ORIGINAL_GRID[8][19] = ORIGINAL_GRID[8][20] = "1"

for i in range(4, 16):
    ORIGINAL_GRID[9][i] = "1"

for i in range(17, 22):
    ORIGINAL_GRID[11][i] = "1"

for i in range(22, 29):
    ORIGINAL_GRID[10][i] = "1"

ORIGINAL_GRID[10][16] = "1"
ORIGINAL_GRID[11][26] = ORIGINAL_GRID[11][27] = ORIGINAL_GRID[11][28] = "1"
ORIGINAL_GRID[12][26] = ORIGINAL_GRID[12][27] = ORIGINAL_GRID[12][28] = "1"

for i in range(1, 13):
    ORIGINAL_GRID[13][i] = "1"

ORIGINAL_GRID[14][13] = "1"

for i in range(14, 27):
    ORIGINAL_GRID[15][i] = "1"

for i in range(1, 7):
    ORIGINAL_GRID[16][i] = "1"

ORIGINAL_GRID[17][27] = ORIGINAL_GRID[17][28] = ORIGINAL_GRID[18][27] = ORIGINAL_GRID[18][28] = "1"

# Copy original grid to working grid
grid = [[cell for cell in row] for row in ORIGINAL_GRID]

# Track slide positions and active buttons
yellow_slides_positions = []
purple_slides_positions = []
yellow_buttons_active = False
purple_buttons_active = False
slide_animation_counter = 0
SLIDE_ANIMATION_SPEED = 15  # Controls how fast slides move

# Initialize slide positions
def initialize_slides():
    global yellow_slides_positions, purple_slides_positions
    yellow_slides_positions = []
    purple_slides_positions = []
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if ORIGINAL_GRID[y][x] == "YS":
                yellow_slides_positions.append((y, x))
            elif ORIGINAL_GRID[y][x] == "PS":
                purple_slides_positions.append((y, x))

# Create grid maps for environment features
def create_grid_map(feature_type):
    """Create a boolean grid map for different environmental features"""
    grid_map = np.zeros((GRID_HEIGHT, GRID_WIDTH), dtype=bool)
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            cell = grid[y][x]
            if feature_type == 'wall' and cell == "1":
                grid_map[y][x] = True
            elif feature_type == 'red_poison' and cell == "RP":
                grid_map[y][x] = True
            elif feature_type == 'blue_poison' and cell == "BP":
                grid_map[y][x] = True
            elif feature_type == 'green_poison' and cell == "GP":
                grid_map[y][x] = True
            elif feature_type == 'red_gem' and cell == "RB":
                grid_map[y][x] = True
            elif feature_type == 'blue_gem' and cell == "BB":
                grid_map[y][x] = True
            elif feature_type == 'fire_goal' and cell == "FG":
                grid_map[y][x] = True
            elif feature_type == 'water_goal' and cell == "WG":
                grid_map[y][x] = True
            elif feature_type == 'yellow_button' and cell == "YB":
                grid_map[y][x] = True
            elif feature_type == 'purple_button' and cell == "PB":
                grid_map[y][x] = True
            elif feature_type == 'yellow_slide' and cell == "YS":
                grid_map[y][x] = True
            elif feature_type == 'purple_slide' and cell == "PS":
                grid_map[y][x] = True
    
    return grid_map

# Create all grid maps
def initialize_environment_maps():
    wall_map = create_grid_map('wall')
    red_poison_map = create_grid_map('red_poison')
    blue_poison_map = create_grid_map('blue_poison')
    green_poison_map = create_grid_map('green_poison')
    red_gem_map = create_grid_map('red_gem')
    blue_gem_map = create_grid_map('blue_gem')
    fire_goal_map = create_grid_map('fire_goal')
    water_goal_map = create_grid_map('water_goal')
    yellow_button_map = create_grid_map('yellow_button')
    purple_button_map = create_grid_map('purple_button')
    yellow_slide_map = create_grid_map('yellow_slide')
    purple_slide_map = create_grid_map('purple_slide')
    
    # Create combined maps
    poison_map = red_poison_map | blue_poison_map | green_poison_map
    
    # Create environment dictionary
    environment = {
        'wall': wall_map,
        'poison': poison_map,
        'red_poison': red_poison_map,
        'blue_poison': blue_poison_map,
        'green_poison': green_poison_map,
        'red_gems': red_gem_map.copy(),
        'blue_gems': blue_gem_map.copy(),
        'fire_goal': fire_goal_map,
        'water_goal': water_goal_map,
        'yellow_button': yellow_button_map,
        'purple_button': purple_button_map,
        'yellow_slide': yellow_slide_map,
        'purple_slide': purple_slide_map
    }
    
    return environment

# Agent position setup
def get_initial_positions():
    fireboy_pos = None
    watergirl_pos = None
    green_pos = None
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "F":
                fireboy_pos = (y, x)
            elif grid[y][x] == "W":
                watergirl_pos = (y, x)
            elif grid[y][x] == "G":
                green_pos = (y, x)
    
    return fireboy_pos, watergirl_pos, green_pos

# Draw Function
def draw_grid():
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            tile = grid[y][x]
            color = COLORS.get(tile, BASE_COLOR)

            # Agents: Heart shape
            if tile in ["F", "W", "G"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + 6),
                    (x*TILE_SIZE + 6, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE - 6, y*TILE_SIZE + TILE_SIZE//2),
                ])
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)
                pygame.draw.circle(screen, color, (x*TILE_SIZE + 2*TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)

            # Gems: Rhombus
            elif tile in ["RB", "BB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2)
                ])

            # Buttons: Circle
            elif tile in ["YB", "PB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE//2), TILE_SIZE//4)

            # Goals: Triangle
            elif tile in ["FG", "WG"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE)
                ])

            # Slides: Full tile fill
            elif tile in ["YS", "PS"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                
            # Poison: Now blocks with specific colors
            elif tile in ["RP", "BP", "GP"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Default or walls
            else:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Draw the black border
            pygame.draw.rect(screen, (0, 0, 0), (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE), 1)

# Enhanced neural network for MAPPO
class MAPPONetwork(nn.Module):
    def __init__(self, input_size, hidden_size, action_size):
        super(MAPPONetwork, self).__init__()
        # Actor network (policy)
        self.actor = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size),
        )
        
        # Critic network (value function)
        self.critic = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
    
    def forward(self, state):
        # Return action probabilities and value
        return F.softmax(self.actor(state), dim=-1), self.critic(state)
    
    def get_action(self, state, deterministic=False):
        probs, _ = self.forward(state)
        if deterministic:
            return torch.argmax(probs).item()
        else:
            return torch.multinomial(probs, 1).item()
    
    def evaluate(self, state, action):
        probs, value = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return log_prob, entropy, value

# Define the DQN for the green agent
class DQN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.fc3(x)  # Q-values (not softmax)

# Memory buffer for MAPPO
class MAPPOMemory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []
        self.log_probs = []
        self.values = []
    
    def push(self, state, action, reward, next_state, done, log_prob, value):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        self.values.append(value)
    
    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.rewards.clear()
        self.next_states.clear()
        self.dones.clear()
        self.log_probs.clear()
        self.values.clear()
    
    def __len__(self):
        return len(self.states)

# Experience replay memory for DQN
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    
    def push(self, state, action, next_state, reward, done):
        self.memory.append((state, action, next_state, reward, done))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

# Initialize neural networks with proper hidden sizes
HIDDEN_SIZE = 128
INPUT_SIZE_MAPPO = 16  # Expanded input size for better state representation
INPUT_SIZE_DQN = 24    # Even larger input size for the green agent

# Initialize MAPPO networks for cooperative agents
policy_f = MAPPONetwork(INPUT_SIZE_MAPPO, HIDDEN_SIZE, NUM_ACTIONS)
policy_w = MAPPONetwork(INPUT_SIZE_MAPPO, HIDDEN_SIZE, NUM_ACTIONS)

# Initialize DQN for the green adversarial agent
dqn = DQN(INPUT_SIZE_DQN, HIDDEN_SIZE, NUM_ACTIONS)
target_dqn = DQN(INPUT_SIZE_DQN, HIDDEN_SIZE, NUM_ACTIONS)
target_dqn.load_state_dict(dqn.state_dict())  # Initially the same as the online network

# Optimizers
optimizer_f = optim.Adam(policy_f.parameters(), lr=0.0005)
optimizer_w = optim.Adam(policy_w.parameters(), lr=0.0005)
optimizer_dqn = optim.Adam(dqn.parameters(), lr=0.0005)

# MAPPO memory buffers
memory_f = MAPPOMemory()
memory_w = MAPPOMemory()

# DQN experience replay memory
dqn_memory = ReplayMemory(10000)

# DQN hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99  # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 10000
TARGET_UPDATE = 10  # How often to update target network
epsilon = EPSILON_START

# MAPPO hyperparameters
PPO_EPOCHS = 4
PPO_EPSILON = 0.2  # Clip parameter
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01
GAE_LAMBDA = 0.95

# Initialize agent state
def reset_environment():
    # Reset grid to original state first
    global grid, yellow_buttons_active, purple_buttons_active, slide_animation_counter
    grid = [[cell for cell in row] for row in ORIGINAL_GRID]
    yellow_buttons_active = False
    purple_buttons_active = False
    slide_animation_counter = 0
    
    # Initialize slides
    initialize_slides()
    
    # Get fresh positions from the reset grid
    fireboy_pos, watergirl_pos, green_pos = get_initial_positions()
    
    # Create agent dictionaries with position and other properties
    red_agent = {'pos': fireboy_pos, 'color': 'red', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    blue_agent = {'pos': watergirl_pos, 'color': 'blue', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    green_agent = {'pos': green_pos, 'color': 'green', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    
    # Create fresh environment maps from the reset grid
    environment = initialize_environment_maps()
    
    return red_agent, blue_agent, green_agent, environment

# Handle slide movement
def update_slides():
    global yellow_slides_positions, purple_slides_positions, grid, slide_animation_counter
    
    # Clear previous slide positions from grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "YS" or grid[y][x] == "PS":
                grid[y][x] = "0"
    
    # Update yellow slides
    new_yellow_positions = []
    for y, x in yellow_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if yellow_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_yellow_positions.append((current_y, x))
                grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((target_y, x))
                grid[target_y][x] = "YS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in yellow_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in yellow_slides_positions):
                    current_pos = [pos for pos in yellow_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_yellow_positions.append((current_y, x))
                    grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((y, x))
                grid[y][x] = "YS"
    
    # Update purple slides
    new_purple_positions = []
    for y, x in purple_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if purple_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_purple_positions.append((current_y, x))
                grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((target_y, x))
                grid[target_y][x] = "PS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in purple_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in purple_slides_positions):
                    current_pos = [pos for pos in purple_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_purple_positions.append((current_y, x))
                    grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((y, x))
                grid[y][x] = "PS"
    
    # Update slide positions for next frame
    yellow_slides_positions = new_yellow_positions
    purple_slides_positions = new_purple_positions
    
    # Update animation counter
    if slide_animation_counter < SLIDE_ANIMATION_SPEED:
        slide_animation_counter += 1

# Enhanced state representation for MAPPO
def get_mappo_state(agent, other_agent, green_agent, environment):
    y, x = agent['pos']
    oy, ox = other_agent['pos']
    gy, gx = green_agent['pos']
    
    # Find locations of goals
    fire_goal_positions = np.where(environment['fire_goal'])
    water_goal_positions = np.where(environment['water_goal'])
    fgy, fgx = fire_goal_positions[0][0], fire_goal_positions[1][0] if len(fire_goal_positions[0]) > 0 else (0, 0)
    wgy, wgx = water_goal_positions[0][0], water_goal_positions[1][0] if len(water_goal_positions[0]) > 0 else (0, 0)
    
    # Target goal based on agent color
    goal_y, goal_x = (fgy, fgx) if agent['color'] == 'red' else (wgy, wgx)
    
    # Calculate distances
    dist_to_goal = np.sqrt((y - goal_y)**2 + (x - goal_x)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_other = np.sqrt((y - oy)**2 + (x - ox)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_green = np.sqrt((y - gy)**2 + (x - gx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Comprehensive state representation
    state = [
        y / GRID_HEIGHT,                  # Normalized y position
        x / GRID_WIDTH,                   # Normalized x position
        oy / GRID_HEIGHT,                 # Other agent y position
        ox / GRID_WIDTH,                  # Other agent x position
        gy / GRID_HEIGHT,                 # Green agent y position
        gx / GRID_WIDTH,                  # Green agent x position
        dist_to_goal,                     # Distance to own goal
        dist_to_other,                    # Distance to other agent
        dist_to_green,                    # Distance to green agent
        1.0 if agent['is_pressing_button'] else 0.0,  # Is agent pressing button
        1.0 if other_agent['is_pressing_button'] else 0.0,  # Is other agent pressing button
        1.0 if yellow_buttons_active else 0.0,        # Is yellow button active
        1.0 if purple_buttons_active else 0.0,        # Is purple button active
        1.0 if agent['goal_reached'] else 0.0,        # Has agent reached goal
        agent['collected_gems'] / 2.0,                # Normalized collected gems
        other_agent['collected_gems'] / 2.0           # Other agent's collected gems
    ]
    
    return torch.tensor(state, dtype=torch.float)

# Enhanced state representation for DQN (green agent)
def get_dqn_state(green_agent, red_agent, blue_agent, environment):
    gy, gx = green_agent['pos']
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    
    # Find locations of goals
    fire_goal_positions = np.where(environment['fire_goal'])
    water_goal_positions = np.where(environment['water_goal'])
    fgy, fgx = fire_goal_positions[0][0], fire_goal_positions[1][0] if len(fire_goal_positions[0]) > 0 else (0, 0)
    wgy, wgx = water_goal_positions[0][0], water_goal_positions[1][0] if len(water_goal_positions[0]) > 0 else (0, 0)
    
    # Calculate distances to fire/water agents
    dist_to_red = np.sqrt((gy - ry)**2 + (gx - rx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_blue = np.sqrt((gy - by)**2 + (gx - bx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Calculate distances from fire/water to their goals
    red_to_goal = np.sqrt((ry - fgy)**2 + (rx - fgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    blue_to_goal = np.sqrt((by - wgy)**2 + (bx - wgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Look in different directions for obstacles and agents
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]  # right, down, left, up
    dir_features = []
    
    for dy, dx in directions:
        ny, nx = gy + dy, gx + dx
        # Check if there's a wall
        if 0 <= ny < GRID_HEIGHT and 0 <= nx < GRID_WIDTH and environment['wall'][ny][nx]:
            dir_features.append(1.0)
        else:
            dir_features.append(0.0)
            
        # Check if there's a fire/water agent
        if (ny, nx) == (ry, rx) or (ny, nx) == (by, bx):
            dir_features.append(1.0)
        else:
            dir_features.append(0.0)
    
    # Comprehensive state for green agent
    state = [
        gy / GRID_HEIGHT,                 # Normalized y position
        gx / GRID_WIDTH,                  # Normalized x position
        ry / GRID_HEIGHT,                 # Red agent y position
        rx / GRID_WIDTH,                  # Red agent x position
        by / GRID_HEIGHT,                 # Blue agent y position
        bx / GRID_WIDTH,                  # Blue agent x position
        dist_to_red,                      # Distance to red agent
        dist_to_blue,                     # Distance to blue agent
        red_to_goal,                      # Red agent's distance to goal
        blue_to_goal,                     # Blue agent's distance to goal
        1.0 if red_agent['goal_reached'] else 0.0,   # Has red agent reached goal
        1.0 if blue_agent['goal_reached'] else 0.0,  # Has blue agent reached goal
        red_agent['collected_gems'] / 2.0,           # Red agent's collected gems
        blue_agent['collected_gems'] / 2.0,          # Blue agent's collected gems
        1.0 if green_agent['is_pressing_button'] else 0.0,  # Is green agent pressing button
        1.0 if yellow_buttons_active else 0.0,        # Is yellow button active
        1.0 if purple_buttons_active else 0.0,        # Is purple button active
    ]
    
    # Add the directional features
    state.extend(dir_features)
    
    return torch.tensor(state, dtype=torch.float)

# Move the agent based on action
def move_agent(agent, action, environment):
    y, x = agent['pos']
    
    # Define actions: 0=up, 1=right, 2=down, 3=left
    dy, dx = [(0, 0), (-1, 0), (0, 1), (1, 0), (0, -1)][action]
    
    # Check if the move would take us to a wall
    new_y, new_x = y + dy, x + dx
    
    # Don't move if hitting a wall
    if (new_y < 0 or new_y >= GRID_HEIGHT or new_x < 0 or new_x >= GRID_WIDTH or 
        environment['wall'][new_y][new_x]):
        return (y, x)  # Stay in place
    
    return (new_y, new_x)

# Check for interactions based on agent movement
def check_interactions(agent, environment):
    y, x = agent['pos']
    result = {
        'consumed_gem': False,
        'died': False,
        'reached_goal': False,
        'pressed_button': False
    }
    
    # Check for gem collection
    if agent['color'] == 'red' and environment['red_gems'][y][x]:
        environment['red_gems'][y][x] = False
        agent['collected_gems'] += 1
        result['consumed_gem'] = True
        # Remove the gem from the grid
        if grid[y][x] == "RB":
            grid[y][x] = "0"
    
    elif agent['color'] == 'blue' and environment['blue_gems'][y][x]:
        environment['blue_gems'][y][x] = False
        agent['collected_gems'] += 1
        result['consumed_gem'] = True
        # Remove the gem from the grid
        if grid[y][x] == "BB":
            grid[y][x] = "0"
    
    # Check for poison interaction
    if (agent['color'] == 'red' and environment['blue_poison'][y][x]) or \
       (agent['color'] == 'blue' and environment['red_poison'][y][x]) or \
       (environment['green_poison'][y][x]):
        agent['died'] = True
        result['died'] = True
    
    # Check for goal reaching
    if agent['color'] == 'red' and environment['fire_goal'][y][x]:
        agent['goal_reached'] = True
        result['reached_goal'] = True
    elif agent['color'] == 'blue' and environment['water_goal'][y][x]:
        agent['goal_reached'] = True
        result['reached_goal'] = True
    
    # Check for button pressing
    if environment['yellow_button'][y][x] or environment['purple_button'][y][x]:
        agent['is_pressing_button'] = True
        result['pressed_button'] = True
        
        # Activate corresponding button
        global yellow_buttons_active, purple_buttons_active, slide_animation_counter
        if environment['yellow_button'][y][x]:
            yellow_buttons_active = True
        if environment['purple_button'][y][x]:
            purple_buttons_active = True
        
        # Reset animation counter when a button is pressed
        slide_animation_counter = 0
    else:
        agent['is_pressing_button'] = False
    
    return result

# Update grid based on agent positions
def update_grid_state(red_agent, blue_agent, green_agent):
    # Clear previous agent positions
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] in ["F", "W", "G"]:
                grid[y][x] = "0"
    
    # Place agents in new positions
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    gy, gx = green_agent['pos']
    
    grid[ry][rx] = "F"
    grid[by][bx] = "W"
    grid[gy][gx] = "G"

# Calculate rewards for fireboy and watergirl (cooperative)
def calculate_cooperative_reward(agent, interactions, other_agent):
    reward = 0
    
    # Base reward for survival
    reward += 0.1
    
    # Reward for gem collection
    if interactions['consumed_gem']:
        reward += 1.0
    
    # Reward for reaching goal
    if interactions['reached_goal']:
        reward += 5.0
    
    # Penalty for death
    if interactions['died']:
        reward -= 5.0
        
    # Reward for button pressing
    if interactions['pressed_button']:
        reward += 0.5
    
    # Cooperative bonus when both agents reach their goals
    if agent['goal_reached'] and other_agent['goal_reached']:
        reward += 10.0
    
    # Penalty for being too far from other agent
    y, x = agent['pos']
    oy, ox = other_agent['pos']
    distance = np.sqrt((y - oy)**2 + (x - ox)**2)
    
    # Encourage staying within a reasonable distance
    if distance > 10:
        reward -= 0.2
    
    return reward

# Calculate rewards for green agent (adversarial)
def calculate_adversarial_reward(green_agent, red_agent, blue_agent):
    reward = 0
    
    # Base reward for survival
    reward += 0.1
    
    # Reward for being close to fireboy/watergirl
    gy, gx = green_agent['pos']
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    
    dist_to_red = np.sqrt((gy - ry)**2 + (gx - rx)**2)
    dist_to_blue = np.sqrt((gy - by)**2 + (gx - bx)**2)
    
    # Reward inversely proportional to distance to closest agent
    min_dist = min(dist_to_red, dist_to_blue)
    reward += max(0, (10 - min_dist) / 10) * 0.5
    
    # Reward for causing fireboy/watergirl to die
    if red_agent['died'] or blue_agent['died']:
        reward += 5.0
    
    # Penalty if both fireboy and watergirl reach their goals
    if red_agent['goal_reached'] and blue_agent['goal_reached']:
        reward -= 5.0
    
    # Reward for pressing buttons (interfering with fireboy/watergirl)
    if green_agent['is_pressing_button']:
        reward += 0.5
    
    return reward

# MAPPO training function
def train_mappo(memory, policy, optimizer):
    # Return early if not enough data
    if len(memory) < 32:
        return
    
    # Convert memory to tensors
    states = torch.stack(memory.states)
    actions = torch.tensor(memory.actions)
    rewards = torch.tensor(memory.rewards)
    next_states = torch.stack(memory.next_states)
    dones = torch.tensor(memory.dones, dtype=torch.float)
    old_log_probs = torch.stack(memory.log_probs)
    old_values = torch.stack(memory.values)
    
    # Calculate returns and advantages
    returns = []
    advantages = []
    gae = 0
    
    # Get next state values
    _, next_values = policy(next_states)
    next_values = next_values.squeeze(-1)
    
    # Calculate returns and advantages
    for i in reversed(range(len(rewards))):
        if i == len(rewards) - 1:
            next_return = next_values[i] * (1 - dones[i])
        else:
            next_return = returns[0]
        
        # Calculate return
        current_return = rewards[i] + GAMMA * next_return
        returns.insert(0, current_return)
        
        # Calculate advantage using GAE
        delta = rewards[i] + GAMMA * next_values[i] * (1 - dones[i]) - old_values[i]
        gae = delta + GAMMA * GAE_LAMBDA * (1 - dones[i]) * gae
        advantages.insert(0, gae)
    
    # Convert to tensors
    returns = torch.tensor(returns)
    advantages = torch.tensor(advantages)
    
    # Normalize advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    # Perform PPO update
    for _ in range(PPO_EPOCHS):
        # Get updated action probabilities and values
        current_log_probs, entropies, current_values = [], [], []
        
        for i in range(len(states)):
            log_prob, entropy, value = policy.evaluate(states[i], actions[i])
            current_log_probs.append(log_prob)
            entropies.append(entropy)
            current_values.append(value)
        
        current_log_probs = torch.stack(current_log_probs)
        entropies = torch.stack(entropies)
        current_values = torch.stack(current_values).squeeze(-1)
        
        # Calculate ratios and surrogate losses
        ratios = torch.exp(current_log_probs - old_log_probs.detach())
        
        # Calculate surrogate objectives
        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1.0 - PPO_EPSILON, 1.0 + PPO_EPSILON) * advantages
        
        # Calculate actor, critic, and entropy losses
        actor_loss = -torch.min(surr1, surr2).mean()
        critic_loss = F.mse_loss(current_values, returns)
        entropy_loss = -entropies.mean()
        
        # Total loss
        loss = actor_loss + VALUE_COEF * critic_loss + ENTROPY_COEF * entropy_loss
        
        # Perform optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Clear memory after update
    memory.clear()

# DQN training function
def train_dqn():
    global epsilon
    
    # Skip training if memory is too small
    if len(dqn_memory) < BATCH_SIZE:
        return
    
    # Sample batch from memory
    batch = dqn_memory.sample(BATCH_SIZE)
    batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = zip(*batch)
    
    # Convert to tensors
    states = torch.stack(batch_states)
    actions = torch.tensor(batch_actions).unsqueeze(1)
    next_states = torch.stack(batch_next_states)
    rewards = torch.tensor(batch_rewards).unsqueeze(1)
    dones = torch.tensor(batch_dones, dtype=torch.float).unsqueeze(1)
    
    # Compute Q values
    q_values = dqn(states).gather(1, actions)
    
    # Compute next Q values with target network
    with torch.no_grad():
        next_q_values = target_dqn(next_states).max(1, keepdim=True)[0]
    
    # Compute target Q values
    target_q_values = rewards + GAMMA * next_q_values * (1 - dones)
    
    # Compute loss
    loss = F.mse_loss(q_values, target_q_values)
    
    # Optimize
    optimizer_dqn.zero_grad()
    loss.backward()
    optimizer_dqn.step()
    
    # Decay epsilon
    epsilon = max(EPSILON_END, EPSILON_START - (EPSILON_START - EPSILON_END) * (len(dqn_memory) / EPSILON_DECAY))

# Fireboy action selection with MAPPO
def select_fireboy_action(red_agent, blue_agent, green_agent, environment, training=True):
    state = get_mappo_state(red_agent, blue_agent, green_agent, environment)
    
    if training:
        probs, value = policy_f(state)
        action = torch.multinomial(probs, 1).item()
        log_prob = torch.log(probs[action])
        return action, log_prob, value.item()
    else:
        return policy_f.get_action(state, deterministic=True), None, None

# Watergirl action selection with MAPPO
def select_watergirl_action(blue_agent, red_agent, green_agent, environment, training=True):
    state = get_mappo_state(blue_agent, red_agent, green_agent, environment)
    
    if training:
        probs, value = policy_w(state)
        action = torch.multinomial(probs, 1).item()
        log_prob = torch.log(probs[action])
        return action, log_prob, value.item()
    else:
        return policy_w.get_action(state, deterministic=True), None, None

# Green agent action selection with DQN
def select_green_action(green_agent, red_agent, blue_agent, environment, training=True):
    state = get_dqn_state(green_agent, red_agent, blue_agent, environment)
    
    if training and random.random() < epsilon:
        return random.randint(0, NUM_ACTIONS - 1)
    else:
        with torch.no_grad():
            return torch.argmax(dqn(state)).item()

# Main game loop
def main():
    global grid, epsilon, target_dqn
    
    # Training parameters
    num_episodes = 10000
    target_update_counter = 0
    render_every = 100  # Only render every 100 episodes
    
    # Training loop
    for episode in range(num_episodes):
        # Reset environment
        red_agent, blue_agent, green_agent, environment = reset_environment()
        
        # Track episode stats
        episode_length = 0
        episode_f_reward = 0
        episode_w_reward = 0
        episode_g_reward = 0
        done = False
        
        # Episode loop
        while not done and episode_length < 200:  # Max 200 steps per episode
            # Render if needed
            if episode % render_every == 0:
                # Handle pygame events
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        sys.exit()
                
                # Render the game
                screen.fill((0, 0, 0))
                draw_grid()
                pygame.display.flip()
                clock.tick(FPS)
            
            # Update slides and buttons
            update_slides()
            
            # Fireboy's turn
            f_action, f_log_prob, f_value = select_fireboy_action(red_agent, blue_agent, green_agent, environment)
            red_agent['pos'] = move_agent(red_agent, f_action, environment)
            
            # Watergirl's turn
            w_action, w_log_prob, w_value = select_watergirl_action(blue_agent, red_agent, green_agent, environment)
            blue_agent['pos'] = move_agent(blue_agent, w_action, environment)
            
            # Green agent's turn
            g_action = select_green_action(green_agent, red_agent, blue_agent, environment)
            green_agent['pos'] = move_agent(green_agent, g_action, environment)
            
            # Update the grid
            update_grid_state(red_agent, blue_agent, green_agent)
            
            # Check for interactions
            f_interactions = check_interactions(red_agent, environment)
            w_interactions = check_interactions(blue_agent, environment)
            g_interactions = check_interactions(green_agent, environment)
            
            # Calculate rewards
            f_reward = calculate_cooperative_reward(red_agent, f_interactions, blue_agent)
            w_reward = calculate_cooperative_reward(blue_agent, w_interactions, red_agent)
            g_reward = calculate_adversarial_reward(green_agent, red_agent, blue_agent)
            
            # Update episode rewards
            episode_f_reward += f_reward
            episode_w_reward += w_reward
            episode_g_reward += g_reward
            
            # Get new states
            f_next_state = get_mappo_state(red_agent, blue_agent, green_agent, environment)
            w_next_state = get_mappo_state(blue_agent, red_agent, green_agent, environment)
            g_next_state = get_dqn_state(green_agent, red_agent, blue_agent, environment)
            
            # Check if episode is done
            done = (red_agent['died'] or blue_agent['died'] or
                   (red_agent['goal_reached'] and blue_agent['goal_reached']))
            
            # Store experiences
            memory_f.push(
                get_mappo_state(red_agent, blue_agent, green_agent, environment),
                f_action,
                f_reward,
                f_next_state,
                done,
                torch.tensor(f_log_prob),
                torch.tensor(f_value)
            )
            
            memory_w.push(
                get_mappo_state(blue_agent, red_agent, green_agent, environment),
                w_action,
                w_reward,
                w_next_state,
                done,
                torch.tensor(w_log_prob),
                torch.tensor(w_value)
            )
            
            dqn_memory.push(
                get_dqn_state(green_agent, red_agent, blue_agent, environment),
                g_action,
                g_next_state,
                g_reward,
                done
            )
            
            # Train models
            train_mappo(memory_f, policy_f, optimizer_f)
            train_mappo(memory_w, policy_w, optimizer_w)
            train_dqn()
            
            episode_length += 1
            
            # Update target network
            target_update_counter += 1
            if target_update_counter % TARGET_UPDATE == 0:
                target_dqn.load_state_dict(dqn.state_dict())
        
        # Print episode stats
        if episode % 100 == 0:
            print(f"Episode {episode}")
            print(f"F Reward: {episode_f_reward:.2f}, W Reward: {episode_w_reward:.2f}, G Reward: {episode_g_reward:.2f}")
            print(f"Episode Length: {episode_length}")
            print(f"Epsilon: {epsilon:.4f}")
            print()
    
    # Save the trained models
    torch.save(policy_f.state_dict(), "policy_f.pth")
    torch.save(policy_w.state_dict(), "policy_w.pth")
    torch.save(dqn.state_dict(), "dqn.pth")
    
    print("Training complete!")

# Run the game
if __name__ == "__main__":
    initialize_slides()
    main()

  torch.tensor(f_log_prob),
  torch.tensor(w_log_prob),


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x25 and 24x128)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import pygame
import sys
import numpy as np
from collections import deque
import math

# Constants
TILE_SIZE = 24
GRID_WIDTH = 30
GRID_HEIGHT = 20
SCREEN_WIDTH = TILE_SIZE * GRID_WIDTH
SCREEN_HEIGHT = TILE_SIZE * GRID_HEIGHT
FPS = 10  # Increased FPS for smoother animation
NUM_ACTIONS = 4  # Number of actions (up, down, left, right)

# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Fireboy and Watergirl Grid World")
clock = pygame.time.Clock()

# Colors
COLORS = {
    "0": (120, 66, 18),   # Brown - Empty
    "1": (160, 82, 45),   # Light Brown - Platform
    "F": (255, 0, 0),     # Fireboy - Red Heart
    "W": (0, 0, 255),     # Watergirl - Blue Heart
    "G": (0, 255, 0),     # Green Agent
    "RP": (255, 0, 0),  # Red Poison - Now a block
    "BP": (0, 0, 255),  # Blue Poison - Now a block
    "GP": (0, 255, 0),  # Green Poison - Now a block
    "FG": (255, 0, 0),    # Fireboy Goal
    "WG": (0, 0, 255),    # Watergirl Goal
    "RB": (255, 0, 0),    # Red Gem
    "BB": (0, 0, 255),    # Blue Gem
    "YS": (255, 255, 0),  # Yellow Slide
    "PS": (128, 0, 128),  # Purple Slide
    "YB": (255, 255, 0),  # Yellow Button
    "PB": (128, 0, 128),  # Purple Button
}

BASE_COLOR = COLORS["0"]

# Original grid layout - THIS WILL NEVER BE MODIFIED
ORIGINAL_GRID = [["0" for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)]

# Place sample tiles
ORIGINAL_GRID[18][1] = "F"    # Fireboy
ORIGINAL_GRID[15][1] = "W"    # Watergirl
ORIGINAL_GRID[2][19] = "G"    # Green agent

# Poison tiles
for x in range(10, 13):
    ORIGINAL_GRID[18][x] = "RP"

for x in range(17, 20):
    ORIGINAL_GRID[18][x] = "BP"

for x in range(18, 21):
    ORIGINAL_GRID[14][x] = "GP"

# Gems
ORIGINAL_GRID[17][11] = ORIGINAL_GRID[8][10] = "RB"
ORIGINAL_GRID[17][18] = ORIGINAL_GRID[10][20] = "BB"

# Goals
ORIGINAL_GRID[2][26] = "FG"
ORIGINAL_GRID[2][28] = "WG"

# Slides
for x in range(1, 4):
    ORIGINAL_GRID[9][x] = "YS"  # Yellow slides

for x in range(26, 29):
    ORIGINAL_GRID[6][x] = "PS"  # Purple slides

# Buttons
ORIGINAL_GRID[8][6] = ORIGINAL_GRID[12][9] = "YB"  # Yellow buttons
ORIGINAL_GRID[8][14] = ORIGINAL_GRID[5][22] = "PB"  # Purple buttons

# Assigning 1's for walls and platforms
for i in range(0, 30):
    ORIGINAL_GRID[0][i] = "1"
    ORIGINAL_GRID[19][i] = "1"

for i in range(0, 20):
    ORIGINAL_GRID[i][0] = "1"
    ORIGINAL_GRID[i][29] = "1"

ORIGINAL_GRID[1][1] = ORIGINAL_GRID[1][2] = "1"
for i in range(4, 29):
    ORIGINAL_GRID[3][i] = "1"

ORIGINAL_GRID[4][1] = ORIGINAL_GRID[4][2] = ORIGINAL_GRID[5][1] = ORIGINAL_GRID[5][2] = "1"

for i in range(1, 26):
    ORIGINAL_GRID[6][i] = "1"

ORIGINAL_GRID[7][18] = ORIGINAL_GRID[7][19] = ORIGINAL_GRID[7][20] = "1"
ORIGINAL_GRID[8][18] = ORIGINAL_GRID[8][19] = ORIGINAL_GRID[8][20] = "1"

for i in range(4, 16):
    ORIGINAL_GRID[9][i] = "1"

for i in range(17, 22):
    ORIGINAL_GRID[11][i] = "1"

for i in range(22, 29):
    ORIGINAL_GRID[10][i] = "1"

ORIGINAL_GRID[10][16] = "1"
ORIGINAL_GRID[11][26] = ORIGINAL_GRID[11][27] = ORIGINAL_GRID[11][28] = "1"
ORIGINAL_GRID[12][26] = ORIGINAL_GRID[12][27] = ORIGINAL_GRID[12][28] = "1"

for i in range(1, 13):
    ORIGINAL_GRID[13][i] = "1"

ORIGINAL_GRID[14][13] = "1"

for i in range(14, 27):
    ORIGINAL_GRID[15][i] = "1"

for i in range(1, 7):
    ORIGINAL_GRID[16][i] = "1"

ORIGINAL_GRID[17][27] = ORIGINAL_GRID[17][28] = ORIGINAL_GRID[18][27] = ORIGINAL_GRID[18][28] = "1"

# Copy original grid to working grid
grid = [[cell for cell in row] for row in ORIGINAL_GRID]

# Track slide positions and active buttons
yellow_slides_positions = []
purple_slides_positions = []
yellow_buttons_active = False
purple_buttons_active = False
slide_animation_counter = 0
SLIDE_ANIMATION_SPEED = 15  # Controls how fast slides move

# Initialize slide positions
def initialize_slides():
    global yellow_slides_positions, purple_slides_positions
    yellow_slides_positions = []
    purple_slides_positions = []
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if ORIGINAL_GRID[y][x] == "YS":
                yellow_slides_positions.append((y, x))
            elif ORIGINAL_GRID[y][x] == "PS":
                purple_slides_positions.append((y, x))

# Create grid maps for environment features
def create_grid_map(feature_type):
    """Create a boolean grid map for different environmental features"""
    grid_map = np.zeros((GRID_HEIGHT, GRID_WIDTH), dtype=bool)
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            cell = grid[y][x]
            if feature_type == 'wall' and cell == "1":
                grid_map[y][x] = True
            elif feature_type == 'red_poison' and cell == "RP":
                grid_map[y][x] = True
            elif feature_type == 'blue_poison' and cell == "BP":
                grid_map[y][x] = True
            elif feature_type == 'green_poison' and cell == "GP":
                grid_map[y][x] = True
            elif feature_type == 'red_gem' and cell == "RB":
                grid_map[y][x] = True
            elif feature_type == 'blue_gem' and cell == "BB":
                grid_map[y][x] = True
            elif feature_type == 'fire_goal' and cell == "FG":
                grid_map[y][x] = True
            elif feature_type == 'water_goal' and cell == "WG":
                grid_map[y][x] = True
            elif feature_type == 'yellow_button' and cell == "YB":
                grid_map[y][x] = True
            elif feature_type == 'purple_button' and cell == "PB":
                grid_map[y][x] = True
            elif feature_type == 'yellow_slide' and cell == "YS":
                grid_map[y][x] = True
            elif feature_type == 'purple_slide' and cell == "PS":
                grid_map[y][x] = True
    
    return grid_map

# Create all grid maps
def initialize_environment_maps():
    wall_map = create_grid_map('wall')
    red_poison_map = create_grid_map('red_poison')
    blue_poison_map = create_grid_map('blue_poison')
    green_poison_map = create_grid_map('green_poison')
    red_gem_map = create_grid_map('red_gem')
    blue_gem_map = create_grid_map('blue_gem')
    fire_goal_map = create_grid_map('fire_goal')
    water_goal_map = create_grid_map('water_goal')
    yellow_button_map = create_grid_map('yellow_button')
    purple_button_map = create_grid_map('purple_button')
    yellow_slide_map = create_grid_map('yellow_slide')
    purple_slide_map = create_grid_map('purple_slide')
    
    # Create combined maps
    poison_map = red_poison_map | blue_poison_map | green_poison_map
    
    # Create environment dictionary
    environment = {
        'wall': wall_map,
        'poison': poison_map,
        'red_poison': red_poison_map,
        'blue_poison': blue_poison_map,
        'green_poison': green_poison_map,
        'red_gems': red_gem_map.copy(),
        'blue_gems': blue_gem_map.copy(),
        'fire_goal': fire_goal_map,
        'water_goal': water_goal_map,
        'yellow_button': yellow_button_map,
        'purple_button': purple_button_map,
        'yellow_slide': yellow_slide_map,
        'purple_slide': purple_slide_map
    }
    
    return environment

# Agent position setup
def get_initial_positions():
    fireboy_pos = None
    watergirl_pos = None
    green_pos = None
    
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "F":
                fireboy_pos = (y, x)
            elif grid[y][x] == "W":
                watergirl_pos = (y, x)
            elif grid[y][x] == "G":
                green_pos = (y, x)
    
    return fireboy_pos, watergirl_pos, green_pos

# Draw Function
def draw_grid():
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            tile = grid[y][x]
            color = COLORS.get(tile, BASE_COLOR)

            # Agents: Heart shape
            if tile in ["F", "W", "G"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + 6),
                    (x*TILE_SIZE + 6, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE - 6, y*TILE_SIZE + TILE_SIZE//2),
                ])
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)
                pygame.draw.circle(screen, color, (x*TILE_SIZE + 2*TILE_SIZE//3, y*TILE_SIZE + TILE_SIZE//3), 5)

            # Gems: Rhombus
            elif tile in ["RB", "BB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2),
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE//2)
                ])

            # Buttons: Circle
            elif tile in ["YB", "PB"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.circle(screen, color, (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE + TILE_SIZE//2), TILE_SIZE//4)

            # Goals: Triangle
            elif tile in ["FG", "WG"]:
                pygame.draw.rect(screen, BASE_COLOR, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                pygame.draw.polygon(screen, color, [
                    (x*TILE_SIZE + TILE_SIZE//2, y*TILE_SIZE),
                    (x*TILE_SIZE, y*TILE_SIZE + TILE_SIZE),
                    (x*TILE_SIZE + TILE_SIZE, y*TILE_SIZE + TILE_SIZE)
                ])

            # Slides: Full tile fill
            elif tile in ["YS", "PS"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))
                
            # Poison: Now blocks with specific colors
            elif tile in ["RP", "BP", "GP"]:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Default or walls
            else:
                pygame.draw.rect(screen, color, (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE))

            # Draw the black border
            pygame.draw.rect(screen, (0, 0, 0), (x*TILE_SIZE, y*TILE_SIZE, TILE_SIZE, TILE_SIZE), 1)

# Enhanced neural network for MAPPO
class MAPPONetwork(nn.Module):
    def __init__(self, input_size, hidden_size, action_size):
        super(MAPPONetwork, self).__init__()
        # Actor network (policy)
        self.actor = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size),
        )
        
        # Critic network (value function)
        self.critic = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
    
    def forward(self, state):
        # Return action probabilities and value
        return F.softmax(self.actor(state), dim=-1), self.critic(state)
    
    def get_action(self, state, deterministic=False):
        probs, _ = self.forward(state)
        if deterministic:
            return torch.argmax(probs).item()
        else:
            return torch.multinomial(probs, 1).item()
    
    def evaluate(self, state, action):
        probs, value = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        return log_prob, entropy, value

# Define the DQN for the green agent
class DQN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.fc3(x)  # Q-values (not softmax)

# Memory buffer for MAPPO
class MAPPOMemory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []
        self.log_probs = []
        self.values = []
    
    def push(self, state, action, reward, next_state, done, log_prob, value):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        self.values.append(value)
    
    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.rewards.clear()
        self.next_states.clear()
        self.dones.clear()
        self.log_probs.clear()
        self.values.clear()
    
    def __len__(self):
        return len(self.states)

# Experience replay memory for DQN
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    
    def push(self, state, action, next_state, reward, done):
        self.memory.append((state, action, next_state, reward, done))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

# Initialize neural networks with proper hidden sizes
HIDDEN_SIZE = 128
INPUT_SIZE_MAPPO = 16  # Expanded input size for better state representation
# FIX: Match the input size to the actual features in the state
INPUT_SIZE_DQN = 25    # Updated to match actual state size from get_dqn_state

# Initialize MAPPO networks for cooperative agents
policy_f = MAPPONetwork(INPUT_SIZE_MAPPO, HIDDEN_SIZE, NUM_ACTIONS)
policy_w = MAPPONetwork(INPUT_SIZE_MAPPO, HIDDEN_SIZE, NUM_ACTIONS)

# Initialize DQN for the green adversarial agent
dqn = DQN(INPUT_SIZE_DQN, HIDDEN_SIZE, NUM_ACTIONS)
target_dqn = DQN(INPUT_SIZE_DQN, HIDDEN_SIZE, NUM_ACTIONS)
target_dqn.load_state_dict(dqn.state_dict())  # Initially the same as the online network

# Optimizers
optimizer_f = optim.Adam(policy_f.parameters(), lr=0.0005)
optimizer_w = optim.Adam(policy_w.parameters(), lr=0.0005)
optimizer_dqn = optim.Adam(dqn.parameters(), lr=0.0005)

# MAPPO memory buffers
memory_f = MAPPOMemory()
memory_w = MAPPOMemory()

# DQN experience replay memory
dqn_memory = ReplayMemory(10000)

# DQN hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99  # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 10000
TARGET_UPDATE = 10  # How often to update target network
epsilon = EPSILON_START

# MAPPO hyperparameters
PPO_EPOCHS = 4
PPO_EPSILON = 0.2  # Clip parameter
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01
GAE_LAMBDA = 0.95

# Initialize agent state
def reset_environment():
    # Reset grid to original state first
    global grid, yellow_buttons_active, purple_buttons_active, slide_animation_counter
    grid = [[cell for cell in row] for row in ORIGINAL_GRID]
    yellow_buttons_active = False
    purple_buttons_active = False
    slide_animation_counter = 0
    
    # Initialize slides
    initialize_slides()
    
    # Get fresh positions from the reset grid
    fireboy_pos, watergirl_pos, green_pos = get_initial_positions()
    
    # Create agent dictionaries with position and other properties
    red_agent = {'pos': fireboy_pos, 'color': 'red', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    blue_agent = {'pos': watergirl_pos, 'color': 'blue', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    green_agent = {'pos': green_pos, 'color': 'green', 'is_pressing_button': False, 'goal_reached': False, 'collected_gems': 0, 'died': False}
    
    # Create fresh environment maps from the reset grid
    environment = initialize_environment_maps()
    
    return red_agent, blue_agent, green_agent, environment

# Handle slide movement
def update_slides():
    global yellow_slides_positions, purple_slides_positions, grid, slide_animation_counter
    
    # Clear previous slide positions from grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] == "YS" or grid[y][x] == "PS":
                grid[y][x] = "0"
    
    # Update yellow slides
    new_yellow_positions = []
    for y, x in yellow_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if yellow_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_yellow_positions.append((current_y, x))
                grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((target_y, x))
                grid[target_y][x] = "YS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in yellow_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in yellow_slides_positions):
                    current_pos = [pos for pos in yellow_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_yellow_positions.append((current_y, x))
                    grid[current_y][x] = "YS"
            else:
                new_yellow_positions.append((y, x))
                grid[y][x] = "YS"
    
    # Update purple slides
    new_purple_positions = []
    for y, x in purple_slides_positions:
        # If button is pressed, move down 3 rows, otherwise stay in original position
        if purple_buttons_active:
            target_y = min(y + 3, GRID_HEIGHT - 1)
            # Check if we're in animation
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                # Calculate intermediate position
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                current_y = int(y + (target_y - y) * animation_progress)
                new_purple_positions.append((current_y, x))
                grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((target_y, x))
                grid[target_y][x] = "PS"
        else:
            # If buttons not active, return to original position
            if slide_animation_counter < SLIDE_ANIMATION_SPEED:
                original_y = [pos[0] for pos in purple_slides_positions if pos[1] == x][0]
                animation_progress = slide_animation_counter / SLIDE_ANIMATION_SPEED
                # Find the current position in the animation
                if any(pos[1] == x for pos in purple_slides_positions):
                    current_pos = [pos for pos in purple_slides_positions if pos[1] == x][0]
                    current_y = int(current_pos[0] - (current_pos[0] - original_y) * animation_progress)
                    new_purple_positions.append((current_y, x))
                    grid[current_y][x] = "PS"
            else:
                new_purple_positions.append((y, x))
                grid[y][x] = "PS"
    
    # Update slide positions for next frame
    yellow_slides_positions = new_yellow_positions
    purple_slides_positions = new_purple_positions
    
    # Update animation counter
    if slide_animation_counter < SLIDE_ANIMATION_SPEED:
        slide_animation_counter += 1

# Enhanced state representation for MAPPO
def get_mappo_state(agent, other_agent, green_agent, environment):
    y, x = agent['pos']
    oy, ox = other_agent['pos']
    gy, gx = green_agent['pos']
    
    # Find locations of goals
    fire_goal_positions = np.where(environment['fire_goal'])
    water_goal_positions = np.where(environment['water_goal'])
    fgy, fgx = fire_goal_positions[0][0], fire_goal_positions[1][0] if len(fire_goal_positions[0]) > 0 else (0, 0)
    wgy, wgx = water_goal_positions[0][0], water_goal_positions[1][0] if len(water_goal_positions[0]) > 0 else (0, 0)
    
    # Target goal based on agent color
    goal_y, goal_x = (fgy, fgx) if agent['color'] == 'red' else (wgy, wgx)
    
    # Calculate distances
    dist_to_goal = np.sqrt((y - goal_y)**2 + (x - goal_x)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_other = np.sqrt((y - oy)**2 + (x - ox)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_green = np.sqrt((y - gy)**2 + (x - gx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Comprehensive state representation
    state = [
        y / GRID_HEIGHT,                  # Normalized y position
        x / GRID_WIDTH,                   # Normalized x position
        oy / GRID_HEIGHT,                 # Other agent y position
        ox / GRID_WIDTH,                  # Other agent x position
        gy / GRID_HEIGHT,                 # Green agent y position
        gx / GRID_WIDTH,                  # Green agent x position
        dist_to_goal,                     # Distance to own goal
        dist_to_other,                    # Distance to other agent
        dist_to_green,                    # Distance to green agent
        1.0 if agent['is_pressing_button'] else 0.0,  # Is agent pressing button
        1.0 if other_agent['is_pressing_button'] else 0.0,  # Is other agent pressing button
        1.0 if yellow_buttons_active else 0.0,        # Is yellow button active
        1.0 if purple_buttons_active else 0.0,        # Is purple button active
        1.0 if agent['goal_reached'] else 0.0,        # Has agent reached goal
        agent['collected_gems'] / 2.0,                # Normalized collected gems
        other_agent['collected_gems'] / 2.0           # Other agent's collected gems
    ]
    
    return torch.tensor(state, dtype=torch.float)

# Enhanced state representation for DQN (green agent)
def get_dqn_state(green_agent, red_agent, blue_agent, environment):
    gy, gx = green_agent['pos']
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    
    # Find locations of goals
    fire_goal_positions = np.where(environment['fire_goal'])
    water_goal_positions = np.where(environment['water_goal'])
    fgy, fgx = fire_goal_positions[0][0], fire_goal_positions[1][0] if len(fire_goal_positions[0]) > 0 else (0, 0)
    wgy, wgx = water_goal_positions[0][0], water_goal
    wgy, wgx = water_goal_positions[0][0], water_goal_positions[1][0] if len(water_goal_positions[0]) > 0 else (0, 0)
    
    # Distance calculations
    dist_to_red = np.sqrt((gy - ry)**2 + (gx - rx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_blue = np.sqrt((gy - by)**2 + (gx - bx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_fire_goal = np.sqrt((gy - fgy)**2 + (gx - fgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    dist_to_water_goal = np.sqrt((gy - wgy)**2 + (gx - wgx)**2) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2)
    
    # Get button information
    yellow_button_positions = np.where(environment['yellow_button'])
    purple_button_positions = np.where(environment['purple_button'])
    
    # Find distances to buttons
    dist_to_yellow_button = min([np.sqrt((gy - y)**2 + (gx - x)**2) for y, x in zip(yellow_button_positions[0], yellow_button_positions[1])]) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2) if len(yellow_button_positions[0]) > 0 else 1.0
    dist_to_purple_button = min([np.sqrt((gy - y)**2 + (gx - x)**2) for y, x in zip(purple_button_positions[0], purple_button_positions[1])]) / np.sqrt(GRID_HEIGHT**2 + GRID_WIDTH**2) if len(purple_button_positions[0]) > 0 else 1.0
    
    # Check surroundings (4 adjacent cells)
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
    surroundings = []
    for dy, dx in directions:
        ny, nx = gy + dy, gx + dx
        if 0 <= ny < GRID_HEIGHT and 0 <= nx < GRID_WIDTH:
            if environment['wall'][ny][nx]:
                surroundings.append(1.0)  # Wall
            elif (ny, nx) == (ry, rx):
                surroundings.append(2.0)  # Red agent
            elif (ny, nx) == (by, bx):
                surroundings.append(3.0)  # Blue agent
            else:
                surroundings.append(0.0)  # Empty
        else:
            surroundings.append(1.0)  # Out of bounds (treat as wall)
    
    # Comprehensive state representation for the green agent
    state = [
        gy / GRID_HEIGHT,                  # Normalized y position
        gx / GRID_WIDTH,                   # Normalized x position
        ry / GRID_HEIGHT,                  # Red agent y position
        rx / GRID_WIDTH,                   # Red agent x position
        by / GRID_HEIGHT,                  # Blue agent y position
        bx / GRID_WIDTH,                   # Blue agent x position
        dist_to_red,                       # Distance to red agent
        dist_to_blue,                      # Distance to blue agent
        dist_to_fire_goal,                 # Distance to fire goal
        dist_to_water_goal,                # Distance to water goal
        dist_to_yellow_button,             # Distance to closest yellow button
        dist_to_purple_button,             # Distance to closest purple button
        1.0 if yellow_buttons_active else 0.0,  # Is yellow button active
        1.0 if purple_buttons_active else 0.0,  # Is purple button active
        1.0 if green_agent['is_pressing_button'] else 0.0,  # Is green agent pressing button
        1.0 if red_agent['goal_reached'] else 0.0,  # Has red agent reached goal
        1.0 if blue_agent['goal_reached'] else 0.0,  # Has blue agent reached goal
        red_agent['collected_gems'] / 2.0,  # Red agent's collected gems
        blue_agent['collected_gems'] / 2.0,  # Blue agent's collected gems
    ] + surroundings  # Add surroundings (4 values)
    
    return torch.tensor(state, dtype=torch.float)

# Get actions and update agent positions
def step_environment(red_agent, blue_agent, green_agent, environment, red_action, blue_action, green_action):
    # Process button presses first (from previous state)
    check_button_presses(red_agent, blue_agent, green_agent, environment)
    
    # Update slide positions based on button activity
    update_slides()
    
    # Update environment maps after potential slide movement
    environment = initialize_environment_maps()
    
    # Process agent movements based on actions
    red_done = move_agent(red_agent, environment, red_action)
    blue_done = move_agent(blue_agent, environment, blue_action)
    green_done = move_agent(green_agent, environment, green_action)
    
    # Check collectibles and hazards after movement
    red_reward = check_collectibles_and_hazards(red_agent, environment)
    blue_reward = check_collectibles_and_hazards(blue_agent, environment)
    green_reward = check_collectibles_and_hazards(green_agent, environment)
    
    # Check if agents reached their goals
    red_reached_goal = check_goal(red_agent, environment, 'fire_goal')
    blue_reached_goal = check_goal(blue_agent, environment, 'water_goal')
    
    # Calculate reward based on game state
    # For cooperative agents (red and blue)
    if red_agent['died'] or blue_agent['died']:
        red_reward -= 10
        blue_reward -= 10
    
    if red_reached_goal:
        red_agent['goal_reached'] = True
        red_reward += 20
        blue_reward += 10
    
    if blue_reached_goal:
        blue_agent['goal_reached'] = True
        blue_reward += 20
        red_reward += 10
    
    # Both reached goal is a big win
    if red_agent['goal_reached'] and blue_agent['goal_reached']:
        red_reward += 50
        blue_reward += 50
    
    # For adversarial agent (green)
    if red_agent['died'] or blue_agent['died']:
        green_reward += 15
    
    if red_agent['goal_reached'] or blue_agent['goal_reached']:
        green_reward -= 10
    
    if red_agent['goal_reached'] and blue_agent['goal_reached']:
        green_reward -= 30
    
    # Check if episode is done
    done = (red_agent['died'] or blue_agent['died'] or 
            (red_agent['goal_reached'] and blue_agent['goal_reached']) or
            red_done or blue_done or green_done)
    
    # Update grid representation for visualization
    update_grid_representation(red_agent, blue_agent, green_agent)
    
    return red_reward, blue_reward, green_reward, done, environment

# Check if agent is on a button
def check_button_presses(red_agent, blue_agent, green_agent, environment):
    global yellow_buttons_active, purple_buttons_active, slide_animation_counter
    
    # Reset button states
    yellow_buttons_pressed = False
    purple_buttons_pressed = False
    
    # Check each agent
    for agent in [red_agent, blue_agent, green_agent]:
        y, x = agent['pos']
        if environment['yellow_button'][y][x]:
            yellow_buttons_pressed = True
            agent['is_pressing_button'] = True
        elif environment['purple_button'][y][x]:
            purple_buttons_pressed = True
            agent['is_pressing_button'] = True
        else:
            agent['is_pressing_button'] = False
    
    # If button state changed, reset animation counter
    if yellow_buttons_pressed != yellow_buttons_active or purple_buttons_pressed != purple_buttons_active:
        slide_animation_counter = 0
    
    yellow_buttons_active = yellow_buttons_pressed
    purple_buttons_active = purple_buttons_pressed

# Move agent based on action
def move_agent(agent, environment, action):
    y, x = agent['pos']
    
    if action == 0:  # UP
        y -= 1
    elif action == 1:  # RIGHT
        x += 1
    elif action == 2:  # DOWN
        y += 1
    elif action == 3:  # LEFT
        x -= 1
    
    # Check if new position is valid
    if 0 <= y < GRID_HEIGHT and 0 <= x < GRID_WIDTH and not environment['wall'][y][x]:
        agent['pos'] = (y, x)
    
    # Check if agent is out of bounds or in an invalid position
    return not (0 <= y < GRID_HEIGHT and 0 <= x < GRID_WIDTH)

# Check for collectibles and hazards
def check_collectibles_and_hazards(agent, environment):
    y, x = agent['pos']
    reward = 0
    
    # Check for gems
    if agent['color'] == 'red' and environment['red_gems'][y][x]:
        environment['red_gems'][y][x] = False
        agent['collected_gems'] += 1
        reward += 5
        # Clear gem from grid for visualization
        grid[y][x] = "0"
    
    elif agent['color'] == 'blue' and environment['blue_gems'][y][x]:
        environment['blue_gems'][y][x] = False
        agent['collected_gems'] += 1
        reward += 5
        # Clear gem from grid for visualization
        grid[y][x] = "0"
    
    # Check for poison (specific to agent type)
    if (agent['color'] == 'red' and environment['blue_poison'][y][x]) or \
       (agent['color'] == 'blue' and environment['red_poison'][y][x]) or \
       (agent['color'] in ['red', 'blue', 'green'] and environment['green_poison'][y][x]):
        agent['died'] = True
        reward -= 20
    
    return reward

# Check if agent reached goal
def check_goal(agent, environment, goal_type):
    y, x = agent['pos']
    if environment[goal_type][y][x]:
        return True
    return False

# Update grid representation for visualization
def update_grid_representation(red_agent, blue_agent, green_agent):
    # Clear agent positions from grid
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            if grid[y][x] in ["F", "W", "G"]:
                grid[y][x] = "0"
    
    # Add agents to grid
    ry, rx = red_agent['pos']
    by, bx = blue_agent['pos']
    gy, gx = green_agent['pos']
    
    grid[ry][rx] = "F"
    grid[by][bx] = "W"
    grid[gy][gx] = "G"

# Select action using MAPPO policies (with exploration)
def select_mappo_action(policy, state, deterministic=False):
    with torch.no_grad():
        action_probs, value = policy(state)
        if deterministic:
            action = torch.argmax(action_probs).item()
        else:
            action = torch.multinomial(action_probs, 1).item()
        log_prob = torch.log(action_probs[action])
        
    return action, log_prob.item(), value.item()

# Select action using DQN (with epsilon-greedy exploration)
def select_dqn_action(state, epsilon):
    if random.random() < epsilon:
        return random.randint(0, NUM_ACTIONS - 1)
    else:
        with torch.no_grad():
            q_values = dqn(state)
            return torch.argmax(q_values).item()

# Training functions
def train_mappo(memories, policies, optimizers, gamma, epsilon_clip, value_coef, entropy_coef):
    for memory, policy, optimizer in zip(memories, policies, optimizers):
        if len(memory) == 0:
            continue
        
        # Convert lists to tensors
        states = torch.stack(memory.states)
        actions = torch.tensor(memory.actions)
        rewards = torch.tensor(memory.rewards)
        next_states = torch.stack(memory.next_states)
        dones = torch.tensor(memory.dones).float()
        old_log_probs = torch.tensor(memory.log_probs)
        old_values = torch.tensor(memory.values)
        
        # Compute returns and advantages
        returns = []
        advantages = []
        gae = 0
        
        # Calculate GAE advantage and returns
        with torch.no_grad():
            _, next_values = policy(next_states)
            next_values = next_values.squeeze(-1)
            
            for t in reversed(range(len(rewards))):
                if t == len(rewards) - 1:
                    next_value = next_values[t] * (1 - dones[t])
                else:
                    next_value = old_values[t + 1] * (1 - dones[t])
                
                delta = rewards[t] + gamma * next_value - old_values[t]
                gae = delta + gamma * GAE_LAMBDA * (1 - dones[t]) * gae
                advantages.insert(0, gae)
                returns.insert(0, gae + old_values[t])
        
        advantages = torch.tensor(advantages)
        returns = torch.tensor(returns)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO update
        for _ in range(PPO_EPOCHS):
            # Evaluate actions
            log_probs, entropies, values = [], [], []
            for i in range(len(states)):
                log_prob, entropy, value = policy.evaluate(states[i], actions[i])
                log_probs.append(log_prob)
                entropies.append(entropy)
                values.append(value)
            
            log_probs = torch.stack(log_probs)
            entropies = torch.stack(entropies)
            values = torch.cat(values)
            
            # Compute ratio and clipped ratio
            ratios = torch.exp(log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1.0 - epsilon_clip, 1.0 + epsilon_clip) * advantages
            
            # Compute losses
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values, returns)
            entropy_loss = -entropies.mean()
            
            # Total loss
            loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_loss
            
            # Update policy
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Clear memory
        memory.clear()

# Train DQN function
def train_dqn(memory, batch_size, gamma):
    if len(memory) < batch_size:
        return
    
    # Sample from replay memory
    batch = memory.sample(batch_size)
    states, actions, next_states, rewards, dones = zip(*batch)
    
    # Convert to tensors
    states = torch.stack(states)
    actions = torch.tensor(actions).unsqueeze(1)
    next_states = torch.stack(next_states)
    rewards = torch.tensor(rewards).unsqueeze(1)
    dones = torch.tensor(dones, dtype=torch.float).unsqueeze(1)
    
    # Calculate current Q values
    current_q_values = dqn(states).gather(1, actions)
    
    # Calculate target Q values
    with torch.no_grad():
        max_next_q_values = target_dqn(next_states).max(1)[0].unsqueeze(1)
        target_q_values = rewards + (gamma * max_next_q_values * (1 - dones))
    
    # Compute loss and update
    loss = F.smooth_l1_loss(current_q_values, target_q_values)
    
    optimizer_dqn.zero_grad()
    loss.backward()
    optimizer_dqn.step()

# Main game loop
def game_loop():
    global epsilon
    
    # Stats and counters
    episode = 0
    steps = 0
    total_red_reward = 0
    total_blue_reward = 0
    total_green_reward = 0
    
    # Initialize environment
    red_agent, blue_agent, green_agent, environment = reset_environment()
    
    running = True
    while running:
        # Handle Pygame events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_ESCAPE:
                    running = False
        
        # Get states
        red_state = get_mappo_state(red_agent, blue_agent, green_agent, environment)
        blue_state = get_mappo_state(blue_agent, red_agent, green_agent, environment)
        green_state = get_dqn_state(green_agent, red_agent, blue_agent, environment)
        
        # Select actions
        red_action, red_log_prob, red_value = select_mappo_action(policy_f, red_state)
        blue_action, blue_log_prob, blue_value = select_mappo_action(policy_w, blue_state)
        green_action = select_dqn_action(green_state, epsilon)
        
        # Update epsilon for exploration
        epsilon = max(EPSILON_END, EPSILON_START - steps / EPSILON_DECAY)
        
        # Take step in environment
        red_reward, blue_reward, green_reward, done, environment = step_environment(
            red_agent, blue_agent, green_agent, environment, red_action, blue_action, green_action
        )
        
        # Get next states
        next_red_state = get_mappo_state(red_agent, blue_agent, green_agent, environment)
        next_blue_state = get_mappo_state(blue_agent, red_agent, green_agent, environment)
        next_green_state = get_dqn_state(green_agent, red_agent, blue_agent, environment)
        
        # Store transitions in memory
        memory_f.push(red_state, red_action, red_reward, next_red_state, done, red_log_prob, red_value)
        memory_w.push(blue_state, blue_action, blue_reward, next_blue_state, done, blue_log_prob, blue_value)
        dqn_memory.push(green_state, green_action, next_green_state, green_reward, done)
        
        # Train agents
        train_mappo([memory_f, memory_w], [policy_f, policy_w], [optimizer_f, optimizer_w], 
                   GAMMA, PPO_EPSILON, VALUE_COEF, ENTROPY_COEF)
        train_dqn(dqn_memory, BATCH_SIZE, GAMMA)
        
        # Update target network occasionally
        steps += 1
        if steps % TARGET_UPDATE == 0:
            target_dqn.load_state_dict(dqn.state_dict())
        
        # Draw the grid
        screen.fill((0, 0, 0))
        draw_grid()
        pygame.display.flip()
        
        # Update rewards
        total_red_reward += red_reward
        total_blue_reward += blue_reward
        total_green_reward += green_reward
        
        # Check if episode is done
        if done:
            print(f"Episode {episode} completed!")
            print(f"Red Reward: {total_red_reward:.2f}, Blue Reward: {total_blue_reward:.2f}, Green Reward: {total_green_reward:.2f}")
            print(f"Red Gems: {red_agent['collected_gems']}, Blue Gems: {blue_agent['collected_gems']}")
            print(f"Red Goal: {red_agent['goal_reached']}, Blue Goal: {blue_agent['goal_reached']}")
            print(f"Red Died: {red_agent['died']}, Blue Died: {blue_agent['died']}")
            
            # Reset environment for next episode
            red_agent, blue_agent, green_agent, environment = reset_environment()
            total_red_reward = total_blue_reward = total_green_reward = 0
            episode += 1
        
        # Cap frame rate
        clock.tick(FPS)

    pygame.quit()
    sys.exit()

# Main entry point
if __name__ == "__main__":
    initialize_slides()
    game_loop()

pygame 2.6.1 (SDL 2.28.4, Python 3.12.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


NameError: name 'water_goal' is not defined