# Human Player Test for State Punishment Beta

This notebook allows you to play the State Punishment Beta game as a human player. It provides:

1. **Environment State Visualization**: Shows the current state of all agents in the environment
2. **Punishment Level Display**: Shows the current punishment level and voting statistics
3. **Agent Metrics**: Displays scalar values like rewards, social harm, individual scores, and encounters for each agent
4. **Interactive Gameplay**: Step-by-step control with visual feedback

## How to Play
- Use WASD keys for movement (W=Up, A=Left, S=Down, D=Right)
- Use number keys for voting actions (4=Vote Increase, 5=Vote Decrease)
- Use 6 for No Action
- Type 'quit' to exit

## Game Mechanics
- Collect resources (A, B, C, D, E) to gain points
- Some resources are taboo and will cause punishment
- Vote to increase or decrease the punishment level
- Social harm affects all agents when taboo resources are collected


In [None]:
# Setup and imports
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
from pathlib import Path

# Add the sorrel module to path
module_path = os.path.abspath('../../..')
if module_path not in sys.path:
    sys.path.insert(0, module_path)

# Import sorrel components
from sorrel.examples.state_punishment_beta.env import StatePunishmentEnv
from sorrel.examples.state_punishment_beta.world import StatePunishmentWorld
from sorrel.examples.state_punishment_beta.agents import StatePunishmentAgent
from sorrel.examples.state_punishment_beta.entities import EmptyEntity
from sorrel.models.human_player import HumanPlayer
from sorrel.action.action_spec import ActionSpec
from sorrel.observation.observation_spec import OneHotObservationSpec
from sorrel.utils.visualization import plot, render_sprite

print("Imports successful!")


In [None]:
# Configuration for the game
config = {
    "experiment": {
        "epochs": 1,
        "max_turns": 50,
        "record_period": 50,
        "run_name": "human_player_test",
        "num_agents": 3,
        "initial_resources": 15,
    },
    "model": {
        "agent_vision_radius": 2,
        "epsilon": 0.0,  # No exploration for human player
        "epsilon_decay": 0.001,
        "full_view": True,
        "layer_size": 128,
        "n_frames": 3,
        "n_step": 3,
        "sync_freq": 100,
        "model_update_freq": 4,
        "batch_size": 64,
        "memory_size": 512,
        "LR": 0.00025,
        "TAU": 0.001,
        "GAMMA": 0.99,
        "n_quantiles": 8,
        "device": "cpu",
    },
    "world": {
        "height": 10,
        "width": 10,
        "a_value": 3.0,
        "b_value": 7.0,
        "c_value": 2.0,
        "d_value": -2.0,
        "e_value": 1.0,
        "spawn_prob": 0.05,
        "respawn_prob": 0.02,
        "init_punishment_prob": 0.1,
        "punishment_magnitude": -10.0,
        "change_per_vote": 0.2,
        "taboo_resources": ["A", "B", "C", "D", "E"],
        "entity_spawn_probs": {
            "A": 0.2, "B": 0.2, "C": 0.2, "D": 0.2, "E": 0.2
        }
    },
    "use_composite_views": False,
    "use_composite_actions": False,
    "use_multi_env_composite": False,
}

print("Configuration loaded!")


In [None]:
# Create custom human player model for state punishment
class StatePunishmentHumanPlayer(HumanPlayer):
    """Custom human player for state punishment with proper action mapping."""
    
    def take_action(self, state: np.ndarray):
        """Override take_action to handle state punishment specific actions."""
        if self.show:
            clear_output(wait=True)
            
            # Reshape the input to return to the original image
            state = state[:, self.SLICE:]
            state = state.reshape(
                (
                    -1,
                    self.input_size[0] * self.tile_size,
                    self.input_size[1] * self.tile_size,
                    self.num_channels,
                )
            )
            state = np.array(state, dtype=int)
            state_ = []
            for i in range(state.shape[0]):
                state_.append(state[i, :, :, :])
            plot(state_)

        action = None
        num_retries = 0
        while not isinstance(action, int):
            action_ = input("Select Action: ")
            
            # Movement actions (WASD)
            if action_ in ["w", "a", "s", "d"]:
                if action_ == "w":
                    action = 0  # Up
                elif action_ == "s":
                    action = 1  # Down
                elif action_ == "a":
                    action = 2  # Left
                elif action_ == "d":
                    action = 3  # Right
            # Voting actions
            elif action_ == "4":
                action = 4  # Vote increase
            elif action_ == "5":
                action = 5  # Vote decrease
            elif action_ == "6":
                action = 6  # No action
            # Direct action numbers
            elif action_ in [str(act) for act in self.action_list]:
                action = int(action_)
            elif action_ == "quit":
                raise KeyboardInterrupt("Quitting...")
            else:
                num_retries += 1
                if num_retries > 5:
                    raise KeyboardInterrupt("Too many invalid inputs. Quitting...")
                print("Please try again. Possible actions:")
                print("Movement: w=Up, s=Down, a=Left, d=Right")
                print("Voting: 4=Increase punishment, 5=Decrease punishment")
                print("Other: 6=No action, quit=Exit")
                print(f"Or enter action number: {list(self.action_list)}")

        return action

print("Custom human player created!")


In [None]:
# Create custom human agent that bypasses memory stacking
class StatePunishmentHumanAgent(StatePunishmentAgent):
    """Custom human agent for state punishment that bypasses memory stacking."""
    
    def get_action(self, state: np.ndarray) -> int:
        """Override get_action to bypass memory stacking for human player."""
        # For human player, we don't need memory stacking
        # Just pass the state directly to the model
        action = self.model.take_action(state)
        return action

    def add_memory(self, state: np.ndarray, action: int, reward: float, done: bool) -> None:
        """Override add_memory to handle dimension mismatch for human player."""
        # For human player, we don't need to store experiences in memory
        # The human player doesn't learn from experience, so we can skip this
        pass

print("Custom human agent created!")


In [None]:
# Create the environment and world
world = StatePunishmentWorld(config=config, default_entity=EmptyEntity())
env = StatePunishmentEnv(world, config)

print(f"Environment created with {config['world']['height']}x{config['world']['width']} grid")
print(f"Number of agents: {config['experiment']['num_agents']}")
print(f"Taboo resources: {config['world']['taboo_resources']}")
print(f"Initial punishment probability: {config['world']['init_punishment_prob']}")


In [None]:
# Replace all agents with human players
for i, agent in enumerate(env.agents):
    # Create observation spec
    entity_list = ["EmptyEntity", "Wall", "A", "B", "C", "D", "E", "StatePunishmentAgent"]
    observation_spec = OneHotObservationSpec(
        entity_list,
        full_view=config["model"]["full_view"],
        vision_radius=config["model"]["agent_vision_radius"],
        env_dims=(config["world"]["height"], config["world"]["width"]) if config["model"]["full_view"] else None,
    )
    
    # Create action spec
    action_names = ["up", "down", "left", "right", "vote_increase", "vote_decrease", "noop"]
    action_spec = ActionSpec(action_names)
    
    # Create human player model
    human_model = StatePunishmentHumanPlayer(
        input_size=observation_spec.input_size,
        action_space=action_spec.n_actions,
        memory_size=1,
        show=True
    )
    
    # Create human agent
    human_agent = StatePunishmentHumanAgent(
        observation_spec=observation_spec,
        action_spec=action_spec,
        model=human_model,
        agent_id=i,
        use_composite_views=False,
        use_composite_actions=False,
        use_multi_env_composite=False,
    )
    
    # Replace the original agent
    env.agents[i] = human_agent

print(f"Replaced {len(env.agents)} agents with human players!")


In [None]:
# Function to display environment state and metrics
def display_game_state(env, turn, step_info=None):
    """Display the current game state with all agents and metrics."""
    clear_output(wait=True)
    
    # Create visualization of the world
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'State Punishment Beta - Turn {turn}', fontsize=16, fontweight='bold')
    
    # Plot 1: Full environment view
    ax1 = axes[0, 0]
    ax1.set_title('Environment State (All Agents)', fontweight='bold')
    
    # Create a visual representation of the world
    world_map = np.zeros((env.world.height, env.world.width, 3))
    
    # Color coding:
    # Empty: White, Wall: Black, Agent: Blue, Resources: Different colors
    for y in range(env.world.height):
        for x in range(env.world.width):
            entity = env.world.map[y, x, 0]
            if hasattr(entity, 'kind'):
                if entity.kind == 'Wall':
                    world_map[y, x] = [0, 0, 0]  # Black
                elif entity.kind == 'StatePunishmentAgent':
                    world_map[y, x] = [0, 0, 1]  # Blue
                elif entity.kind == 'A':
                    world_map[y, x] = [1, 0, 0]  # Red
                elif entity.kind == 'B':
                    world_map[y, x] = [0, 1, 0]  # Green
                elif entity.kind == 'C':
                    world_map[y, x] = [1, 1, 0]  # Yellow
                elif entity.kind == 'D':
                    world_map[y, x] = [1, 0, 1]  # Magenta
                elif entity.kind == 'E':
                    world_map[y, x] = [0, 1, 1]  # Cyan
                else:
                    world_map[y, x] = [1, 1, 1]  # White (empty)
    
    ax1.imshow(world_map)
    ax1.set_xlabel('X Position')
    ax1.set_ylabel('Y Position')
    ax1.grid(True, alpha=0.3)
    
    # Add agent labels
    for i, agent in enumerate(env.agents):
        if hasattr(agent, 'location') and agent.location:
            y, x = agent.location[0], agent.location[1]
            ax1.text(x, y, f'A{i}', ha='center', va='center', 
                   color='white', fontweight='bold', fontsize=12)
    
    # Plot 2: Punishment level and voting info
    ax2 = axes[0, 1]
    ax2.set_title('Punishment System', fontweight='bold')
    
    # Current punishment level
    current_prob = env.world.state_system.prob
    ax2.bar(['Current Punishment\\nProbability'], [current_prob], 
           color='red' if current_prob > 0.5 else 'orange' if current_prob > 0.2 else 'green')
    ax2.set_ylim(0, 1)
    ax2.set_ylabel('Probability')
    
    # Add text info
    vote_stats = env.world.state_system.get_epoch_vote_stats()
    ax2.text(0, 0.8, f'Votes Up: {vote_stats["vote_up"]}', fontsize=10)
    ax2.text(0, 0.7, f'Votes Down: {vote_stats["vote_down"]}', fontsize=10)
    ax2.text(0, 0.6, f'Total Votes: {vote_stats["total_votes"]}', fontsize=10)
    
    # Plot 3: Agent individual scores
    ax3 = axes[1, 0]
    ax3.set_title('Agent Individual Scores', fontweight='bold')
    
    agent_scores = [agent.individual_score for agent in env.agents]
    agent_labels = [f'Agent {i}' for i in range(len(env.agents))]
    bars = ax3.bar(agent_labels, agent_scores, color=['blue', 'green', 'red'][:len(agent_scores)])
    ax3.set_ylabel('Score')
    
    # Add value labels on bars
    for bar, score in zip(bars, agent_scores):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                f'{score:.1f}', ha='center', va='bottom')
    
    # Plot 4: Agent encounters and social harm
    ax4 = axes[1, 1]
    ax4.set_title('Agent Encounters & Social Harm', fontweight='bold')
    
    # Create a table-like display
    ax4.axis('off')
    
    # Prepare data for display
    table_data = []
    table_data.append(['Agent', 'Social Harm', 'A', 'B', 'C', 'D', 'E'])
    
    for i, agent in enumerate(env.agents):
        social_harm = env.world.get_social_harm(i)
        encounters = agent.encounters
        row = [f'Agent {i}', f'{social_harm:.1f}']
        for resource in ['a', 'b', 'c', 'd', 'e']:
            row.append(str(encounters.get(resource, 0)))
        table_data.append(row)
    
    # Create table
    table = ax4.table(cellText=table_data[1:], colLabels=table_data[0], 
                     cellLoc='center', loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2)
    
    # Style the table
    for i in range(len(table_data[0])):
        table[(0, i)].set_facecolor('#40466e')
        table[(0, i)].set_text_props(weight='bold', color='white')
    
    plt.tight_layout()
    plt.show()
    
    # Print additional step information if provided
    if step_info:
        print(f"\\nStep Information:")
        print(f"Action taken: {step_info.get('action', 'N/A')}")
        print(f"Reward received: {step_info.get('reward', 0):.2f}")
        print(f"Punishment applied: {step_info.get('punishment', 0):.2f}")
        print(f"Social harm received: {step_info.get('social_harm', 0):.2f}")

print("Display function created!")


In [None]:
# Main game loop
def run_human_game():
    """Run the human player game with step-by-step visualization."""
    print("Starting State Punishment Beta Human Player Game!")
    print("\\nControls:")
    print("Movement: w=Up, s=Down, a=Left, d=Right")
    print("Voting: 4=Increase punishment, 5=Decrease punishment")
    print("Other: 6=No action, quit=Exit")
    print("\\nPress Enter to start...")
    input()
    
    # Reset environment
    env.reset()
    
    turn = 0
    done = False
    
    try:
        while not done and turn < config['experiment']['max_turns']:
            turn += 1
            
            # Display current state
            display_game_state(env, turn)
            
            # Let each agent (human player) take a turn
            for agent_idx, agent in enumerate(env.agents):
                print(f"\\n--- Agent {agent_idx}'s Turn ---")
                
                # Get current state for this agent
                state = agent.pov(env.world)
                
                # Get action from human player
                action = agent.get_action(state)
                
                # Execute action and get reward
                reward = agent.act(env.world, action)
                
                # Update individual score
                agent.individual_score += reward
                
                # Collect step information
                step_info = {
                    'action': action,
                    'reward': reward,
                    'punishment': 0,  # Will be calculated by the agent
                    'social_harm': env.world.get_social_harm(agent_idx)
                }
                
                # Display step information
                print(f"Action: {action} ({['Up', 'Down', 'Left', 'Right', 'Vote Increase', 'Vote Decrease', 'No Action'][action]})")
                print(f"Reward: {reward:.2f}")
                print(f"Individual Score: {agent.individual_score:.2f}")
                print(f"Social Harm: {step_info['social_harm']:.2f}")
                
                # Check if done
                if env.world.is_done:
                    done = True
                    break
            
            # Spawn new resources after all agents have moved
            env._spawn_resources()
            
            # Record punishment level for this turn
            env.world.record_punishment_level()
            
            # Small delay for better visualization
            time.sleep(0.5)
        
        # Final display
        display_game_state(env, turn)
        print("\\n=== GAME OVER ===")
        print(f"Final scores:")
        for i, agent in enumerate(env.agents):
            print(f"Agent {i}: {agent.individual_score:.2f}")
        
        # Final punishment statistics
        vote_stats = env.world.state_system.get_epoch_vote_stats()
        print(f"\\nFinal punishment probability: {env.world.state_system.prob:.3f}")
        print(f"Total votes cast: {vote_stats['total_votes']}")
        print(f"Votes to increase: {vote_stats['vote_up']}")
        print(f"Votes to decrease: {vote_stats['vote_down']}")
        
    except KeyboardInterrupt:
        print("\\nGame interrupted by user.")
    except Exception as e:
        print(f"\\nError during game: {e}")
        import traceback
        traceback.print_exc()

print("Game loop function created!")


In [None]:
# Run the game
run_human_game()


## Game Analysis

After playing the game, you can analyze the results:

1. **Punishment Dynamics**: How did the punishment level change throughout the game?
2. **Agent Behavior**: Which agents collected taboo resources and how did this affect others?
3. **Voting Patterns**: How did the voting behavior influence the punishment system?
4. **Social Harm**: How did social harm accumulate and affect agent scores?

### Key Metrics to Consider:
- **Individual Scores**: Each agent's total reward
- **Social Harm**: Negative impact on other agents from taboo resource collection
- **Punishment Probability**: Current level of punishment in the system
- **Voting Statistics**: How agents voted to change punishment levels
- **Resource Encounters**: Which resources each agent collected


In [None]:
# Optional: Analyze game results
def analyze_game_results(env):
    """Analyze the results of the completed game."""
    print("=== GAME ANALYSIS ===")
    
    # Agent performance
    print("\\nAgent Performance:")
    for i, agent in enumerate(env.agents):
        print(f"Agent {i}:")
        print(f"  Individual Score: {agent.individual_score:.2f}")
        print(f"  Social Harm Received: {env.world.get_social_harm(i):.2f}")
        print(f"  Resource Encounters: {dict(agent.encounters)}")
        print(f"  Vote History: {agent.vote_history}")
    
    # Punishment system analysis
    print("\\nPunishment System Analysis:")
    print(f"Initial Probability: {env.world.state_system.init_prob:.3f}")
    print(f"Final Probability: {env.world.state_system.prob:.3f}")
    print(f"Probability Change: {env.world.state_system.prob - env.world.state_system.init_prob:.3f}")
    
    vote_stats = env.world.state_system.get_epoch_vote_stats()
    print(f"Total Votes: {vote_stats['total_votes']}")
    print(f"Votes to Increase: {vote_stats['vote_up']}")
    print(f"Votes to Decrease: {vote_stats['vote_down']}")
    
    if vote_stats['total_votes'] > 0:
        print(f"Vote Ratio (Increase/Total): {vote_stats['vote_up'] / vote_stats['total_votes']:.3f}")
    
    # Transgression analysis
    print("\\nTransgression Analysis:")
    transgression_stats = env.world.state_system.get_transgression_stats()
    for resource, transgressions in transgression_stats.items():
        if 'transgressions' in resource:
            print(f"{resource}: {transgressions}")
    
    # Resource value analysis
    print("\\nResource Values:")
    print(f"A: {config['world']['a_value']} (Taboo: {'A' in config['world']['taboo_resources']})")
    print(f"B: {config['world']['b_value']} (Taboo: {'B' in config['world']['taboo_resources']})")
    print(f"C: {config['world']['c_value']} (Taboo: {'C' in config['world']['taboo_resources']})")
    print(f"D: {config['world']['d_value']} (Taboo: {'D' in config['world']['taboo_resources']})")
    print(f"E: {config['world']['e_value']} (Taboo: {'E' in config['world']['taboo_resources']})")

# Uncomment to run analysis after the game
# analyze_game_results(env)
