# 04: A* Search and Heuristics

## Learning Objectives
- Understand how heuristics guide search
- Implement A* search algorithm
- Design admissible and consistent heuristics
- Compare different heuristic functions
- Analyze the impact of heuristic quality on performance

## 1. Introduction to Informed Search

**A* Search** combines:
- **g(n)**: Cost from start to node n (like UCS)
- **h(n)**: Heuristic estimate from n to goal
- **f(n) = g(n) + h(n)**: Estimated total cost through n

### Key Properties:
- **Complete**: Yes
- **Optimal**: Yes (if h is admissible)
- **Time**: O(b^d) but much better in practice
- **Space**: O(b^d)

In [None]:
# Import required modules
import heapq
import math
from typing import Any, List, Tuple, Dict, Callable
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [None]:
# Base classes
class SearchProblem:
    def get_start_state(self) -> Any:
        raise NotImplementedError
    
    def is_goal_state(self, state: Any) -> bool:
        raise NotImplementedError
    
    def get_successors(self, state: Any) -> List[Tuple[Any, str, float]]:
        raise NotImplementedError

class PriorityQueue:
    def __init__(self):
        self.heap = []
        self.counter = 0
        self.entry_finder = {}
        self.REMOVED = '<removed>'
    
    def push(self, state, priority, data=None):
        if state in self.entry_finder:
            self.remove(state)
        entry = [priority, self.counter, state, data]
        self.counter += 1
        self.entry_finder[state] = entry
        heapq.heappush(self.heap, entry)
    
    def pop(self):
        while self.heap:
            priority, _, state, data = heapq.heappop(self.heap)
            if state != self.REMOVED:
                del self.entry_finder[state]
                return state, priority, data
        raise KeyError('Pop from empty priority queue')
    
    def remove(self, state):
        entry = self.entry_finder.pop(state)
        entry[2] = self.REMOVED
    
    def is_empty(self):
        return len(self.entry_finder) == 0
    
    def __len__(self):
        return len(self.entry_finder)

## 2. Heuristic Functions

### Admissible Heuristics
A heuristic h is **admissible** if: h(n) ≤ h*(n) for all n  
(Never overestimates the true cost)

### Consistent Heuristics
A heuristic h is **consistent** if: h(n) ≤ c(n,n') + h(n') for all n, n'  
(Satisfies triangle inequality)

In [None]:
def manhattan_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
    """L1 distance - admissible for 4-directional movement"""
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])

def euclidean_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
    """L2 distance - admissible for any movement"""
    return math.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)

def chebyshev_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
    """L∞ distance - admissible for 8-directional movement"""
    return max(abs(pos1[0] - pos2[0]), abs(pos1[1] - pos2[1]))

def null_heuristic(state: Any, problem: Any = None) -> float:
    """Zero heuristic - makes A* behave like UCS"""
    return 0

# Visualize different distance metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Create grid showing distances from center
center = (5, 5)
grid_size = 11

for ax, distance_fn, title in zip(
    axes,
    [manhattan_distance, euclidean_distance, chebyshev_distance],
    ['Manhattan Distance', 'Euclidean Distance', 'Chebyshev Distance']
):
    distances = np.zeros((grid_size, grid_size))
    for i in range(grid_size):
        for j in range(grid_size):
            distances[i, j] = distance_fn((i, j), center)
    
    im = ax.imshow(distances, cmap='YlOrRd')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Mark center
    ax.plot(center[1], center[0], 'b*', markersize=15)
    
    # Add contour lines
    contour = ax.contour(distances, levels=5, colors='black', alpha=0.3)
    ax.clabel(contour, inline=True, fontsize=8)

plt.tight_layout()
plt.show()

## 3. A* Search Implementation

In [None]:
class AStarSearch:
    """A* search with customizable heuristic"""
    
    def __init__(self, heuristic: Callable = null_heuristic):
        self.heuristic = heuristic
        self.nodes_expanded = 0
        self.max_frontier_size = 0
        self.exploration_order = []
        self.f_values = {}  # f(n) values when expanded
        self.g_values = {}  # g(n) values when expanded
        self.h_values = {}  # h(n) values when expanded
    
    def search(self, problem: SearchProblem) -> Tuple[List[str], float]:
        """
        Search using A* algorithm
        Returns: (actions, total_cost)
        """
        start_state = problem.get_start_state()
        
        # Check if start is goal
        if problem.is_goal_state(start_state):
            return [], 0
        
        # Priority queue with f(n) = g(n) + h(n)
        frontier = PriorityQueue()
        
        # Calculate initial heuristic
        h_start = self.heuristic(start_state, problem)
        f_start = 0 + h_start
        
        # Push start state with f-value priority
        frontier.push(start_state, f_start, ([], 0))  # (path, g_cost)
        
        # Track best g-values
        best_g = {start_state: 0}
        
        # Explored set
        explored = set()
        
        while not frontier.is_empty():
            # Update statistics
            self.max_frontier_size = max(self.max_frontier_size, len(frontier))
            
            # Get node with lowest f-value
            state, f_value, (path, g_cost) = frontier.pop()
            
            # Skip if already explored
            if state in explored:
                continue
            
            # Mark as explored
            explored.add(state)
            self.nodes_expanded += 1
            self.exploration_order.append(state)
            
            # Store values for visualization
            self.g_values[state] = g_cost
            self.h_values[state] = f_value - g_cost
            self.f_values[state] = f_value
            
            # Check if goal
            if problem.is_goal_state(state):
                return path, g_cost
            
            # Expand node
            for successor, action, step_cost in problem.get_successors(state):
                if successor not in explored:
                    new_g = g_cost + step_cost
                    
                    # Only consider if we found a better path
                    if successor not in best_g or new_g < best_g[successor]:
                        best_g[successor] = new_g
                        
                        # Calculate f-value
                        h_value = self.heuristic(successor, problem)
                        f_value = new_g + h_value
                        
                        # Add to frontier
                        frontier.push(successor, f_value, (path + [action], new_g))
        
        # No solution found
        return None, float('inf')
    
    def get_statistics(self):
        return {
            'nodes_expanded': self.nodes_expanded,
            'max_frontier_size': self.max_frontier_size,
            'states_explored': len(self.exploration_order)
        }

## 4. Grid Problem with Heuristics

In [None]:
class GridSearchProblem(SearchProblem):
    """Grid navigation problem for testing heuristics"""
    
    def __init__(self, grid, start, goal):
        self.grid = grid
        self.start = start
        self.goal = goal
        self.rows = len(grid)
        self.cols = len(grid[0])
    
    def get_start_state(self):
        return self.start
    
    def is_goal_state(self, state):
        return state == self.goal
    
    def get_successors(self, state):
        successors = []
        row, col = state
        
        # 4-directional movement
        moves = [
            ((-1, 0), 'UP', 1.0),
            ((1, 0), 'DOWN', 1.0),
            ((0, -1), 'LEFT', 1.0),
            ((0, 1), 'RIGHT', 1.0)
        ]
        
        for (dr, dc), action, cost in moves:
            new_row, new_col = row + dr, col + dc
            
            if (0 <= new_row < self.rows and 
                0 <= new_col < self.cols and 
                self.grid[new_row][new_col] != 1):
                successors.append(((new_row, new_col), action, cost))
        
        return successors

def create_heuristic(goal, distance_type='manhattan'):
    """Create a heuristic function for a specific goal"""
    
    distance_functions = {
        'manhattan': manhattan_distance,
        'euclidean': euclidean_distance,
        'chebyshev': chebyshev_distance,
        'null': lambda a, b: 0
    }
    
    distance_fn = distance_functions[distance_type]
    
    def heuristic(state, problem=None):
        return distance_fn(state, goal)
    
    return heuristic

## 5. Comparing Different Heuristics

In [None]:
# Create a test maze
maze = [
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 1, 1, 1, 0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 1, 1, 0, 1, 1, 1, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 1, 1, 1, 1, 0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 1, 0, 1, 1, 1, 1, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]

start = (0, 0)
goal = (9, 9)
problem = GridSearchProblem(maze, start, goal)

# Test different heuristics
heuristics = ['null', 'manhattan', 'euclidean']
results = {}

print("=" * 60)
print("COMPARING HEURISTICS ON MAZE PROBLEM")
print("=" * 60)
print(f"{'Heuristic':<15} {'Nodes Expanded':<15} {'Path Length':<15} {'Path Cost'}")
print("-" * 60)

for h_type in heuristics:
    heuristic = create_heuristic(goal, h_type)
    astar = AStarSearch(heuristic)
    
    solution, cost = astar.search(problem)
    
    results[h_type] = {
        'algorithm': astar,
        'solution': solution,
        'cost': cost,
        'stats': astar.get_statistics()
    }
    
    print(f"{h_type.capitalize():<15} {astar.nodes_expanded:<15} "
          f"{len(solution):<15} {cost:.1f}")

## 6. Visualizing Search Behavior

In [None]:
def visualize_astar_comparison(problem, results):
    """Visualize how different heuristics explore the space"""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for idx, (h_type, data) in enumerate(results.items()):
        astar = data['algorithm']
        solution = data['solution']
        
        # Top row: Exploration order
        ax = axes[idx]
        display = np.ones((problem.rows, problem.cols))
        
        # Mark obstacles
        for i in range(problem.rows):
            for j in range(problem.cols):
                if problem.grid[i][j] == 1:
                    display[i, j] = 0
        
        # Color by exploration order
        max_order = len(astar.exploration_order)
        for order, state in enumerate(astar.exploration_order):
            if display[state[0], state[1]] == 1:
                display[state[0], state[1]] = 0.3 + (0.6 * order / max_order)
        
        im = ax.imshow(display, cmap='YlOrRd', vmin=0, vmax=1)
        ax.set_title(f"{h_type.capitalize()} - Exploration Order")
        ax.axis('off')
        
        # Mark start and goal
        ax.plot(problem.start[1], problem.start[0], 'g*', markersize=15)
        ax.plot(problem.goal[1], problem.goal[0], 'r*', markersize=15)
        
        # Bottom row: f-values
        ax = axes[idx + 3]
        display_f = np.full((problem.rows, problem.cols), np.nan)
        
        for state, f_val in astar.f_values.items():
            display_f[state[0], state[1]] = f_val
        
        # Mask non-explored areas
        masked = np.ma.masked_invalid(display_f)
        
        im = ax.imshow(masked, cmap='coolwarm')
        ax.set_title(f"{h_type.capitalize()} - f-values")
        ax.axis('off')
        
        # Add f-values as text (for small grids)
        if problem.rows <= 10:
            for state, f_val in astar.f_values.items():
                ax.text(state[1], state[0], f"{f_val:.1f}", 
                       ha='center', va='center', fontsize=7)
        
        # Mark obstacles
        for i in range(problem.rows):
            for j in range(problem.cols):
                if problem.grid[i][j] == 1:
                    rect = plt.Rectangle((j-0.5, i-0.5), 1, 1, 
                                        linewidth=0, facecolor='black')
                    ax.add_patch(rect)
    
    plt.tight_layout()
    plt.show()

# Visualize the comparison
visualize_astar_comparison(problem, results)

## 7. Heuristic Quality Analysis

In [None]:
def analyze_heuristic_quality(problem, heuristic_name, heuristic):
    """Analyze admissibility and consistency of a heuristic"""
    
    # Run A* to get optimal costs
    astar = AStarSearch(heuristic)
    solution, _ = astar.search(problem)
    
    # Get true costs using UCS (null heuristic)
    ucs = AStarSearch(null_heuristic)
    ucs.search(problem)
    
    print(f"\n{'='*50}")
    print(f"HEURISTIC ANALYSIS: {heuristic_name}")
    print(f"{'='*50}")
    
    # Check admissibility
    admissible = True
    overestimates = []
    
    for state in astar.exploration_order:
        h_value = heuristic(state, problem)
        
        # True cost from state to goal (if known from UCS)
        if state in ucs.g_values and problem.goal in ucs.g_values:
            true_cost_to_goal = ucs.g_values[problem.goal] - ucs.g_values[state]
            
            if h_value > true_cost_to_goal + 0.001:  # Small epsilon for floating point
                admissible = False
                overestimates.append((state, h_value, true_cost_to_goal))
    
    print(f"Admissible: {admissible}")
    if overestimates:
        print(f"Overestimates found: {len(overestimates)}")
        for state, h_val, true_val in overestimates[:3]:  # Show first 3
            print(f"  State {state}: h={h_val:.2f}, true={true_val:.2f}")
    
    # Check consistency
    consistent = True
    inconsistencies = []
    
    for state in astar.exploration_order:
        h_n = heuristic(state, problem)
        
        for successor, action, cost in problem.get_successors(state):
            h_n_prime = heuristic(successor, problem)
            
            if h_n > cost + h_n_prime + 0.001:  # Triangle inequality
                consistent = False
                inconsistencies.append((state, successor, h_n, h_n_prime, cost))
    
    print(f"Consistent: {consistent}")
    if inconsistencies:
        print(f"Inconsistencies found: {len(inconsistencies)}")
        for s1, s2, h1, h2, c in inconsistencies[:3]:  # Show first 3
            print(f"  {s1}->{s2}: h({s1})={h1:.2f} > cost={c:.2f} + h({s2})={h2:.2f}")
    
    # Heuristic statistics
    h_values = [heuristic(state, problem) for state in astar.exploration_order]
    print(f"\nHeuristic Statistics:")
    print(f"  Mean h-value: {np.mean(h_values):.2f}")
    print(f"  Max h-value: {np.max(h_values):.2f}")
    print(f"  Min h-value: {np.min(h_values):.2f}")
    print(f"  Nodes expanded: {astar.nodes_expanded}")

# Analyze each heuristic
for h_type in ['manhattan', 'euclidean']:
    heuristic = create_heuristic(goal, h_type)
    analyze_heuristic_quality(problem, h_type, heuristic)

## 8. Advanced: Pattern Database Heuristics

In [None]:
class PatternDatabaseHeuristic:
    """Pre-computed heuristic values for sliding puzzle problems"""
    
    def __init__(self, goal_state, pattern_tiles):
        """
        Pre-compute distances for a subset of tiles
        pattern_tiles: list of tile numbers to track
        """
        self.goal_state = goal_state
        self.pattern_tiles = pattern_tiles
        self.database = {}
        
        # Build pattern database using backward search from goal
        self._build_database()
    
    def _build_database(self):
        """Build database of pattern distances"""
        # Extract pattern from goal
        goal_pattern = self._extract_pattern(self.goal_state)
        
        # BFS from goal pattern
        from collections import deque
        
        queue = deque([(goal_pattern, 0)])
        self.database[goal_pattern] = 0
        
        while queue:
            pattern, dist = queue.popleft()
            
            # Generate all possible parent patterns
            for parent in self._get_parent_patterns(pattern):
                if parent not in self.database:
                    self.database[parent] = dist + 1
                    queue.append((parent, dist + 1))
    
    def _extract_pattern(self, state):
        """Extract pattern (positions of tracked tiles)"""
        # Simplified for demonstration
        return tuple(state[i] if i in self.pattern_tiles else -1 
                    for i in range(len(state)))
    
    def _get_parent_patterns(self, pattern):
        """Generate patterns that could lead to this one"""
        # Simplified - would need full sliding puzzle logic
        return []
    
    def __call__(self, state, problem=None):
        """Return heuristic value for state"""
        pattern = self._extract_pattern(state)
        return self.database.get(pattern, float('inf'))

print("Pattern databases are powerful for problems like:")
print("- Sliding puzzles (8-puzzle, 15-puzzle)")
print("- Rubik's cube")
print("- Planning problems with independent subgoals")
print("\nThey trade memory for speed by pre-computing exact distances.")

## 9. Greedy Best-First Search

A variant that only considers h(n), ignoring g(n):

In [None]:
class GreedyBestFirstSearch:
    """Greedy search using only heuristic value"""
    
    def __init__(self, heuristic):
        self.heuristic = heuristic
        self.nodes_expanded = 0
        self.exploration_order = []
    
    def search(self, problem):
        start_state = problem.get_start_state()
        
        if problem.is_goal_state(start_state):
            return [], 0
        
        # Priority based only on h(n)
        frontier = PriorityQueue()
        h_value = self.heuristic(start_state, problem)
        frontier.push(start_state, h_value, ([], 0))
        
        explored = set()
        
        while not frontier.is_empty():
            state, _, (path, cost) = frontier.pop()
            
            if state in explored:
                continue
            
            explored.add(state)
            self.nodes_expanded += 1
            self.exploration_order.append(state)
            
            if problem.is_goal_state(state):
                return path, cost
            
            for successor, action, step_cost in problem.get_successors(state):
                if successor not in explored:
                    h_value = self.heuristic(successor, problem)
                    frontier.push(successor, h_value, 
                                (path + [action], cost + step_cost))
        
        return None, float('inf')

# Compare Greedy with A*
print("\n" + "=" * 60)
print("GREEDY BEST-FIRST vs A*")
print("=" * 60)

manhattan = create_heuristic(goal, 'manhattan')

# Run Greedy
greedy = GreedyBestFirstSearch(manhattan)
greedy_solution, greedy_cost = greedy.search(problem)

# Run A*
astar = AStarSearch(manhattan)
astar_solution, astar_cost = astar.search(problem)

print(f"{'Algorithm':<15} {'Nodes Expanded':<15} {'Path Cost':<15} {'Optimal?'}")
print("-" * 60)
print(f"{'Greedy':<15} {greedy.nodes_expanded:<15} {greedy_cost:<15.1f} "
      f"{'No' if greedy_cost > astar_cost else 'Yes'}")
print(f"{'A*':<15} {astar.nodes_expanded:<15} {astar_cost:<15.1f} Yes")
print(f"\nGreedy is faster but found a path {greedy_cost - astar_cost:.1f} units longer")

## 10. Practice Exercises

### Exercise 1: Custom Heuristic Design
Design a heuristic that considers obstacles:

In [None]:
def obstacle_aware_heuristic(state, problem):
    """
    TODO: Design a heuristic that estimates better by considering obstacles
    Ideas:
    - Use wavefront/flood-fill from goal
    - Count minimum obstacles to cross
    - Combine with Manhattan distance
    """
    # Basic version - improve this!
    return manhattan_distance(state, problem.goal)

# Test your heuristic
# custom_astar = AStarSearch(obstacle_aware_heuristic)
# solution, cost = custom_astar.search(problem)
# print(f"Custom heuristic: {custom_astar.nodes_expanded} nodes expanded")

### Exercise 2: Weighted A*
Implement weighted A* that trades optimality for speed:

In [None]:
class WeightedAStar:
    """
    TODO: Implement weighted A* with f(n) = g(n) + w * h(n)
    - w > 1: Faster but suboptimal
    - w = 1: Standard A*
    - Guarantee: cost <= w * optimal_cost
    """
    
    def __init__(self, heuristic, weight=1.0):
        self.heuristic = heuristic
        self.weight = weight
    
    def search(self, problem):
        # TODO: Implement weighted A*
        pass

# Test different weights
# for w in [1.0, 1.5, 2.0, 5.0]:
#     weighted = WeightedAStar(manhattan, weight=w)
#     solution, cost = weighted.search(problem)
#     print(f"Weight {w}: cost={cost}, nodes={weighted.nodes_expanded}")

### Exercise 3: Iterative Deepening A* (IDA*)
Memory-efficient variant of A*:

In [None]:
class IDAstar:
    """
    TODO: Implement IDA*
    - Use depth-first search with f-value limit
    - Iteratively increase limit
    - Space complexity: O(bd) instead of O(b^d)
    """
    
    def __init__(self, heuristic):
        self.heuristic = heuristic
        self.nodes_expanded = 0
    
    def search(self, problem):
        # TODO: Implement iterative deepening with f-limit
        pass
    
    def depth_limited_search(self, problem, f_limit):
        # TODO: DFS with f-value cutoff
        pass

## 11. Key Takeaways

### A* Properties:
✅ **Optimal**: With admissible heuristic  
✅ **Complete**: Will find solution if exists  
✅ **Informed**: Uses domain knowledge via heuristic  
✅ **Efficient**: Often dramatically faster than UCS  
❌ **Memory intensive**: Must maintain frontier  
❌ **Heuristic dependent**: Bad heuristic → poor performance

### Heuristic Design Principles:
1. **Admissibility**: Never overestimate (for optimality)
2. **Consistency**: Satisfy triangle inequality (for efficiency)  
3. **Informedness**: h₁ dominates h₂ if h₁(n) ≥ h₂(n) for all n
4. **Computation**: Should be fast to compute

### Algorithm Variants:
- **Greedy Best-First**: Fast but not optimal (only h(n))
- **Weighted A***: Trade optimality for speed
- **IDA***: Memory-efficient iterative deepening
- **Pattern Databases**: Pre-computed perfect heuristics

## Next Steps
Next notebook: Game playing with Minimax and Alpha-Beta pruning!