# Agent-Based SEIR Model - Phase 1: Core Infrastructure

This notebook implements Phase 1.1 (Agent class) and Phase 1.2 (Population initialization) for a spatial agent-based SEIR model.

## Transition from Compartmental to Agent-Based Model

**Original Model (Compartmental SEIR):**
- Population divided into compartments: S, E1, E2, I1, I2, R1, R2
- Homogeneous mixing (everyone equally likely to contact everyone)
- Described by differential equations
- Population: ~10,000 individuals

**Target Model (Agent-Based with Spatial Structure):**
- Each individual is a discrete agent with position and state
- Spatial proximity determines contacts
- Heterogeneous mixing based on location
- Same population size, but now explicitly spatial

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Union, Dict, List

## Agent Class

### Required Capabilities:
1. Store unique identifier, 2D position (x, y), disease state, and time in current state
2. Calculate Euclidean distance to other agents
3. Validate inputs (non-negative positions, valid disease states)
4. Provide readable string representation
5. Be lightweight for large populations (10,000+ agents)

### Valid Disease States:
- **S**: Susceptible
- **E1**: Exposed to strain 1
- **E2**: Exposed to strain 2
- **I1**: Infected with strain 1
- **I2**: Infected with strain 2
- **R1**: Recovered from strain 1
- **R2**: Recovered from strain 2

In [None]:
class Agent:
    """
    Individual agent in a spatial SEIR epidemic model.
    
    Each agent represents a single individual with a unique identity, spatial position,
    disease state, and temporal tracking of disease progression.
    
    Attributes:
        id (int): Unique identifier for the agent
        x (float): X-coordinate position in 2D space
        y (float): Y-coordinate position in 2D space
        state (str): Current disease state - one of ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']
        time_in_state (float): Duration (in days) the agent has been in current state
    
    The class uses __slots__ for memory efficiency when creating large populations.
    """
    
    # Valid disease states for the two-strain SEIR model
    VALID_STATES = {'S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2'}
    
    # Use __slots__ to reduce memory footprint (important for 10k+ agents)
    __slots__ = ['id', 'x', 'y', 'state', 'time_in_state']
    
    def __init__(self, 
                 id: int, 
                 x: float, 
                 y: float, 
                 state: str = 'S',
                 time_in_state: float = 0.0):
        """
        Initialize an agent with position and disease state.
        
        Parameters:
            id: Unique integer identifier for the agent
            x: X-coordinate position (must be non-negative)
            y: Y-coordinate position (must be non-negative)
            state: Disease state (default: 'S' for susceptible)
            time_in_state: Initial time in current state (default: 0.0 days)
        
        Raises:
            ValueError: If coordinates are negative or state is invalid
        """
        # Validate coordinates
        if x < 0 or y < 0:
            raise ValueError(f"Coordinates must be non-negative. Got x={x}, y={y}")
        
        # Validate disease state
        if state not in self.VALID_STATES:
            raise ValueError(
                f"Invalid disease state '{state}'. Must be one of {self.VALID_STATES}"
            )
        
        # Validate time_in_state
        if time_in_state < 0:
            raise ValueError(f"time_in_state must be non-negative. Got {time_in_state}")
        
        # Initialize attributes
        self.id = id
        self.x = x
        self.y = y
        self.state = state
        self.time_in_state = time_in_state
    
    def distance_to(self, other: 'Agent') -> float:
        """
        Calculate Euclidean distance to another agent.
        
        Parameters:
            other: Another Agent instance
        
        Returns:
            Euclidean distance between this agent and the other agent
        """
        return np.sqrt((self.x - other.x)**2 + (self.y - other.y)**2)
    
    def __repr__(self) -> str:
        """
        Readable string representation for debugging and logging.
        
        Returns:
            String showing agent ID, position, and current state
        """
        return (f"Agent(id={self.id}, pos=({self.x:.1f}, {self.y:.1f}), "
                f"state='{self.state}', t={self.time_in_state:.1f})")
    
    def __str__(self) -> str:
        """
        User-friendly string representation.
        
        Returns:
            Simplified string for display
        """
        return f"Agent #{self.id} [{self.state}] at ({self.x:.1f}, {self.y:.1f})"

## Population Initialization Function

### Required Capabilities:
1. Create a specified number of agents distributed in 2D space
2. Assign disease states according to a given distribution
3. Position agents spatially with uniform random distribution
4. Ensure reproducibility with random seed control
5. Return a list of Agent objects ready for simulation

### Design:
- **State Distribution**: Dictionary format `{'S': 9900, 'E1': 50, 'E2': 50}`
- **Spatial Distribution**: Uniform random across square space
- **ID Assignment**: Sequential (0, 1, 2, ..., n-1)
- **State Shuffling**: Randomized to avoid spatial clustering by disease state

In [None]:
def initialize_population(
    n_agents: int,
    space_size: float,
    initial_states: Optional[Dict[str, int]] = None,
    spatial_distribution: str = 'uniform',
    seed: Optional[int] = None
) -> List[Agent]:
    """
    Initialize a population of agents in 2D space.
    
    Creates agents with specified disease state distribution and spatial positions.
    Agent states are shuffled to avoid spatial clustering by disease state.
    
    Parameters:
        n_agents: Total number of agents to create (must be positive)
        space_size: Size of square space (agents positioned in [0, space_size] x [0, space_size])
        initial_states: Dictionary mapping disease states to counts.
                       Example: {'S': 9900, 'E1': 50, 'E2': 50}
                       If None, all agents are susceptible ('S')
                       State counts must sum to n_agents
        spatial_distribution: 'uniform' for uniform random distribution (only option currently)
        seed: Random seed for reproducibility. If None, uses random state.
    
    Returns:
        List of Agent objects with sequential IDs (0 to n_agents-1)
    
    Raises:
        ValueError: If n_agents <= 0, space_size <= 0, state counts don't sum to n_agents,
                   or invalid state names are provided
    
    Examples:
        >>> # Create 100 susceptible agents
        >>> pop = initialize_population(100, 50.0, seed=42)
        
        >>> # Match ODE initial conditions
        >>> pop = initialize_population(
        ...     10000, 100.0, 
        ...     initial_states={'S': 9900, 'E1': 50, 'E2': 50},
        ...     seed=42
        ... )
    """
    # Set random seed for reproducibility
    if seed is not None:
        np.random.seed(seed)
    
    # Validate inputs
    if n_agents <= 0:
        raise ValueError(f"n_agents must be positive. Got {n_agents}")
    
    if space_size <= 0:
        raise ValueError(f"space_size must be positive. Got {space_size}")
    
    # Default: all agents susceptible
    if initial_states is None:
        initial_states = {'S': n_agents}
    
    # Validate initial_states
    total_state_count = sum(initial_states.values())
    if total_state_count != n_agents:
        raise ValueError(
            f"Sum of initial_states ({total_state_count}) must equal n_agents ({n_agents})"
        )
    
    # Validate state names
    valid_states = Agent.VALID_STATES
    invalid_states = set(initial_states.keys()) - valid_states
    if invalid_states:
        raise ValueError(
            f"Invalid state names: {invalid_states}. Must be one of {valid_states}"
        )
    
    # Validate spatial_distribution
    if spatial_distribution != 'uniform':
        raise ValueError(
            f"spatial_distribution must be 'uniform'. Got '{spatial_distribution}'"
        )
    
    # Create list of states to assign (before shuffle)
    state_assignments = []
    for state, count in initial_states.items():
        state_assignments.extend([state] * count)
    
    # Shuffle state assignments to avoid clustering
    np.random.shuffle(state_assignments)
    
    # Create agents with shuffled states
    agents = []
    for i in range(n_agents):
        agent = Agent(
            id=i,
            x=np.random.uniform(0, space_size),
            y=np.random.uniform(0, space_size),
            state=state_assignments[i]
        )
        agents.append(agent)
    
    return agents

## Population Initialization Examples

Demonstrations of different ways to initialize populations.

In [None]:
# Example 1: Default population (all susceptible)
print("Example 1: Default Population (All Susceptible)")
print("=" * 60)

pop_default = initialize_population(n_agents=100, space_size=50.0, seed=42)

print(f"Created {len(pop_default)} agents")
print(f"First agent: {pop_default[0]}")
print(f"Last agent: {pop_default[-1]}")

# Count states
state_counts = {}
for agent in pop_default:
    state_counts[agent.state] = state_counts.get(agent.state, 0) + 1

print(f"\nState distribution: {state_counts}")
print()

In [None]:
# Example 2: Match ODE initial conditions
print("Example 2: Match ODE Initial Conditions")
print("=" * 60)

initial_states = {'S': 9900, 'E1': 50, 'E2': 50}
pop_ode = initialize_population(
    n_agents=10000,
    space_size=100.0,
    initial_states=initial_states,
    seed=42
)

print(f"Created {len(pop_ode)} agents matching ODE initial conditions")

# Count states
state_counts = {}
for agent in pop_ode:
    state_counts[agent.state] = state_counts.get(agent.state, 0) + 1

print(f"\nState distribution:")
for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
    count = state_counts.get(state, 0)
    if count > 0:
        print(f"  {state}: {count:,}")

print(f"\nFirst 5 agents (showing state shuffling):")
for agent in pop_ode[:5]:
    print(f"  {agent}")
print()

In [None]:
# Example 3: Reproducibility with seed
print("Example 3: Reproducibility with Seed")
print("=" * 60)

pop_seed1a = initialize_population(n_agents=10, space_size=10.0, seed=123)
pop_seed1b = initialize_population(n_agents=10, space_size=10.0, seed=123)
pop_seed2 = initialize_population(n_agents=10, space_size=10.0, seed=456)

print("Populations with same seed (123):")
print(f"Pop A, Agent 0: x={pop_seed1a[0].x:.4f}, y={pop_seed1a[0].y:.4f}")
print(f"Pop B, Agent 0: x={pop_seed1b[0].x:.4f}, y={pop_seed1b[0].y:.4f}")
print(f"Identical: {pop_seed1a[0].x == pop_seed1b[0].x and pop_seed1a[0].y == pop_seed1b[0].y}")

print(f"\nPopulation with different seed (456):")
print(f"Pop C, Agent 0: x={pop_seed2[0].x:.4f}, y={pop_seed2[0].y:.4f}")
print(f"Different from Pop A: {pop_seed1a[0].x != pop_seed2[0].x or pop_seed1a[0].y != pop_seed2[0].y}")
print()

In [None]:
# Example 4: Population with multiple disease states
print("Example 4: Population with Multiple Disease States")
print("=" * 60)

initial_states = {
    'S': 85,
    'E1': 5,
    'E2': 5,
    'I1': 2,
    'I2': 2,
    'R1': 0,
    'R2': 1
}

pop_mixed = initialize_population(
    n_agents=100,
    space_size=50.0,
    initial_states=initial_states,
    seed=42
)

# Count and display states
state_counts = {}
for agent in pop_mixed:
    state_counts[agent.state] = state_counts.get(agent.state, 0) + 1

print(f"Created {len(pop_mixed)} agents with mixed states")
print(f"\nState distribution:")
for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
    count = state_counts.get(state, 0)
    if count > 0:
        print(f"  {state}: {count}")

# Show that states are shuffled (not clustered)
print(f"\nFirst 10 agents (demonstrating state shuffling):")
for agent in pop_mixed[:10]:
    print(f"  {agent}")
print()

## Agent Class Examples

Basic demonstrations of individual agent functionality.

In [None]:
# Create individual agents
agent1 = Agent(id=0, x=10.0, y=20.0, state='S')
agent2 = Agent(id=1, x=13.0, y=24.0, state='I1')
agent3 = Agent(id=2, x=50.0, y=50.0, state='E2')

print("Created agents:")
print(agent1)
print(agent2)
print(agent3)

In [None]:
# Calculate distances between agents
dist_12 = agent1.distance_to(agent2)
dist_13 = agent1.distance_to(agent3)

print(f"\nDistance calculations:")
print(f"Agent 0 to Agent 1: {dist_12:.2f} units")
print(f"Agent 0 to Agent 3: {dist_13:.2f} units")

In [None]:
# Create a small population using the initialization function
np.random.seed(42)

n_agents = 100
space_size = 50.0

# Use initialize_population function
population = initialize_population(
    n_agents=n_agents,
    space_size=space_size,
    initial_states={'S': 90, 'E1': 5, 'E2': 5},
    seed=42
)

# Print population summary
print(f"\nCreated population of {len(population)} agents")
print("\nState distribution:")
state_counts = {}
for agent in population:
    state_counts[agent.state] = state_counts.get(agent.state, 0) + 1

for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
    count = state_counts.get(state, 0)
    if count > 0:
        print(f"  {state}: {count}")

In [None]:
# Visualize the population in 2D space
state_colors = {
    'S': 'lightgray',
    'E1': 'gold',
    'E2': 'orange',
    'I1': 'red',
    'I2': 'darkred',
    'R1': 'lightblue',
    'R2': 'darkblue'
}

fig, ax = plt.subplots(figsize=(10, 10))

# Plot each state separately for legend
for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
    agents_in_state = [agent for agent in population if agent.state == state]
    if agents_in_state:
        x_coords = [agent.x for agent in agents_in_state]
        y_coords = [agent.y for agent in agents_in_state]
        ax.scatter(x_coords, y_coords, 
                  c=state_colors[state], 
                  s=100, 
                  alpha=0.7,
                  edgecolors='black',
                  linewidths=0.5,
                  label=f"{state} (n={len(agents_in_state)})")

ax.set_xlim(0, space_size)
ax.set_ylim(0, space_size)
ax.set_xlabel('X Position', fontsize=12)
ax.set_ylabel('Y Position', fontsize=12)
ax.set_title('Agent Positions in 2D Space', fontsize=14, fontweight='bold')
ax.set_aspect('equal')
ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(loc='upper right', fontsize=10)

plt.tight_layout()
plt.show()

## Phase 1.3: Spatial Index for Neighbor Finding

### Purpose:
Efficient neighbor-finding is critical for spatial transmission models. Instead of checking all N×N agent pairs (O(N²)), we use a grid-based spatial index to find neighbors in O(1) time per query.

### Design:
- Divide space into a grid of cells
- Each cell contains agents in that region
- To find neighbors: check agent's cell + adjacent cells
- Cell size ≥ typical contact radius for efficiency

In [None]:
class SpatialIndex:
    """
    Grid-based spatial index for efficient neighbor queries in 2D space.
    
    Divides the space into a grid of cells. Agents are assigned to cells based
    on their position. Neighbor queries only need to check the agent's cell and
    adjacent cells, avoiding O(N²) comparisons.
    
    Attributes:
        space_size (float): Size of the square space
        cell_size (float): Size of each grid cell
        n_cells (int): Number of cells per dimension
        grid (dict): Dictionary mapping (cell_x, cell_y) to list of agents
    """
    
    def __init__(self, space_size: float, cell_size: float):
        """
        Initialize spatial index with grid structure.
        
        Parameters:
            space_size: Size of square space [0, space_size] x [0, space_size]
            cell_size: Size of each grid cell (should be >= typical contact radius)
        
        Raises:
            ValueError: If space_size or cell_size are not positive
        """
        if space_size <= 0:
            raise ValueError(f"space_size must be positive. Got {space_size}")
        if cell_size <= 0:
            raise ValueError(f"cell_size must be positive. Got {cell_size}")
        
        self.space_size = space_size
        self.cell_size = cell_size
        self.n_cells = int(np.ceil(space_size / cell_size))
        self.grid = {}
    
    def get_cell(self, x: float, y: float) -> tuple:
        """
        Convert position to grid cell indices.
        
        Parameters:
            x, y: Position coordinates
        
        Returns:
            Tuple (cell_x, cell_y) of integer cell indices
        """
        cell_x = int(x / self.cell_size)
        cell_y = int(y / self.cell_size)
        
        # Clamp to valid range
        cell_x = max(0, min(cell_x, self.n_cells - 1))
        cell_y = max(0, min(cell_y, self.n_cells - 1))
        
        return (cell_x, cell_y)
    
    def update(self, agents: List[Agent]):
        """
        Rebuild spatial index from current agent positions.
        
        Parameters:
            agents: List of all agents to index
        """
        # Clear existing grid
        self.grid = {}
        
        # Assign agents to cells
        for agent in agents:
            cell = self.get_cell(agent.x, agent.y)
            if cell not in self.grid:
                self.grid[cell] = []
            self.grid[cell].append(agent)
    
    def find_neighbors(self, agent: Agent, radius: float) -> List[Agent]:
        """
        Find all agents within a given radius of the query agent.
        
        Parameters:
            agent: The query agent
            radius: Search radius
        
        Returns:
            List of agents within radius (excluding the query agent itself)
        """
        neighbors = []
        
        # Get agent's cell
        center_cell = self.get_cell(agent.x, agent.y)
        
        # Determine how many cells to check in each direction
        cell_range = int(np.ceil(radius / self.cell_size))
        
        # Check agent's cell and nearby cells
        for dx in range(-cell_range, cell_range + 1):
            for dy in range(-cell_range, cell_range + 1):
                check_cell = (center_cell[0] + dx, center_cell[1] + dy)
                
                # Skip if cell is out of bounds
                if check_cell not in self.grid:
                    continue
                
                # Check each agent in this cell
                for other_agent in self.grid[check_cell]:
                    # Skip self
                    if other_agent.id == agent.id:
                        continue
                    
                    # Check actual distance
                    distance = agent.distance_to(other_agent)
                    if distance <= radius:
                        neighbors.append(other_agent)
        
        return neighbors
    
    def get_stats(self) -> dict:
        """
        Get statistics about the spatial index.
        
        Returns:
            Dictionary with statistics
        """
        if not self.grid:
            return {
                'n_cells_total': self.n_cells ** 2,
                'n_cells_occupied': 0,
                'n_agents': 0,
                'agents_per_cell_mean': 0,
                'agents_per_cell_max': 0
            }
        
        agents_per_cell = [len(agents) for agents in self.grid.values()]
        
        return {
            'n_cells_total': self.n_cells ** 2,
            'n_cells_occupied': len(self.grid),
            'n_agents': sum(agents_per_cell),
            'agents_per_cell_mean': np.mean(agents_per_cell),
            'agents_per_cell_max': np.max(agents_per_cell)
        }

## Spatial Index Examples

Demonstrate building and using the spatial index.

In [None]:
# Example 1: Create spatial index for a population
print("Example 1: Building Spatial Index")
print("=" * 60)

# Create a population
population = initialize_population(
    n_agents=1000,
    space_size=100.0,
    initial_states={'S': 950, 'E1': 25, 'E2': 25},
    seed=42
)

# Create spatial index with cell size = 10
spatial_index = SpatialIndex(space_size=100.0, cell_size=10.0)
spatial_index.update(population)

# Get statistics
stats = spatial_index.get_stats()
print(f"Spatial Index Statistics:")
print(f"  Total cells: {stats['n_cells_total']}")
print(f"  Occupied cells: {stats['n_cells_occupied']}")
print(f"  Total agents: {stats['n_agents']}")
print(f"  Agents per cell (mean): {stats['agents_per_cell_mean']:.1f}")
print(f"  Agents per cell (max): {stats['agents_per_cell_max']}")
print()

In [None]:
# Example 2: Find neighbors within contact radius
print("Example 2: Finding Neighbors")
print("=" * 60)

# Pick a random agent
query_agent = population[100]
contact_radius = 5.0

print(f"Query agent: {query_agent}")
print(f"Contact radius: {contact_radius}")

# Find neighbors
neighbors = spatial_index.find_neighbors(query_agent, contact_radius)

print(f"\nFound {len(neighbors)} neighbors within radius {contact_radius}")

# Show first few neighbors
print(f"\nFirst 5 neighbors:")
for i, neighbor in enumerate(neighbors[:5]):
    distance = query_agent.distance_to(neighbor)
    print(f"  {i+1}. {neighbor} - distance: {distance:.2f}")

# Count neighbors by state
neighbor_states = {}
for neighbor in neighbors:
    neighbor_states[neighbor.state] = neighbor_states.get(neighbor.state, 0) + 1

print(f"\nNeighbors by disease state:")
for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
    count = neighbor_states.get(state, 0)
    if count > 0:
        print(f"  {state}: {count}")
print()

In [None]:
# Example 3: Compare different contact radii
print("Example 3: Effect of Contact Radius")
print("=" * 60)

query_agent = population[500]
radii = [2.0, 5.0, 10.0, 20.0]

print(f"Query agent at position ({query_agent.x:.1f}, {query_agent.y:.1f})\n")

for radius in radii:
    neighbors = spatial_index.find_neighbors(query_agent, radius)
    print(f"Radius {radius:5.1f}: {len(neighbors):4d} neighbors")
print()

## Phase 1.4: Enhanced Visualization Functions

Professional visualization functions for analyzing spatial distributions and contact networks.

In [None]:
def plot_population_detailed(
    agents: List[Agent],
    space_size: float,
    title: str = "Agent-Based SEIR Model - Population Snapshot",
    figsize: tuple = (12, 10),
    show_stats: bool = True
) -> tuple:
    """
    Create detailed visualization of agent population with statistics.
    
    Parameters:
        agents: List of Agent objects
        space_size: Size of the spatial domain
        title: Plot title
        figsize: Figure size (width, height)
        show_stats: Whether to show statistics panel
    
    Returns:
        Tuple of (fig, axes)
    """
    # Define colors for each state
    state_colors = {
        'S': '#D3D3D3',   # Light gray
        'E1': '#FFD700',  # Gold
        'E2': '#FFA500',  # Orange
        'I1': '#FF4500',  # Red-orange
        'I2': '#8B0000',  # Dark red
        'R1': '#87CEEB',  # Sky blue
        'R2': '#4169E1'   # Royal blue
    }
    
    # Count agents by state
    state_counts = {}
    for agent in agents:
        state_counts[agent.state] = state_counts.get(agent.state, 0) + 1
    
    # Create figure
    if show_stats:
        fig, (ax_main, ax_stats) = plt.subplots(1, 2, figsize=figsize,
                                                gridspec_kw={'width_ratios': [3, 1]})
    else:
        fig, ax_main = plt.subplots(1, 1, figsize=figsize)
        ax_stats = None
    
    # Main plot: agent positions
    for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
        agents_in_state = [agent for agent in agents if agent.state == state]
        if agents_in_state:
            x_coords = [agent.x for agent in agents_in_state]
            y_coords = [agent.y for agent in agents_in_state]
            count = len(agents_in_state)
            
            ax_main.scatter(
                x_coords, y_coords,
                c=state_colors[state],
                s=50,
                alpha=0.6,
                edgecolors='black',
                linewidths=0.5,
                label=f"{state}: {count}"
            )
    
    ax_main.set_xlim(0, space_size)
    ax_main.set_ylim(0, space_size)
    ax_main.set_xlabel('X Position', fontsize=12, fontweight='bold')
    ax_main.set_ylabel('Y Position', fontsize=12, fontweight='bold')
    ax_main.set_title(title, fontsize=14, fontweight='bold', pad=20)
    ax_main.set_aspect('equal')
    ax_main.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    ax_main.legend(loc='upper right', fontsize=10, framealpha=0.9)
    
    # Statistics panel
    if show_stats and ax_stats is not None:
        ax_stats.axis('off')
        
        stats_text = "Population Statistics\n" + "=" * 25 + "\n\n"
        stats_text += f"Total Agents: {len(agents):,}\n\n"
        
        stats_text += "Disease States:\n"
        for state in ['S', 'E1', 'E2', 'I1', 'I2', 'R1', 'R2']:
            count = state_counts.get(state, 0)
            if count > 0:
                pct = 100 * count / len(agents)
                stats_text += f"  {state}: {count:4d} ({pct:5.2f}%)\n"
        
        # Add spatial extent
        x_coords = [agent.x for agent in agents]
        y_coords = [agent.y for agent in agents]
        stats_text += f"\nSpatial Extent:\n"
        stats_text += f"  X: [{min(x_coords):.1f}, {max(x_coords):.1f}]\n"
        stats_text += f"  Y: [{min(y_coords):.1f}, {max(y_coords):.1f}]\n"
        
        ax_stats.text(0.1, 0.95, stats_text, 
                     transform=ax_stats.transAxes,
                     fontsize=10,
                     verticalalignment='top',
                     fontfamily='monospace',
                     bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    
    plt.tight_layout()
    return fig, (ax_main, ax_stats) if show_stats else (ax_main,)


def plot_contact_network(
    agents: List[Agent],
    spatial_index: SpatialIndex,
    query_agent: Agent,
    contact_radius: float,
    space_size: float,
    figsize: tuple = (10, 10)
) -> tuple:
    """
    Visualize the contact network for a specific agent.
    
    Parameters:
        agents: List of all agents
        spatial_index: SpatialIndex for finding neighbors
        query_agent: Agent to show contacts for
        contact_radius: Contact radius to visualize
        space_size: Size of spatial domain
        figsize: Figure size
    
    Returns:
        Tuple of (fig, ax)
    """
    state_colors = {
        'S': '#D3D3D3', 'E1': '#FFD700', 'E2': '#FFA500',
        'I1': '#FF4500', 'I2': '#8B0000', 'R1': '#87CEEB', 'R2': '#4169E1'
    }
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Find neighbors
    neighbors = spatial_index.find_neighbors(query_agent, contact_radius)
    neighbor_ids = set(n.id for n in neighbors)
    
    # Plot all agents (faded)
    for agent in agents:
        if agent.id == query_agent.id:
            continue
        if agent.id in neighbor_ids:
            continue
        ax.scatter(agent.x, agent.y, c=state_colors[agent.state],
                  s=20, alpha=0.2, edgecolors='none')
    
    # Plot neighbors
    for neighbor in neighbors:
        ax.scatter(neighbor.x, neighbor.y, c=state_colors[neighbor.state],
                  s=80, alpha=0.7, edgecolors='black', linewidths=1)
        # Draw line to query agent
        ax.plot([query_agent.x, neighbor.x], [query_agent.y, neighbor.y],
               'k-', alpha=0.3, linewidth=0.5)
    
    # Plot query agent
    ax.scatter(query_agent.x, query_agent.y, c='red',
              s=200, marker='*', edgecolors='black', linewidths=2,
              label=f'Query Agent (ID {query_agent.id})', zorder=10)
    
    # Draw contact radius circle
    circle = plt.Circle((query_agent.x, query_agent.y), contact_radius,
                       color='red', fill=False, linestyle='--',
                       linewidth=2, label=f'Contact Radius ({contact_radius})')
    ax.add_patch(circle)
    
    ax.set_xlim(0, space_size)
    ax.set_ylim(0, space_size)
    ax.set_xlabel('X Position', fontsize=12, fontweight='bold')
    ax.set_ylabel('Y Position', fontsize=12, fontweight='bold')
    ax.set_title(f'Contact Network: Agent {query_agent.id} ({len(neighbors)} contacts)',
                fontsize=14, fontweight='bold')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(loc='upper right', fontsize=10)
    
    plt.tight_layout()
    return fig, ax

## Visualization Examples

In [None]:
# Example 1: Detailed population plot
population = initialize_population(
    n_agents=1000,
    space_size=100.0,
    initial_states={'S': 900, 'E1': 50, 'E2': 50},
    seed=42
)

fig, axes = plot_population_detailed(
    agents=population,
    space_size=100.0,
    title="Population Snapshot - Initial Conditions",
    show_stats=True
)
plt.show()

In [None]:
# Example 2: Contact network visualization
spatial_index = SpatialIndex(space_size=100.0, cell_size=10.0)
spatial_index.update(population)

# Pick an exposed agent to visualize contacts
exposed_agents = [a for a in population if a.state == 'E1']
if exposed_agents:
    query_agent = exposed_agents[0]
    
    fig, ax = plot_contact_network(
        agents=population,
        spatial_index=spatial_index,
        query_agent=query_agent,
        contact_radius=5.0,
        space_size=100.0
    )
    plt.show()

## Phase 1 Complete!

### What We've Built:

**Phase 1.1 - Agent Class:**
- Individual agents with position, state, and temporal tracking
- Distance calculations
- Input validation

**Phase 1.2 - Population Initialization:**
- Create populations matching ODE initial conditions
- State distribution control
- Reproducible with seeds
- Spatial distribution across space

**Phase 1.3 - Spatial Index:**
- Efficient neighbor finding (O(1) per query vs O(N²))
- Grid-based spatial indexing
- Contact radius queries

**Phase 1.4 - Visualization:**
- Population snapshots with statistics
- Contact network visualization
- Disease state color coding

### Ready for Phase 2:
Next steps will implement disease progression (E→I→R transitions) and spatial transmission mechanics.

The infrastructure is now in place to build a full spatial agent-based SEIR model!