## Taxi Environments

In [1]:
import numpy as np
import gymnasium as gym
import time

# Initialize the Taxi environment
env = gym.make('Taxi-v3', render_mode='ansi')
env = env.unwrapped  # Access the underlying environment to get transition probabilities

# Initialize value function
V = np.zeros(env.observation_space.n)
gamma = 0.9  # Discount factor
theta = 1e-6  # Convergence threshold

def one_step_lookahead(state, V):
    """
    Helper function to calculate the value for all actions in a given state.
    """
    A = np.zeros(env.action_space.n)
    for action in range(env.action_space.n):
        for prob, next_state, reward, done in env.P[state][action]:
            A[action] += prob * (reward + gamma * V[next_state])
    return A

# Value Iteration Algorithm
while True:
    delta = 0
    for state in range(env.observation_space.n):
        A = one_step_lookahead(state, V)
        best_action_value = np.max(A)
        delta = max(delta, np.abs(best_action_value - V[state]))
        V[state] = best_action_value
    if delta < theta:
        break

# Extract the optimal policy
policy = np.zeros([env.observation_space.n, env.action_space.n])
for state in range(env.observation_space.n):
    A = one_step_lookahead(state, V)
    best_action = np.argmax(A)
    policy[state, best_action] = 1.0

# Visualize agent's movement
def simulate_agent(env, policy):
    state, _ = env.reset()
    print(env.render())
    time.sleep(1)
    done = False
    while not done:
        action = np.argmax(policy[state])
        state, reward, done, truncated, info = env.step(action)
        print(env.render())
        time.sleep(1)

simulate_agent(env, policy)


+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[43m [0m: : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+


+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (East)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[43m [0m: |
|Y| : |[34;1mB[0m: |
+---------+
  (South)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[34;1m[43mB[0m[0m: |
+---------+
  (South)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[42mB[0m: |
+---------+
  (Pickup)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[42m_[0m: |
|Y| : |B: |
+---------+
  (North)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[42m_[0m: |
| | : | : |
|Y| : |B: |
+---------+
  (North)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[42m_[0m: : |
| | : | : |
|Y| : |B: |
+---------+
  (West)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| :[42m_[0m: : 

## Manual Replication

In [1]:
import numpy as np
import gym
from gym import spaces

class CustomTaxiMDP(gym.Env):
    """
    A customizable taxi MDP environment with:
    - Variable grid size
    - Custom pickup and drop-off locations
    - No explicit pickup/dropoff actions; entering pickup cell picks up passenger, entering dropoff cell completes the ride.
    - Safety incidents do not terminate the episode but penalize the agent.
    - Road types modify movement rewards.
    - Balanced rewards: large final dropoff reward, scenic rewards not exploitable indefinitely.
    """
    
    metadata = {"render_modes": ["ansi"], "render_fps": 4}

    def __init__(
        self,
        grid_rows=4,
        grid_cols=5,
        pickup_loc=(0,0),
        dropoff_loc=(4,4),
        incident_prob=0.001,         # base probability of a safety incident each step
        unsafe_squares=[],
        unsafe_incident_prob=0.01,   # additional probability if on unsafe square
        scenic_squares=[],
        scenic_reward=0.5,
        step_cost=-1,
        incident_cost=-10,           # cost for a safety incident
        turn_cost=-0.1,              # penalty for changing direction
        straight_reward=0.0,         # reward for going straight
        road_types=None,             # {(r,c): "highway"/"rough"/"scenic"/...}
        final_dropoff_reward=50.0,   # Large reward for completing dropoff
        seed=None
    ):
        super().__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.pickup_loc = pickup_loc
        self.dropoff_loc = dropoff_loc
        self.incident_prob = incident_prob
        self.unsafe_squares = set(unsafe_squares)
        self.unsafe_incident_prob = unsafe_incident_prob
        self.scenic_squares = set(scenic_squares)
        self.scenic_reward = scenic_reward
        self.step_cost = step_cost
        self.incident_cost = incident_cost
        self.turn_cost = turn_cost
        self.straight_reward = straight_reward
        self.final_dropoff_reward = final_dropoff_reward

        if road_types is None:
            self.road_types = {}
        else:
            self.road_types = road_types

        self.action_space = spaces.Discrete(4)
        # State representation: (agent_row, agent_col, passenger_picked_up, last_action)
        self.observation_space = spaces.MultiDiscrete([self.grid_rows, self.grid_cols, 2, 5])
        
        self.np_random, _ = gym.utils.seeding.np_random(seed)
        
        self.reset()

        # Build the transition model for planning
        self._build_transition_model()

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)
        self.agent_row = self.np_random.integers(self.grid_rows)
        self.agent_col = self.np_random.integers(self.grid_cols)
        
        # Ensure not starting at dropoff location
        while (self.agent_row, self.agent_col) == self.dropoff_loc:
            self.agent_row = self.np_random.integers(self.grid_rows)
            self.agent_col = self.np_random.integers(self.grid_cols)
        
        self.passenger_picked_up = False
        self.last_action = -1
        self.terminated = False

        # Performance metrics
        self.total_reward = 0.0
        self.steps = 0
        self.incidents_count = 0

        return self._get_obs(), {}

    def _get_obs(self):
        return (self.agent_row, self.agent_col, int(self.passenger_picked_up), self.last_action+1)

    def _compute_road_type_reward(self, r, c):
        rt = self.road_types.get((r, c), None)
        # Example logic:
        if rt == "highway":
            return 0.5  # partially offsets the step cost
        elif rt == "rough":
            return -0.5 # makes movement more costly
        return 0.0

    def step(self, action):
        prev_action = self.last_action

        # Move
        new_row = self.agent_row
        new_col = self.agent_col
        if action == 0 and self.agent_row > 0:
            new_row -= 1
        elif action == 1 and self.agent_row < self.grid_rows - 1:
            new_row += 1
        elif action == 2 and self.agent_col > 0:
            new_col -= 1
        elif action == 3 and self.agent_col < self.grid_cols - 1:
            new_col += 1
        
        self.agent_row = new_row
        self.agent_col = new_col
        
        # Base reward
        reward = self.step_cost + self._compute_road_type_reward(new_row, new_col)

        # Turn or straight
        if prev_action != -1 and prev_action != action:
            reward += self.turn_cost
        elif prev_action != -1 and prev_action == action:
            reward += self.straight_reward
            
        # Pickup
        if (self.agent_row, self.agent_col) == self.pickup_loc and not self.passenger_picked_up:
            self.passenger_picked_up = True
        
        # Scenic
        if (self.agent_row, self.agent_col) in self.scenic_squares:
            reward += self.scenic_reward
        
        # Check dropoff
        if (self.agent_row, self.agent_col) == self.dropoff_loc and self.passenger_picked_up:
            # Add large final reward
            reward += self.final_dropoff_reward
            self.terminated = True
        
        # Incident Probability
        local_incident_prob = self.incident_prob
        if (self.agent_row, self.agent_col) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        # Incident
        if self.np_random.random() < local_incident_prob:
            reward += self.incident_cost
            self.incidents_count += 1
        
        self.last_action = action
        self.total_reward += reward
        self.steps += 1

        return self._get_obs(), reward, self.terminated, False, {}

    def render(self, mode='ansi'):
        grid = np.full((self.grid_rows, self.grid_cols), '.', dtype=object)
        
        for (r,c) in self.unsafe_squares:
            grid[r,c] = 'U'
        for (r,c) in self.scenic_squares:
            if grid[r,c] == 'U':
                grid[r,c] = 'US'
            else:
                grid[r,c] = 'S'
        
        for (rc, cc), rt in self.road_types.items():
            if rt == 'highway':
                symbol = 'H'
            elif rt == 'rough':
                symbol = 'R'
            else:
                symbol = grid[rc, cc]
            if grid[rc,cc] not in ['.', 'H', 'R', 'S', 'U', 'US']:
                symbol = grid[rc,cc] + rt[0].upper()
            grid[rc, cc] = symbol

        pr, pc = self.pickup_loc
        grid[pr, pc] = 'P'
        dr, dc = self.dropoff_loc
        grid[dr, dc] = 'D'
        
        ar, ac = self.agent_row, self.agent_col
        agent_marker = 'A*' if self.passenger_picked_up else 'A'
        grid[ar, ac] = agent_marker
        
        output = ''
        for r in range(self.grid_rows):
            row_str = ' '.join(f"{cell:2}" for cell in grid[r,:])
            output += row_str + '\n'
        return output

    def close(self):
        pass

    def _state_to_tuple(self, s):
        return tuple(s)

    def _get_all_states(self):
        states = []
        for r in range(self.grid_rows):
            for c in range(self.grid_cols):
                for pu in [0,1]:
                    for la in range(5):
                        states.append((r,c,pu,la))
        return states

    def _build_transition_model(self):
        self.P = {}
        all_states = self._get_all_states()
        for s in all_states:
            self.P[s] = {}
            for a in range(self.action_space.n):
                obs, reward, done = self._simulate_transition(s, a)
                self.P[s][a] = (obs, reward, done)
    
    def _simulate_transition(self, s, a):
        (r,c,pu,la) = s
        last_action = la - 1
        
        new_r, new_c = r,c
        if a == 0 and r > 0:
            new_r -= 1
        elif a == 1 and r < self.grid_rows - 1:
            new_r += 1
        elif a == 2 and c > 0:
            new_c -= 1
        elif a == 3 and c < self.grid_cols - 1:
            new_c += 1

        reward = self.step_cost + self._compute_road_type_reward(new_r, new_c)
        
        if last_action != -1 and last_action != a:
            reward += self.turn_cost
        elif last_action != -1 and last_action == a:
            reward += self.straight_reward
        
        passenger_picked = (pu == 1)
        
        if (new_r, new_c) == self.pickup_loc and not passenger_picked:
            passenger_picked = True
        
        if (new_r, new_c) in self.scenic_squares:
            reward += self.scenic_reward
        
        done = False
        if (new_r, new_c) == self.dropoff_loc and passenger_picked:
            reward += self.final_dropoff_reward
            done = True

        local_incident_prob = self.incident_prob
        if (new_r, new_c) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        reward += local_incident_prob * self.incident_cost
        
        new_pu = 1 if passenger_picked else 0
        new_la = a + 1
        next_s = (new_r, new_c, new_pu, new_la)
        
        return next_s, reward, done

    def compute_optimal_policy(self, gamma=0.99, theta=1e-6):
        all_states = self._get_all_states()
        V = {s:0.0 for s in all_states}

        def is_terminal(s):
            (r,c,pu,la) = s
            if (r,c) == self.dropoff_loc and pu == 1:
                return True
            return False

        while True:
            delta = 0
            for s in all_states:
                if is_terminal(s):
                    continue
                v = V[s]
                q_values = []
                for a in range(self.action_space.n):
                    (next_s, r, done) = self.P[s][a]
                    q_val = r + gamma*(0 if done else V[next_s])
                    q_values.append(q_val)
                V[s] = max(q_values)
                delta = max(delta, abs(v - V[s]))
            if delta < theta:
                break
        
        pi = {}
        for s in all_states:
            if is_terminal(s):
                pi[s] = 0
                continue
            q_values = []
            for a in range(self.action_space.n):
                (next_s, r, done) = self.P[s][a]
                q_val = r + gamma*(0 if done else V[next_s])
                q_values.append(q_val)
            pi[s] = np.argmax(q_values)
        
        return V, pi

    def visualize_policy(self, pi, max_steps=50):
        obs, info = self.reset()
        step_count = 0
        print("Initial State:")
        print(self.render())
        while step_count < max_steps:
            s = tuple(obs)
            a = pi[s]
            obs, r, done, truncated, _ = self.step(a)
            print(f"Step: {step_count}, Action: {a}, Reward: {r}")
            print(self.render())
            if done:
                print("Episode finished!")
                break
            step_count += 1
        if step_count >= max_steps:
            print("Reached max steps without termination.")

        print("Performance Metrics:")
        print(f"  Total Reward: {self.total_reward}")
        print(f"  Steps Taken: {self.steps}")
        print(f"  Safety Incidents: {self.incidents_count}")

if __name__ == "__main__":
    # Updated realistic configuration with balanced rewards
    env = CustomTaxiMDP(
        grid_rows=15,
        grid_cols=15,
        pickup_loc=(0, 0),
        dropoff_loc=(14, 14),
        unsafe_squares=[
            (2, 5), (2, 6), (2, 7),
            (3, 5), (3, 6), (3, 7),
            (4, 5), (4, 6), (4, 7)
        ],
        scenic_squares=[
            (0, 11), (0, 12), (0, 13),
            (8, 2), (8, 3), (8, 4),
            (8, 8), (8, 9), (8, 10)
        ],
        incident_prob=0.001,
        unsafe_incident_prob=0.02,
        scenic_reward=1.0,         # Scenic gives +1 but step cost is -1, net 0 if staying put
        step_cost=-1,
        incident_cost=-20,
        turn_cost=-0.3,
        straight_reward=0.5,
        final_dropoff_reward=500.0, # Large final reward to ensure completing the trip is best
        road_types={
            # Highways (just some examples)
            (0, 0): 'highway', (0, 1): 'highway', (0, 2): 'highway', 
            (6, 0): 'highway', (6, 1): 'highway', (6, 2): 'highway',
            (11, 0): 'highway', (11, 1): 'highway', (11, 2): 'highway',

            # Rough roads
            (2, 5): 'rough', (2, 6): 'rough', (2, 7): 'rough',
            (3, 5): 'rough', (3, 6): 'rough', (3, 7): 'rough',
            (4, 5): 'rough', (4, 6): 'rough', (4, 7): 'rough',

            # Scenic roads
            (0, 11): 'scenic', (0, 12): 'scenic', (0, 13): 'scenic',
            (8, 2): 'scenic', (8, 3): 'scenic', (8, 4): 'scenic',
            (8, 8): 'scenic', (8, 9): 'scenic', (8, 10): 'scenic'
        }
    )

    V, pi = env.compute_optimal_policy(gamma=0.99, theta=1e-6)
    print("Optimal Policy Visualization:")
    env.visualize_policy(pi, max_steps=100)


Optimal Policy Visualization:
Initial State:
P  H  H  .  .  .  .  .  .  .  .  S  S  S  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  A  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
H  H  H  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  S  S  S  .  .  .  S  S  S  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
H  H  H  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  D 

Step: 0, Action: 0, Reward: -1.5
P  H  H  .  .  .  .  .  .  .  .  S  S  S  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  A  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  

### Pygame

In [2]:
import networkx as nx
import pygame
import math
import random
import numpy as np

class RoadNetworkMDP:
    def __init__(self, gamma=0.95, driver_error_prob=0.1, delay_prob=0.1):
        self.graph = nx.DiGraph()
        self.gamma = gamma
        self.driver_error_prob = driver_error_prob
        self.delay_prob = delay_prob

        self.states = []
        self.actions = []
        self.state_to_idx = {}
        self.action_to_idx = {}
        self.feature_map = {}

    def add_state(self, state_id, features):
        self.graph.add_node(state_id, **features)

    def add_action(self, from_state, to_state, action):
        # Only local connections already ensured outside (no random edges).
        self.graph.add_edge(from_state, to_state, action=action)

    def set_rewards(self, reward_weights):
        for u, v, data in self.graph.edges(data=True):
            from_features = self.graph.nodes[u]
            reward = sum(reward_weights.get(k,0)*from_features.get(k,0) for k in reward_weights)
            data["reward"] = reward

    def get_states(self):
        return list(self.graph.nodes)

    def get_actions(self, s):
        return list(self.graph.successors(s))

    def is_terminal(self, s, goal):
        return s == goal

    def transition(self, s, a):
        successors = list(self.graph.successors(s))
        if self.is_terminal(s, a):
            intended_reward = self.graph[s][a]["reward"]
        else:
            intended_reward = self.graph[s][a]["reward"] if self.graph.has_edge(s,a) else 0

        if len(successors) == 0:
            return s, 0.0, True

        p_delay = self.delay_prob
        p_error = self.driver_error_prob if len(successors) > 1 else 0.0
        p_intended = 1.0 - p_delay - p_error
        if p_intended < 0:
            p_intended = 0.0

        rand_val = random.random()
        if rand_val < p_delay:
            return s, 0.0, self.is_terminal(s, a)
        elif rand_val < p_delay + p_error and len(successors) > 1:
            possible_errors = [x for x in successors if x != a]
            chosen = random.choice(possible_errors)
            r = self.graph[s][chosen]["reward"] if self.graph.has_edge(s,chosen) else 0
            return chosen, r, self.is_terminal(chosen, a)
        else:
            return a, intended_reward, self.is_terminal(a, a)

    # Value iteration code remains, but we won't use the resulting policy.
    def value_iteration(self, goal, theta=1e-5):
        states = self.get_states()
        V = {s:0.0 for s in states}

        while True:
            delta = 0
            for s in states:
                if self.is_terminal(s, goal):
                    V[s] = 0.0
                    continue
                actions = self.get_actions(s)
                if not actions:
                    V[s] = 0.0
                    continue
                q_values = []
                for a in actions:
                    successors = list(self.graph.successors(s))
                    p_delay = self.delay_prob
                    p_error = self.driver_error_prob if len(successors) > 1 else 0.0
                    p_intended = 1.0 - p_delay - p_error
                    if p_intended < 0:
                        p_intended = 0.0

                    r_intended = self.graph[s][a]["reward"] if self.graph.has_edge(s,a) else 0
                    v_intended = r_intended + self.gamma*V[a]

                    v_delay = self.gamma*V[s]

                    v_error = 0.0
                    if p_error > 0:
                        possible_errors = [x for x in successors if x != a]
                        if possible_errors:
                            error_vals = []
                            for e_succ in possible_errors:
                                r_err = self.graph[s][e_succ]["reward"]
                                error_vals.append(r_err + self.gamma*V[e_succ])
                            v_error = sum(error_vals)/len(error_vals)
                        else:
                            v_error = v_intended

                    ev = p_delay*v_delay + p_error*v_error + p_intended*v_intended
                    q_values.append(ev)

                new_val = max(q_values)
                delta = max(delta, abs(V[s]-new_val))
                V[s] = new_val
            if delta < theta:
                break

        policy = {}
        for s in states:
            if self.is_terminal(s, goal) or not self.get_actions(s):
                policy[s] = None
            else:
                actions = self.get_actions(s)
                best_act = None
                best_val = float('-inf')
                for a in actions:
                    successors = list(self.graph.successors(s))
                    p_delay = self.delay_prob
                    p_error = self.driver_error_prob if len(successors) > 1 else 0.0
                    p_intended = 1.0 - p_delay - p_error

                    r_intended = self.graph[s][a]["reward"] if self.graph.has_edge(s,a) else 0
                    v_intended = r_intended + self.gamma*V[a]
                    v_delay = self.gamma*V[s]

                    v_error = 0.0
                    if p_error > 0:
                        possible_errors = [x for x in successors if x != a]
                        if possible_errors:
                            error_vals = []
                            for e_succ in possible_errors:
                                r_err = self.graph[s][e_succ]["reward"]
                                error_vals.append(r_err + self.gamma*V[e_succ])
                            v_error = sum(error_vals)/len(error_vals)
                        else:
                            v_error = v_intended

                    ev = p_delay*v_delay + p_error*v_error + p_intended*v_intended
                    if ev > best_val:
                        best_val = ev
                        best_act = a
                policy[s] = best_act
        return V, policy

    def build_indexing(self):
        self.states = self.get_states()
        self.state_to_idx = {s:i for i,s in enumerate(self.states)}

        edge_actions = []
        for s in self.states:
            for a in self.get_actions(s):
                edge_actions.append((s,a))
        edge_actions = list(set(edge_actions))
        self.actions = edge_actions
        self.action_to_idx = {a:i for i,a in enumerate(self.actions)}

    def set_feature_map(self):
        for s in self.states:
            node_data = self.graph.nodes[s]
            f = np.array([node_data['length'], node_data['turn_penalty'], node_data['road_type_val']])
            self.feature_map[s] = f

    def compute_rewards_from_weights(self, weights):
        R_s = np.zeros(len(self.states))
        for s in self.states:
            s_idx = self.state_to_idx[s]
            R_s[s_idx] = np.dot(weights, self.feature_map[s])
        return R_s

    def to_transition_matrix(self):
        N_S = len(self.states)
        N_A = len(self.actions)
        P_a = np.zeros((N_S, N_S, N_A))
        R_sa = np.zeros((N_S, N_A))

        for (orig_s, intended_next) in self.actions:
            s_idx = self.state_to_idx[orig_s]
            a_idx = self.action_to_idx[(orig_s,intended_next)]
            successors = list(self.graph.successors(orig_s))

            p_delay = self.delay_prob
            p_error = self.driver_error_prob if len(successors) > 1 else 0.0
            p_intended = 1.0 - p_delay - p_error
            if p_intended < 0:
                p_intended = 0.0

            r_intended = self.graph[orig_s][intended_next]["reward"] if self.graph.has_edge(orig_s,intended_next) else 0

            if p_delay > 0:
                P_a[s_idx, s_idx, a_idx] += p_delay

            if p_error > 0:
                possible_errors = [x for x in successors if x != intended_next]
                if possible_errors:
                    p_each_error = p_error / len(possible_errors)
                    for es in possible_errors:
                        es_idx = self.state_to_idx[es]
                        P_a[s_idx, es_idx, a_idx] += p_each_error

            if p_intended > 0:
                intended_idx = self.state_to_idx[intended_next]
                P_a[s_idx, intended_idx, a_idx] += p_intended

            r_error = 0.0
            if p_error > 0 and possible_errors:
                err_vals = []
                for es in possible_errors:
                    err_r = self.graph[orig_s][es]["reward"]
                    err_vals.append(err_r)
                r_error = sum(err_vals)/len(err_vals)

            r_delay = 0.0
            R_sa[s_idx, a_idx] = p_delay*r_delay + p_error*r_error + p_intended*r_intended

        return P_a, R_sa

    def generate_demonstrations(self, policy, n_trajs, max_length, start_state=None):
        # not used in final loop, no changes here
        demonstrations = []
        for _ in range(n_trajs):
            if start_state is None:
                s = random.choice(self.states)
            else:
                s = start_state
            traj = []
            for _ in range(max_length):
                if policy[s] is None:
                    break
                a = policy[s]
                successors = list(self.graph.successors(s))
                p_delay = self.delay_prob
                p_error = self.driver_error_prob if len(successors) > 1 else 0.0
                p_intended = 1.0 - p_delay - p_error
                if p_intended < 0:
                    p_intended = 0.0
                r_intended = self.graph[s][a]["reward"] if self.graph.has_edge(s,a) else 0

                outcomes = []
                if p_delay > 0:
                    outcomes.append((p_delay, s, 0.0, self.is_terminal(s,a)))
                if p_error > 0:
                    possible_errors = [x for x in successors if x != a]
                    if possible_errors:
                        p_err_each = p_error/len(possible_errors)
                        for es in possible_errors:
                            err_r = self.graph[s][es]["reward"] if self.graph.has_edge(s,es) else 0
                            outcomes.append((p_err_each, es, err_r, self.is_terminal(es,a)))
                    else:
                        outcomes.append((p_error, a, r_intended, self.is_terminal(a,a)))
                if p_intended > 0:
                    outcomes.append((p_intended, a, r_intended, self.is_terminal(a,a)))

                total_p = sum(x[0] for x in outcomes)
                if total_p > 0:
                    outcomes = [(p/total_p, ns, rr, dd) for (p,ns,rr,dd) in outcomes]
                else:
                    outcomes = [(1.0, s, 0.0, True)]

                probs = [x[0] for x in outcomes]
                idx = np.random.choice(len(outcomes), p=probs)
                p_choice, s_next, r, done = outcomes[idx]

                traj.append((s, a, s_next, r))
                s = s_next
                if done:
                    break
            demonstrations.append(traj)
        return demonstrations

    def compute_state_visitation_frequencies(self, policy, horizon=50):
        if not self.states:
            self.build_indexing()
        P_a, R_sa = self.to_transition_matrix()

        pi = np.zeros((len(self.states), len(self.actions)))
        for s in self.states:
            s_idx = self.state_to_idx[s]
            chosen_a = policy[s]
            if chosen_a is None:
                continue
            a_idx = self.action_to_idx[(s, chosen_a)]
            pi[s_idx, a_idx] = 1.0

        N_S = len(self.states)
        P_pi = np.zeros((N_S, N_S))
        for s_idx in range(N_S):
            for a_idx in range(len(self.actions)):
                for s_next_idx in range(N_S):
                    P_pi[s_next_idx, s_idx] += pi[s_idx,a_idx]*P_a[s_idx,s_next_idx,a_idx]

        mu = np.ones(N_S)/N_S
        svf = np.zeros(N_S)
        for t in range(horizon):
            svf += mu
            mu = P_pi.T.dot(mu)
        svf /= horizon
        return svf


# ---------------------------------------------
# Pygame Visualization and main example usage
# ---------------------------------------------
pygame.init()
WIDTH, HEIGHT = 1000, 800
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Road Network MDP")

WHITE = (245,245,245)
BLACK = (30,30,30)
BLUE = (168,216,234)
GREEN = (118,200,147)
RED = (242,132,130)
YELLOW = (246,215,167)
GRAY = (200,200,200)
PATH_COLOR = (203,170,203)

font = pygame.font.SysFont("Arial", 24, bold=False)
car_image = pygame.image.load("data/car.png")
car_image = pygame.transform.scale(car_image, (30, 15))

def draw_nodes(screen, node_positions, graph, start, goal, visited=set(), hover_node=None):
    for node, pos in node_positions.items():
        if node == start:
            color = GREEN
        elif node == goal:
            color = RED
        elif node in visited:
            color = YELLOW
        else:
            color = BLUE
        pygame.draw.circle(screen, color, pos, 15, 0)

    if hover_node is not None:
        features = graph.nodes[hover_node]
        info_text = f"{hover_node}: len={features['length']:.2f}, turn={features['turn_penalty']:.2f}, rt={features['road_type_val']:.2f}"
        info_label = font.render(info_text, True, BLACK)
        mouse_x, mouse_y = node_positions[hover_node]
        pygame.draw.rect(screen, GRAY, (mouse_x, mouse_y-50, info_label.get_width()+10, info_label.get_height()+10), border_radius=5)
        screen.blit(info_label, (mouse_x+5, mouse_y - 50 + 5))

def draw_edges(screen, node_positions, graph, threshold=-1, path_edges=set()):
    for u, v, data in graph.edges(data=True):
        start = node_positions[u]
        end = node_positions[v]
        if (u,v) in path_edges:
            edge_color = PATH_COLOR
            line_width = 3
        else:
            edge_color = RED if data["reward"] > threshold else BLACK
            line_width = 2

        # Just draw a line, no arrow head
        pygame.draw.aaline(screen, edge_color, start, end)

        # Reward annotation
        mid_x = (start[0]+end[0])//2
        mid_y = (start[1]+end[1])//2
        reward_text = f"{data['reward']:.1f}"
        rw_label = font.render(reward_text, True, BLACK)
        screen.blit(rw_label, (int(mid_x) - rw_label.get_width()//2, int(mid_y) - rw_label.get_height()//2))

def run_episode(mdp, start, goal, node_positions):
    # Soft policy: choose a random available action at each step
    s = start
    visited = set()
    path = [s]
    total_return = 0.0
    steps = 100
    for _ in range(steps):
        actions = mdp.get_actions(s)
        if not actions or mdp.is_terminal(s, goal):
            break
        a = random.choice(actions)
        s_next, r, done = mdp.transition(s, a)
        total_return += r
        path.append(s_next)
        s = s_next
        if done:
            break
    return path, total_return

def animate_trajectory(screen, path, node_positions, graph, path_edges, total_return):
    visited = set()
    steps = 100
    for i in range(len(path)-1):
        start_pos = node_positions[path[i]]
        end_pos = node_positions[path[i+1]]
        for t in range(steps):
            x = start_pos[0] + (end_pos[0]-start_pos[0])*t/steps
            y = start_pos[1] + (end_pos[1]-start_pos[1])*t/steps
            visited.add(path[i])
            screen.fill(WHITE)
            draw_nodes(screen, node_positions, graph, path[0], path[-1], visited)
            draw_edges(screen, node_positions, graph, path_edges=path_edges)
            screen.blit(car_image, (int(x)-15, int(y)-7))
            info_label = font.render("Animating Trajectory...", True, BLACK)
            screen.blit(info_label, (10,10))
            pygame.display.flip()
            pygame.time.delay(20)
    visited.add(path[-1])
    screen.fill(WHITE)
    draw_nodes(screen, node_positions, graph, path[0], path[-1], visited)
    draw_edges(screen, node_positions, graph, path_edges=path_edges)
    final_pos = node_positions[path[-1]]
    screen.blit(car_image, (int(final_pos[0])-15, int(final_pos[1])-7))
    # Show total return
    return_label = font.render(f"Episode Return: {total_return:.2f}", True, BLACK)
    screen.blit(return_label, (10,40))
    pygame.display.flip()
    pygame.time.delay(1000)

# Construct a grid MDP with only right and down edges
mdp = RoadNetworkMDP(gamma=0.95, driver_error_prob=0.1, delay_prob=0.1)

rows = 4
cols = 5
states = []
for r in range(rows):
    for c in range(cols):
        s = f"S{r*cols+c}"
        states.append(s)
        features = {
            "length": random.uniform(1,5),
            "turn_penalty": random.uniform(0,1),
            "road_type_val": random.choice([0.5,0.0,-0.5])
        }
        mdp.add_state(s, features)

# Only local connections: right and down
for r in range(rows):
    for c in range(cols):
        s = f"S{r*cols+c}"
        if c < cols-1:
            right = f"S{r*cols+(c+1)}"
            mdp.add_action(s, right, "Go straight")
        if r < rows-1:
            down = f"S{(r+1)*cols+c}"
            mdp.add_action(s, down, "Go straight")

reward_weights = {"length": -1.0, "turn_penalty": -0.5, "road_type_val": 0.2}
mdp.set_rewards(reward_weights)

# Start and goal at opposite corners
# Explicitly compute the start and goal positions
start = f"S0"  # Top-left corner is always "S0"
goal = f"S{(rows-1)*cols + (cols-1)}"  # Bottom-right corner index

mdp.build_indexing()
mdp.set_feature_map()

spacing_x = WIDTH//(cols+1)
spacing_y = HEIGHT//(rows+1)
node_positions = {}
for r in range(rows):
    for c in range(cols):
        s = f"S{r*cols+c}"
        x = (c+1)*spacing_x
        y = (r+1)*spacing_y
        node_positions[s] = (x,y)

running = True
hover_node = None

# Infinite loop of episodes
while running:
    # Run one episode with a random policy
    path, total_return = run_episode(mdp, start, goal, node_positions)

    # Highlight path edges
    path_edges = set()
    for i in range(len(path)-1):
        path_edges.add((path[i], path[i+1]))

    # Animate trajectory
    animate_trajectory(screen, path, node_positions, mdp.graph, path_edges, total_return)

    # After showing the return, start next episode immediately (no button press)
    # If user closes window, break
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

pygame.quit()


: 

In [2]:
import pygame
import random

# Initialize pygame
pygame.init()

# Set dimensions for debugging visualization
WIDTH, HEIGHT = 600, 400
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Grid Debugging Visualization")

# Colors
WHITE = (245, 245, 245)
BLUE = (168, 216, 234)
GREEN = (118, 200, 147)
RED = (242, 132, 130)
BLACK = (30, 30, 30)

# Font for labels
font = pygame.font.SysFont("Arial", 24, bold=False)

# Create a grid layout
rows, cols = 4, 5
spacing_x = WIDTH // (cols + 1)
spacing_y = HEIGHT // (rows + 1)

# Define node positions
node_positions = {}
for r in range(rows):
    for c in range(cols):
        s = f"S{r*cols + c}"
        x = (c + 1) * spacing_x
        y = (r + 1) * spacing_y
        node_positions[s] = (x, y)

# Define the diagonally opposite corner pairs
diagonal_pairs = [("S0", "S19"), ("S4", "S15")]

# Randomly select one diagonal pair
start, goal = random.choice(diagonal_pairs)

# Debugging visualization loop
running = True
while running:
    screen.fill(WHITE)  # Clear the screen

    # Draw nodes
    for node, pos in node_positions.items():
        if node == start:
            color = GREEN  # Start node
        elif node == goal:
            color = RED  # Goal node
        else:
            color = BLUE  # Regular node
        pygame.draw.circle(screen, color, pos, 15, 0)

        # Draw node label
        label = font.render(node, True, BLACK)
        screen.blit(label, (pos[0] - label.get_width() // 2, pos[1] - label.get_height() // 2))

    # Draw edges (right and down connections)
    for r in range(rows):
        for c in range(cols):
            current = f"S{r*cols + c}"
            if c < cols - 1:  # Right connection
                right = f"S{r*cols + (c + 1)}"
                pygame.draw.line(screen, BLACK, node_positions[current], node_positions[right], 2)
            if r < rows - 1:  # Down connection
                down = f"S{(r + 1) * cols + c}"
                pygame.draw.line(screen, BLACK, node_positions[current], node_positions[down], 2)

    # Refresh the display
    pygame.display.flip()

    # Event handling for quitting
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

pygame.quit()


In [2]:
print("Node positions:")
for state, pos in node_positions.items():
    print(f"{state}: {pos}")

Node positions:
S0: (166, 160)
S1: (332, 160)
S2: (498, 160)
S3: (664, 160)
S4: (830, 160)
S5: (166, 320)
S6: (332, 320)
S7: (498, 320)
S8: (664, 320)
S9: (830, 320)
S10: (166, 480)
S11: (332, 480)
S12: (498, 480)
S13: (664, 480)
S14: (830, 480)
S15: (166, 640)
S16: (332, 640)
S17: (498, 640)
S18: (664, 640)
S19: (830, 640)
