# Sanity check del notebook para JALAMAgent
Este notebook hace un sanity check de `JointActionLearningAgentModellingAgent` en todos los juegos durante al menos 10 episodes cada uno.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from games.rps import RPS
from games.mp import MP
from games.blotto import Blotto
from games.foraging import Foraging
from agents.jal_am_agent import JointActionLearningAgentModellingAgent

## Definir juegos y configuraciones
Define los juegos a probar y sus configuraciones.

In [2]:
games_to_test = [
    {"name": "RPS", "game": RPS, "config": {}},
    {"name": "MP", "game": MP, "config": {}},
    {"name": "Blotto", "game": Blotto, "config": {"S": 3, "N": 2}},
    {"name": "Foraging", "game": Foraging, "config": {"config": "Foraging-5x5-2p-1f-v3", "seed": 1}},
]

def get_game_instance(game_entry):
    if game_entry["name"] == "Blotto":
        return game_entry["game"](**game_entry["config"])
    elif game_entry["name"] == "Foraging":
        return game_entry["game"](**game_entry["config"])
    else:
        return game_entry["game"]()

## Inicializar JointActionLearningAgentModellingAgent para cada juego
Para cada juego, inicializar el `JointActionLearningAgentModellingAgent` para todos los agentes en el entorno.

In [3]:
def create_agents(game):
    # Configuración para JointActionLearningAgentModellingAgent
    # Constructor: __init__(self, juego, agente, alpha=0.1, gamma=0.99, epsilon=0.1, min_epsilon=0.01, epsilon_decay=0.995, seed=None)
    return {
        agent_id: JointActionLearningAgentModellingAgent(
            game=game, 
            agent=agent_id, 
            seed=1
        ) for agent_id in game.agents
    }

## Run episodes por game
Ejecutar al menos 10 episodes por juego y recopilar las accumulated rewards.

In [4]:
def play_episodes(game, agents, episodes=10):
    recompensas_acumuladas_totales = {agent_id: 0.0 for agent_id in game.agents}
    # Detect Foraging game para extra Debug prints
    is_foraging_game = "Foraging" in str(type(game))

    for ep in range(episodes):
        if is_foraging_game:
            print(f"[Foraging Debug] Starting Episode {ep+1}/{episodes}")
        
        game.reset()
        for agent_id in game.agents:
            agents[agent_id].reset()  # Reiniciar el estado del agente al comienzo de cada episodio

        turn = 0
        max_turns_per_episode = 200 

        while not (all(game.terminations.values()) or all(game.truncations.values())):
            if is_foraging_game:
                print(f"\n[Foraging Debug] Episode {ep+1}, Turn {turn+1}")

            current_actions = {}
            for agent_id in game.agents:
                try:
                    action = agents[agent_id].action()
                    current_actions[agent_id] = action
                    if is_foraging_game:
                        print(f"[Foraging Debug] Agent {agent_id} chose action: {action}")
                except Exception as e:
                    if is_foraging_game:
                        print(f"[Foraging Debug] ERROR in agent {agent_id}.action(): {e}")
                    import traceback
                    traceback.print_exc()
                    raise 
            
            try:
                game.step(current_actions)
            except Exception as e:
                if is_foraging_game:
                    print(f"[Foraging Debug] ERROR in game.step({current_actions}): {e}")
                import traceback
                traceback.print_exc()
                raise 

            for agent_id in game.agents:
                agents[agent_id].update()  # Llamar a update para cada agente después del paso
                recompensas_acumuladas_totales[agent_id] += game.reward(agent_id)
            
            if is_foraging_game:
                print(f"[Foraging Debug] Rewards after turn {turn+1}: { {ag: game.reward(ag) for ag in game.agents} }")
                print(f"[Foraging Debug] Terminations: {game.terminations}")
                print(f"[Foraging Debug] Truncations: {game.truncations}")

            turn += 1
            if turn >= max_turns_per_episode: 
                if is_foraging_game:
                    print(f"[Foraging Debug] Safety break: Exceeded {max_turns_per_episode} turns in episode {ep+1}.")
                for agent_id in game.agents:
                    game.truncations[agent_id] = True 
                break
        
        if is_foraging_game:
            print(f"[Foraging Debug] Episode {ep+1} finished. Total turns: {turn}")
            print(f"[Foraging Debug] Final cumulative rewards for episode: { {k:v for k,v in recompensas_acumuladas_totales.items()} }")

    return recompensas_acumuladas_totales

## Show results
Mostrar las accumulated rewards para cada Agent en cada game después de 10 episodes.

In [5]:
resultados = {}
for entry in games_to_test:
    print(f"\nTesting {entry['name']} with JointActionLearningAgentModellingAgent...")
    game = get_game_instance(entry)
    agents = create_agents(game)
    recompensas = play_episodes(game, agents, episodes=10)
    resultados[entry['name']] = recompensas
    for agent, recompensa in recompensas.items():
        print(f"Agent {agent}: Total reward in 10 episodes: {recompensa}")


Testing RPS with JointActionLearningAgentModellingAgent...
Agent agent_0: Total reward in 10 episodes: 0.0
Agent agent_1: Total reward in 10 episodes: 0.0

Testing MP with JointActionLearningAgentModellingAgent...
Agent agent_0: Total reward in 10 episodes: 4.0
Agent agent_1: Total reward in 10 episodes: -4.0

Testing Blotto with JointActionLearningAgentModellingAgent...
Agent agent_0: Total reward in 10 episodes: 0.0
Agent agent_1: Total reward in 10 episodes: 0.0

Testing Foraging with JointActionLearningAgentModellingAgent...
[Foraging Debug] Starting Episode 1/10

[Foraging Debug] Episode 1, Turn 1
[Foraging Debug] Agent agent_0 chose action: 4
[Foraging Debug] Agent agent_1 chose action: 1
[Foraging Debug] Rewards after turn 1: {'agent_0': 0, 'agent_1': 0}
[Foraging Debug] Terminations: {'agent_0': False, 'agent_1': False}
[Foraging Debug] Truncations: {'agent_0': False, 'agent_1': False}

[Foraging Debug] Episode 1, Turn 2
[Foraging Debug] Agent agent_0 chose action: 0
[Foraging

  logger.warn(
