# Human Player Multi-World Test

This notebook allows you to play the state punishment game with multiple agents using the same multi-world visualization system.

## Features:
- Interactive human player control
- Multi-world visualization (2×3 grid layout)
- Real-time punishment level display
- Step-by-step visualization generation


In [6]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display, clear_output
import time

from sorrel.examples.state_punishment_beta_copy.entities import EmptyEntity
from sorrel.examples.state_punishment_beta_copy.env import MultiAgentStatePunishmentEnv, StatePunishmentEnv
from sorrel.examples.state_punishment_beta_copy.world import StatePunishmentWorld
from sorrel.examples.state_punishment_beta_copy.agents import StatePunishmentAgent
from sorrel.observation.observation_spec import OneHotObservationSpec
from sorrel.action.action_spec import ActionSpec
from sorrel.models.pytorch import PyTorchIQN
from sorrel.models.human_player import HumanPlayer


In [7]:
def create_multi_world_config(num_agents=3):
    """Create configuration for multi-world human player test using same settings as main.py."""
    from sorrel.examples.state_punishment_beta_copy.config import create_config
    
    # Use the same configuration as main.py
    config = create_config(
        num_agents=num_agents,
        epochs=1,  # Just for human play
        use_composite_views=False,
        use_composite_actions=False,
        use_multi_env_composite=False,
        simple_foraging=True,  # Allow voting to change punishment level
        use_random_policy=False,
        fixed_punishment_level=0.2,
        map_size=10,
        num_resources=20,
        learning_rate=0.00025,
        batch_size=64,
        memory_size=1024,
    )
    
    # Override some settings for human play
    config["experiment"]["max_turns"] = 50  # Longer episodes for human play
    config["experiment"]["record_period"] = 1  # Record every turn
    config["model"]["seed"] = 42  # Add seed for reproducibility
    
    return config


In [8]:
class MultiWorldHumanPlayer:
    """Human player for multi-world state punishment game with visualization."""
    
    def __init__(self, num_agents=3):
        self.num_agents = num_agents
        self.config = create_multi_world_config(num_agents)
        self.setup_environments()
        self.current_agent = 0  # Which agent the human controls
        self.turn_count = 0
        self.visualizations = []  # Store visualizations for each turn
        
    def setup_environments(self):
        """Set up the multi-agent environment with human player using proper environment setup."""
        from sorrel.examples.state_punishment_beta_copy.environment_setup import (
            create_shared_state_system,
            create_shared_social_harm,
            create_individual_environments,
            create_multi_agent_environment
        )
        
        # Create shared state system and social harm (same as main.py)
        shared_state_system = create_shared_state_system(
            self.config, 
            simple_foraging=False, 
            fixed_punishment_level=0.2
        )
        shared_social_harm = create_shared_social_harm(self.num_agents)
        
        # Create individual environments (one agent per world)
        individual_envs = create_individual_environments(
            self.config,
            num_agents=self.num_agents,
            simple_foraging=True,  # Allow voting to change punishment level
            use_random_policy=False  # We'll override the first agent with human player
        )
        
        # Replace the first agent with human player
        first_env = individual_envs[0]
        first_agent = first_env.agents[0]
        
        # Get the location of the original agent before replacing
        original_location = first_agent.location
        
        # Create human player model
        human_model = HumanPlayer(
            input_size=first_agent.observation_spec.input_size,
            action_space=first_agent.action_spec.n_actions,
            memory_size=self.config["model"]["memory_size"]
        )
        
        # Create new human player agent with custom get_action method
        class HumanPlayerAgent(StatePunishmentAgent):
            """StatePunishmentAgent with HumanPlayer model that overrides get_action."""
            
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)
                # Override kind to match StatePunishmentAgent for observation compatibility
                self.kind = "StatePunishmentAgent"
            
            def get_action(self, state: np.ndarray) -> int:
                """Override get_action to work with HumanPlayer model."""
                if self.use_random_policy:
                    return np.random.randint(0, self.action_spec.n_actions)
                
                # For HumanPlayer, we need to format the state differently
                # HumanPlayer expects the state to be reshaped for visualization
                model_input = state.reshape(1, -1)
                action = self.model.take_action(model_input)
                return action
            
            def add_memory(self, state: np.ndarray, action: int, reward: float, done: bool) -> None:
                """Override add_memory to work with HumanPlayer model.
                
                HumanPlayer has a different memory structure that's incompatible
                with the standard state format, so we skip memory addition for human players.
                """
                # Skip memory addition for human players since they don't need training
                pass
        
        human_agent = HumanPlayerAgent(
            observation_spec=first_agent.observation_spec,
            action_spec=first_agent.action_spec,
            model=human_model,
            agent_id=0,
            simple_foraging=False  # Allow voting to change punishment level
        )
        
        # Place the human agent at the same location as the original agent
        first_env.world.add(original_location, human_agent)
        
        # Replace the agent in the environment
        first_env.agents[0] = human_agent
        
        # Create multi-agent environment
        self.multi_agent_env = create_multi_agent_environment(
            individual_envs=individual_envs,
            shared_state_system=shared_state_system,
            shared_social_harm=shared_social_harm
        )
    
    def generate_visualization(self):
        """Generate multi-world visualization for current state."""
        from sorrel.utils.visualization import render_sprite, image_from_array
        from PIL import Image, ImageDraw, ImageFont
        
        # Render each individual world
        world_images = []
        for env in self.multi_agent_env.individual_envs:
            full_sprite = render_sprite(env.world)
            world_img = image_from_array(full_sprite)
            world_images.append(world_img)
        
        # Create 2x3 grid layout
        rows, cols = 2, 3
        
        # Get dimensions of individual images
        if world_images:
            img_width, img_height = world_images[0].size
        else:
            return None
        
        # Create combined image
        combined_width = cols * img_width
        combined_height = rows * img_height
        combined_img = Image.new('RGB', (combined_width, combined_height), (255, 255, 255))
        
        # Place each world image in the grid
        for i, world_img in enumerate(world_images):
            if i >= rows * cols:
                break
                
            row = i // cols
            col = i % cols
            
            x = col * img_width
            y = row * img_height
            
            combined_img.paste(world_img, (x, y))
        
        # Add labels
        draw = ImageDraw.Draw(combined_img)
        try:
            font = ImageFont.truetype("arial.ttf", 16)
        except:
            font = ImageFont.load_default()
        
        # Add world labels
        for i, world_img in enumerate(world_images):
            if i >= rows * cols:
                break
                
            row = i // cols
            col = i % cols
            
            x = col * img_width + 5
            y = row * img_height + 5
            
            # World label with player type
            if i == 0:
                label = f"World {i+1} (HUMAN)"
                color = (0, 0, 255)  # Blue for human
            else:
                label = f"World {i+1} (AI)"
                color = (0, 0, 0)  # Black for AI
            
            draw.text((x, y), label, fill=color, font=font)
        
        # Add global punishment level in bottom right corner
        punishment_level = self.multi_agent_env.shared_state_system.prob
        punishment_text = f"Punishment Level: {punishment_level:.3f}"
        text_x = combined_width - 200
        text_y = combined_height - 30
        draw.text((text_x, text_y), punishment_text, fill=(255, 0, 0), font=font)
        
        # Add turn counter
        turn_text = f"Turn: {self.turn_count}"
        draw.text((5, combined_height - 30), turn_text, fill=(0, 0, 0), font=font)
        
        return combined_img
    
    def display_current_state(self):
        """Display the current multi-world state."""
        img = self.generate_visualization()
        if img:
            plt.figure(figsize=(15, 10))
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"Multi-World State - Turn {self.turn_count}")
            plt.show()
            
            # Store visualization
            self.visualizations.append(img.copy())
    
    def show_action_guide(self):
        """Display the action guide for the human player."""
        print("\n" + "="*60)
        print("🎮 HUMAN PLAYER ACTION GUIDE")
        print("="*60)
        print("MOVEMENT ACTIONS:")
        print("  0: Up (W key)")
        print("  1: Down (S key)")
        print("  2: Left (A key)")
        print("  3: Right (D key)")
        print("\nVOTING ACTIONS:")
        print("  4: No operation (do nothing)")
        print("  5: Vote to INCREASE punishment level")
        print("  6: Vote to DECREASE punishment level")
        print("\nCONTROLS:")
        print("  • Use WASD keys OR numbers 0-6")
        print("  • Type 'quit' to exit")
        print("  • Current punishment level affects all agents")
        print("="*60)
    
    def show_current_state_info(self):
        """Display current social harm and punishment level."""
        print("\n" + "="*50)
        print("📊 CURRENT STATE INFO")
        print("="*50)
        
        # Social harm
        print("🔴 Social Harm:")
        for agent_id, harm in self.multi_agent_env.shared_social_harm.items():
            print(f"  Agent {agent_id + 1}: {harm:.3f}")
        
        # Punishment level
        print(f"\n⚖️  Punishment Level: {self.multi_agent_env.shared_state_system.prob:.3f}")
        
        # Agent scores
        print("\n💰 Agent Scores:")
        for i, env in enumerate(self.multi_agent_env.individual_envs):
            agent = env.agents[0]
            print(f"  Agent {i + 1}: {agent.individual_score:.3f}")
        
        print("="*50)
    
    def play_turn(self):
        """Play one turn of the game."""
        # Display current state
        self.display_current_state()
        
        print(f"\nTURN {self.turn_count + 1} - HUMAN PLAYER'S TURN")
        print("="*50)
        print("Current punishment level:", f"{self.multi_agent_env.shared_state_system.prob:.3f}")
        
        # Show action guide on first turn
        if self.turn_count == 0:
            self.show_action_guide()
        
        print("\nThe HumanPlayer model will now prompt you for input...")
        
        # Capture state before the turn
        social_harm_before = self.multi_agent_env.shared_social_harm.copy()
        punishment_before = self.multi_agent_env.shared_state_system.prob
        
        # Execute the turn - HumanPlayer will handle input automatically
        self.multi_agent_env.take_turn()
        
        # Capture state after the turn
        social_harm_after = self.multi_agent_env.shared_social_harm.copy()
        punishment_after = self.multi_agent_env.shared_state_system.prob
        agent_rewards = {}
        
        # Get individual agent rewards and scores
        for i, env in enumerate(self.multi_agent_env.individual_envs):
            agent = env.agents[0]
            agent_rewards[f"Agent {i+1}"] = {
                "individual_score": agent.individual_score,
                "last_action": agent.last_action if hasattr(agent, 'last_action') else "N/A"
            }
        
        # Display turn results
        print(f"\n📊 TURN {self.turn_count + 1} RESULTS:")
        print("="*50)
        
        # Social harm changes
        print("🔴 Social Harm Changes:")
        for agent_id, harm_after in social_harm_after.items():
            harm_before = social_harm_before.get(agent_id, 0.0)
            harm_change = harm_after - harm_before
            if harm_change != 0:
                print(f"  Agent {agent_id + 1}: {harm_before:.3f} → {harm_after:.3f} (Δ{harm_change:+.3f})")
            else:
                print(f"  Agent {agent_id + 1}: {harm_after:.3f} (no change)")
        
        # Agent rewards and actions
        print("\n💰 Agent Rewards & Actions:")
        for agent_name, info in agent_rewards.items():
            action_names = ["Up", "Down", "Left", "Right", "Noop", "Vote+", "Vote-"]
            action_name = action_names[info["last_action"]] if isinstance(info["last_action"], int) and 0 <= info["last_action"] < 7 else "Unknown"
            print(f"  {agent_name}: Score={info['individual_score']:.3f}, Action={action_name}")
        
        # Updated punishment level
        punishment_change = punishment_after - punishment_before
        print(f"\n⚖️  Punishment Level: {punishment_before:.3f} → {punishment_after:.3f} (Δ{punishment_change:+.3f})")
        
        self.turn_count += 1
        
        # Check if game is over
        if self.turn_count >= self.config["experiment"]["max_turns"]:
            print("\nGame Over! Maximum turns reached.")
            return False
        
        return True
    
    def play_game(self):
        """Play the complete game."""
        print(f"🎮 Starting Multi-World Human Player Game with {self.num_agents} agents!")
        print(f"You control World 1 (HUMAN), other worlds are controlled by AI.")
        print(f"Maximum turns: {self.config['experiment']['max_turns']}")
        print("\n💡 TIP: Type 'help' during the game to see the action guide again!")
        
        while self.play_turn():
            pass
        
        print("\nFinal Results:")
        print(f"Total turns played: {self.turn_count}")
        print(f"Final punishment level: {self.multi_agent_env.shared_state_system.prob:.3f}")
        
        # Display final state
        self.display_current_state()
        
        return self.visualizations


In [9]:
# Show Action Guide
def show_action_guide():
    """Display the action guide for reference."""
    print("\n" + "="*60)
    print("🎮 HUMAN PLAYER ACTION GUIDE")
    print("="*60)
    print("MOVEMENT ACTIONS:")
    print("  0: Up (W key)")
    print("  1: Down (S key)")
    print("  2: Left (A key)")
    print("  3: Right (D key)")
    print("\nVOTING ACTIONS:")
    print("  4: No operation (do nothing)")
    print("  5: Vote to INCREASE punishment level")
    print("  6: Vote to DECREASE punishment level")
    print("\nCONTROLS:")
    print("  • Use WASD keys OR numbers 0-6")
    print("  • Type 'quit' to exit")
    print("  • Current punishment level affects all agents")
    print("="*60)

# Show Current State Info (only works after game is created)
def show_current_state_info():
    """Display current social harm, punishment level, and agent scores."""
    if 'game' in globals():
        game.show_current_state_info()
    else:
        print("❌ No game created yet. Create a game first with: game = MultiWorldHumanPlayer(num_agents=3)")

# Call these functions anytime to see the guides
show_action_guide()



🎮 HUMAN PLAYER ACTION GUIDE
MOVEMENT ACTIONS:
  0: Up (W key)
  1: Down (S key)
  2: Left (A key)
  3: Right (D key)

VOTING ACTIONS:
  4: No operation (do nothing)
  5: Vote to INCREASE punishment level
  6: Vote to DECREASE punishment level

CONTROLS:
  • Use WASD keys OR numbers 0-6
  • Type 'quit' to exit
  • Current punishment level affects all agents


In [10]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
# Create and play the game
game = MultiWorldHumanPlayer(num_agents=2)
visualizations = game.play_game()


KeyboardInterrupt: Quitting...

In [None]:
# Optional: Save visualizations as GIF
def save_visualizations_as_gif(visualizations, filename="human_player_game.gif"):
    """Save all visualizations as an animated GIF."""
    if visualizations:
        visualizations[0].save(
            filename,
            save_all=True,
            append_images=visualizations[1:],
            duration=1000,  # 1 second per frame
            loop=0
        )
        print(f"Game saved as {filename}")
    else:
        print("No visualizations to save.")

# Uncomment to save the game as GIF
# save_visualizations_as_gif(visualizations, "human_player_multi_world_game.gif")


In [7]:
# Test with different numbers of agents
print("Testing with 5 agents...")
game_5 = MultiWorldHumanPlayer(num_agents=5)
visualizations_5 = game_5.play_game()


KeyboardInterrupt: Quitting...

In [None]:
print("Testing with 6 agents...")
game_6 = MultiWorldHumanPlayer(num_agents=6)
visualizations_6 = game_6.play_game()
