# Layer Visualization Test for Stag Hunt Environment

This notebook demonstrates how to visualize each layer of the Stag Hunt environment separately at each timestep. This is useful for:
- Understanding how entities are distributed across layers
- Debugging layer-specific issues
- Seeing how agents, resources, and beams interact across layers
- Comparing terrain, dynamic, and beam layers side by side


In [None]:
import hydra
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf

# sorrel imports
from sorrel.examples.staghunt.agents_v2 import StagHuntAgent
from sorrel.examples.staghunt.entities import Empty, entity_list
from sorrel.examples.staghunt.env import StagHuntEnv
from sorrel.examples.staghunt.world import StagHuntWorld
from sorrel.action.action_spec import ActionSpec
from sorrel.models.human_player import HumanPlayer, HumanObservation
from sorrel.utils.visualization import render_sprite, image_from_array


In [None]:
%load_ext autoreload
%autoreload 2


### Layer visualization functions


In [None]:
def visualize_layers_separately(world, title_prefix=""):
    """Visualize each layer separately with detailed information."""
    
    # Render layers
    layers = render_sprite(world, tile_size=[32, 32])
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    layer_names = ['Terrain Layer', 'Dynamic Layer', 'Beam Layer']
    layer_descriptions = [
        'Walls, Spawn points, Sand tiles',
        'Agents, Resources (Stag/Hare)', 
        'Interaction beams, Effects'
    ]
    
    # Display each layer
    for i, (layer, name, desc) in enumerate(zip(layers, layer_names, layer_descriptions)):
        row = i // 2
        col = i % 2
        
        axes[row, col].imshow(layer)
        axes[row, col].set_title(f"{name}\n{desc}")
        axes[row, col].set_xlabel('X coordinate')
        axes[row, col].set_ylabel('Y coordinate')
        
        # Add grid lines
        axes[row, col].set_xticks(range(0, layer.shape[1], 32))
        axes[row, col].set_yticks(range(0, layer.shape[0], 32))
        axes[row, col].grid(True, alpha=0.3)
    
    # Show composited view in the 4th subplot
    try:
        composited = image_from_array(layers)
        composited_array = np.array(composited)
        axes[1, 1].imshow(composited_array)
        axes[1, 1].set_title("Composited View\nAll layers combined")
        axes[1, 1].set_xlabel('X coordinate')
        axes[1, 1].set_ylabel('Y coordinate')
    except Exception as e:
        axes[1, 1].text(0.5, 0.5, f"Error compositing: {e}", 
                       ha='center', va='center', transform=axes[1, 1].transAxes)
        axes[1, 1].set_title("Composited View (Error)")
    
    # Add overall title
    fig.suptitle(f"{title_prefix}Layer Visualization - Step {getattr(visualize_layers_separately, 'step', 0)}", 
                 fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Increment step counter
    visualize_layers_separately.step = getattr(visualize_layers_separately, 'step', 0) + 1


In [None]:
def analyze_layer_contents(world, step=0):
    """Analyze and print the contents of each layer."""
    
    print(f"\n=== Layer Analysis - Step {step} ===")
    print(f"World dimensions: {world.height}x{world.width}")
    
    # Count entities by type and layer
    entity_counts = {}
    for y in range(world.height):
        for x in range(world.width):
            for layer in range(world.map.shape[2]):
                entity = world.map[y, x, layer]
                entity_type = type(entity).__name__
                key = f"Layer_{layer}_{entity_type}"
                entity_counts[key] = entity_counts.get(key, 0) + 1
    
    # Print counts by layer
    for layer in range(world.map.shape[2]):
        layer_name = ['Terrain', 'Dynamic', 'Beam'][layer]
        print(f"\n{layer_name} Layer:")
        layer_entities = {k: v for k, v in entity_counts.items() if k.startswith(f"Layer_{layer}_")}
        for key, count in sorted(layer_entities.items()):
            entity_type = key.split('_', 2)[2]
            print(f"  {entity_type}: {count}")
    
    # Check for specific issues
    print(f"\n=== Issue Checks ===")
    
    # Check for resources on walls
    resources_on_walls = 0
    for y in range(world.height):
        for x in range(world.width):
            terrain_entity = world.map[y, x, 0]  # Terrain layer
            dynamic_entity = world.map[y, x, 1]  # Dynamic layer
            
            if hasattr(terrain_entity, '__class__') and 'Wall' in terrain_entity.__class__.__name__:
                if hasattr(dynamic_entity, '__class__') and ('Stag' in dynamic_entity.__class__.__name__ or 'Hare' in dynamic_entity.__class__.__name__):
                    resources_on_walls += 1
    
    if resources_on_walls == 0:
        print(f"✓ No resources found on walls")
    else:
        print(f"✗ Found {resources_on_walls} resources on walls")
    
    # Check agent positions
    agent_count = 0
    for y in range(world.height):
        for x in range(world.width):
            dynamic_entity = world.map[y, x, 1]  # Dynamic layer
            if hasattr(dynamic_entity, '__class__') and 'Agent' in dynamic_entity.__class__.__name__:
                agent_count += 1
                print(f"  Agent at ({y}, {x})")
    
    print(f"Total agents: {agent_count}")
    
    # Check beam positions
    beam_count = 0
    for y in range(world.height):
        for x in range(world.width):
            beam_entity = world.map[y, x, 2]  # Beam layer
            if hasattr(beam_entity, '__class__') and 'Beam' in beam_entity.__class__.__name__:
                beam_count += 1
                print(f"  Beam at ({y}, {x})")
    
    print(f"Total beams: {beam_count}")


### Test with ASCII Map Generation


In [None]:
def test_ascii_map_layers():
    """Test layer visualization with ASCII map generation."""
    
    print("=== Testing ASCII Map Generation ===")
    
    # Load ASCII map configuration
    config = OmegaConf.load("../configs/config_ascii_map.yaml")
    
    # Create world with ASCII map generation
    world = StagHuntWorld(config=config, default_entity=Empty())
    experiment = StagHuntEnv(world, config)
    
    print(f"World dimensions: {world.height}x{world.width}")
    print(f"Number of agents: {len(experiment.agents)}")
    print(f"Agent spawn points: {len(world.agent_spawn_points)}")
    print(f"Resource spawn points: {len(world.resource_spawn_points)}")
    
    # Reset step counter
    visualize_layers_separately.step = 0
    
    # Initial state
    print("\n--- Initial State ---")
    analyze_layer_contents(world, 0)
    visualize_layers_separately(world, "ASCII Map - ")
    
    # Run a few steps
    for step in range(3):
        print(f"\n--- Step {step + 1} ---")
        
        # Get observations and actions
        observations = []
        actions = []
        
        for agent in experiment.agents:
            obs = agent.observe(experiment.world)
            observations.append(obs)
            # Use random actions for demonstration
            action = np.random.randint(0, 7)  # 0-6 for the 7 actions
            actions.append(action)
        
        print(f"Actions: {actions}")
        
        # Step the environment
        experiment.step(actions)
        
        # Analyze and visualize
        analyze_layer_contents(world, step + 1)
        visualize_layers_separately(world, "ASCII Map - ")
        
        # Pause for user to see the visualization
        input("Press Enter to continue to next step...")
    
    return world, experiment

# Run the test
world_ascii, experiment_ascii = test_ascii_map_layers()


### Test with Random Generation


In [None]:
def test_random_generation_layers():
    """Test layer visualization with random generation."""
    
    print("\n=== Testing Random Generation ===")
    
    # Load configuration for random generation
    config = OmegaConf.load("../configs/config_ascii_map.yaml")
    config.world.generation_mode = "random"
    config.world.height = 24
    config.world.width = 25
    
    # Create world with random generation
    world = StagHuntWorld(config=config, default_entity=Empty())
    experiment = StagHuntEnv(world, config)
    
    print(f"World dimensions: {world.height}x{world.width}")
    print(f"Number of agents: {len(experiment.agents)}")
    print(f"Agent spawn points: {len(world.agent_spawn_points)}")
    print(f"Resource spawn points: {len(world.resource_spawn_points)}")
    
    # Reset step counter
    visualize_layers_separately.step = 0
    
    # Initial state
    print("\n--- Initial State ---")
    analyze_layer_contents(world, 0)
    visualize_layers_separately(world, "Random - ")
    
    # Run a few steps
    for step in range(3):
        print(f"\n--- Step {step + 1} ---")
        
        # Get observations and actions
        observations = []
        actions = []
        
        for agent in experiment.agents:
            obs = agent.observe(experiment.world)
            observations.append(obs)
            # Use random actions for demonstration
            action = np.random.randint(0, 7)  # 0-6 for the 7 actions
            actions.append(action)
        
        print(f"Actions: {actions}")
        
        # Step the environment
        experiment.step(actions)
        
        # Analyze and visualize
        analyze_layer_contents(world, step + 1)
        visualize_layers_separately(world, "Random - ")
        
        # Pause for user to see the visualization
        input("Press Enter to continue to next step...")
    
    return world, experiment

# Run the test
world_random, experiment_random = test_random_generation_layers()


### Compare ASCII Map vs Random Generation


In [None]:
def compare_generation_methods():
    """Compare ASCII map vs random generation side by side."""
    
    print("\n=== Comparing Generation Methods ===")
    
    # Render both worlds
    layers_ascii = render_sprite(world_ascii, tile_size=[32, 32])
    layers_random = render_sprite(world_random, tile_size=[32, 32])
    
    # Create comparison figure
    fig, axes = plt.subplots(3, 3, figsize=(20, 18))
    
    layer_names = ['Terrain Layer', 'Dynamic Layer', 'Beam Layer']
    
    # Display ASCII map layers
    for i, (layer, name) in enumerate(zip(layers_ascii, layer_names)):
        axes[i, 0].imshow(layer)
        axes[i, 0].set_title(f"ASCII Map - {name}")
        axes[i, 0].set_xlabel('X coordinate')
        axes[i, 0].set_ylabel('Y coordinate')
    
    # Display random generation layers
    for i, (layer, name) in enumerate(zip(layers_random, layer_names)):
        axes[i, 1].imshow(layer)
        axes[i, 1].set_title(f"Random - {name}")
        axes[i, 1].set_xlabel('X coordinate')
        axes[i, 1].set_ylabel('Y coordinate')
    
    # Show composited views
    try:
        composited_ascii = image_from_array(layers_ascii)
        composited_ascii_array = np.array(composited_ascii)
        axes[0, 2].imshow(composited_ascii_array)
        axes[0, 2].set_title("ASCII Map - Composited")
        axes[0, 2].set_xlabel('X coordinate')
        axes[0, 2].set_ylabel('Y coordinate')
    except Exception as e:
        axes[0, 2].text(0.5, 0.5, f"Error: {e}", ha='center', va='center', transform=axes[0, 2].transAxes)
    
    try:
        composited_random = image_from_array(layers_random)
        composited_random_array = np.array(composited_random)
        axes[1, 2].imshow(composited_random_array)
        axes[1, 2].set_title("Random - Composited")
        axes[1, 2].set_xlabel('X coordinate')
        axes[1, 2].set_ylabel('Y coordinate')
    except Exception as e:
        axes[1, 2].text(0.5, 0.5, f"Error: {e}", ha='center', va='center', transform=axes[1, 2].transAxes)
    
    # Add summary statistics
    axes[2, 2].axis('off')
    summary_text = f"""
Generation Method Comparison:

ASCII Map:
- Dimensions: {world_ascii.height}x{world_ascii.width}
- Agents: {len(experiment_ascii.agents)}
- Spawn Points: {len(world_ascii.agent_spawn_points)}
- Resources: {len(world_ascii.resource_spawn_points)}

Random:
- Dimensions: {world_random.height}x{world_random.width}
- Agents: {len(experiment_random.agents)}
- Spawn Points: {len(world_random.agent_spawn_points)}
- Resources: {len(world_random.resource_spawn_points)}
"""
    axes[2, 2].text(0.1, 0.9, summary_text, transform=axes[2, 2].transAxes, 
                   fontsize=10, verticalalignment='top', fontfamily='monospace')
    axes[2, 2].set_title("Summary Statistics")
    
    plt.tight_layout()
    plt.show()

# Run the comparison
compare_generation_methods()


### Interactive Layer Exploration


In [None]:
def explore_specific_coordinates(world, coords_list):
    """Explore what's at specific coordinates across all layers."""
    
    print(f"\n=== Exploring Specific Coordinates ===")
    
    for y, x in coords_list:
        if 0 <= y < world.height and 0 <= x < world.width:
            print(f"\nCoordinate ({y}, {x}):")
            for layer in range(world.map.shape[2]):
                entity = world.map[y, x, layer]
                layer_name = ['Terrain', 'Dynamic', 'Beam'][layer]
                print(f"  {layer_name} Layer: {type(entity).__name__}")
        else:
            print(f"\nCoordinate ({y}, {x}) is out of bounds!")

# Explore some interesting coordinates
interesting_coords = [(0, 0), (1, 1), (5, 5), (10, 10), (15, 15), (20, 20)]
explore_specific_coordinates(world_ascii, interesting_coords)


### Summary and Notes


In [None]:
print("\n=== Layer Visualization Test Complete ===")
print("\nThis notebook demonstrates:")
print("1. How to visualize each layer separately")
print("2. How layers change over time")
print("3. Differences between ASCII map and random generation")
print("4. How to debug layer-specific issues")
print("5. How to explore specific coordinates")
print("\nKey insights:")
print("- Terrain layer: Static elements (walls, spawn points, sand)")
print("- Dynamic layer: Moving elements (agents, resources)")
print("- Beam layer: Temporary effects (interaction beams)")
print("- All layers should be properly aligned")
print("- Resources should never appear on walls")
print("- Beams should align with agent positions")
