In [4]:
import numpy as np
import gymnasium as gym
import time

# Initialize the Taxi environment
env = gym.make('Taxi-v3', render_mode='ansi')
env = env.unwrapped  # Access the underlying environment to get transition probabilities

# Initialize value function
V = np.zeros(env.observation_space.n)
gamma = 0.9  # Discount factor
theta = 1e-6  # Convergence threshold

def one_step_lookahead(state, V):
    """
    Helper function to calculate the value for all actions in a given state.
    """
    A = np.zeros(env.action_space.n)
    for action in range(env.action_space.n):
        for prob, next_state, reward, done in env.P[state][action]:
            A[action] += prob * (reward + gamma * V[next_state])
    return A

# Value Iteration Algorithm
while True:
    delta = 0
    for state in range(env.observation_space.n):
        A = one_step_lookahead(state, V)
        best_action_value = np.max(A)
        delta = max(delta, np.abs(best_action_value - V[state]))
        V[state] = best_action_value
    if delta < theta:
        break

# Extract the optimal policy
policy = np.zeros([env.observation_space.n, env.action_space.n])
for state in range(env.observation_space.n):
    A = one_step_lookahead(state, V)
    best_action = np.argmax(A)
    policy[state, best_action] = 1.0

# Visualize agent's movement
def simulate_agent(env, policy):
    state, _ = env.reset()
    print(env.render())
    time.sleep(1)
    done = False
    while not done:
        action = np.argmax(policy[state])
        state, reward, done, truncated, info = env.step(action)
        print(env.render())
        time.sleep(1)

simulate_agent(env, policy)


+---------+
|[35mR[0m: |[43m [0m: :[34;1mG[0m|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+


+---------+
|[35mR[0m: | :[43m [0m:[34;1mG[0m|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (East)

+---------+
|[35mR[0m: | : :[34;1m[43mG[0m[0m|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (East)

+---------+
|[35mR[0m: | : :[42mG[0m|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (Pickup)

+---------+
|[35mR[0m: | : :G|
| : | : :[42m_[0m|
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (South)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : :[42m_[0m|
| | : | : |
|Y| : |B: |
+---------+
  (South)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[42m_[0m: |
| | : | : |
|Y| : |B: |
+---------+
  (West)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[42m_[0m: : |
| | : | : |
|Y| : |B: |
+---------+
  (West)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| :[42m_[0m: : : |
| | : | :

In [49]:
import numpy as np
import gym
from gym import spaces

class CustomTaxiMDP(gym.Env):
    """
    A customizable taxi MDP environment with:
    - Variable grid size
    - Custom pickup and drop-off locations
    - No explicit pickup/dropoff actions; entering pickup cell picks up passenger, entering dropoff cell completes the ride.
    - Safety incidents do not terminate the episode but penalize the agent.
    - Road types modify movement rewards.
    - Balanced rewards: large final dropoff reward, scenic rewards not exploitable indefinitely.
    """
    
    metadata = {"render_modes": ["ansi"], "render_fps": 4}

    def __init__(
        self,
        grid_rows=5,
        grid_cols=5,
        pickup_loc=(0,0),
        dropoff_loc=(4,4),
        incident_prob=0.001,         # base probability of a safety incident each step
        unsafe_squares=[],
        unsafe_incident_prob=0.01,   # additional probability if on unsafe square
        scenic_squares=[],
        scenic_reward=0.5,
        step_cost=-1,
        incident_cost=-10,           # cost for a safety incident
        turn_cost=-0.1,              # penalty for changing direction
        straight_reward=0.0,         # reward for going straight
        road_types=None,             # {(r,c): "highway"/"rough"/"scenic"/...}
        final_dropoff_reward=50.0,   # Large reward for completing dropoff
        seed=None
    ):
        super().__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.pickup_loc = pickup_loc
        self.dropoff_loc = dropoff_loc
        self.incident_prob = incident_prob
        self.unsafe_squares = set(unsafe_squares)
        self.unsafe_incident_prob = unsafe_incident_prob
        self.scenic_squares = set(scenic_squares)
        self.scenic_reward = scenic_reward
        self.step_cost = step_cost
        self.incident_cost = incident_cost
        self.turn_cost = turn_cost
        self.straight_reward = straight_reward
        self.final_dropoff_reward = final_dropoff_reward

        if road_types is None:
            self.road_types = {}
        else:
            self.road_types = road_types

        self.action_space = spaces.Discrete(4)
        # State representation: (agent_row, agent_col, passenger_picked_up, last_action)
        self.observation_space = spaces.MultiDiscrete([self.grid_rows, self.grid_cols, 2, 5])
        
        self.np_random, _ = gym.utils.seeding.np_random(seed)
        
        self.reset()

        # Build the transition model for planning
        self._build_transition_model()

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)
        self.agent_row = self.np_random.integers(self.grid_rows)
        self.agent_col = self.np_random.integers(self.grid_cols)
        
        # Ensure not starting at dropoff location
        while (self.agent_row, self.agent_col) == self.dropoff_loc:
            self.agent_row = self.np_random.integers(self.grid_rows)
            self.agent_col = self.np_random.integers(self.grid_cols)
        
        self.passenger_picked_up = False
        self.last_action = -1
        self.terminated = False

        # Performance metrics
        self.total_reward = 0.0
        self.steps = 0
        self.incidents_count = 0

        return self._get_obs(), {}

    def _get_obs(self):
        return (self.agent_row, self.agent_col, int(self.passenger_picked_up), self.last_action+1)

    def _compute_road_type_reward(self, r, c):
        rt = self.road_types.get((r, c), None)
        # Example logic:
        if rt == "highway":
            return 0.5  # partially offsets the step cost
        elif rt == "rough":
            return -0.5 # makes movement more costly
        return 0.0

    def step(self, action):
        prev_action = self.last_action

        # Move
        new_row = self.agent_row
        new_col = self.agent_col
        if action == 0 and self.agent_row > 0:
            new_row -= 1
        elif action == 1 and self.agent_row < self.grid_rows - 1:
            new_row += 1
        elif action == 2 and self.agent_col > 0:
            new_col -= 1
        elif action == 3 and self.agent_col < self.grid_cols - 1:
            new_col += 1
        
        self.agent_row = new_row
        self.agent_col = new_col
        
        # Base reward
        reward = self.step_cost + self._compute_road_type_reward(new_row, new_col)

        # Turn or straight
        if prev_action != -1 and prev_action != action:
            reward += self.turn_cost
        elif prev_action != -1 and prev_action == action:
            reward += self.straight_reward
            
        # Pickup
        if (self.agent_row, self.agent_col) == self.pickup_loc and not self.passenger_picked_up:
            self.passenger_picked_up = True
        
        # Scenic
        if (self.agent_row, self.agent_col) in self.scenic_squares:
            reward += self.scenic_reward
        
        # Check dropoff
        if (self.agent_row, self.agent_col) == self.dropoff_loc and self.passenger_picked_up:
            # Add large final reward
            reward += self.final_dropoff_reward
            self.terminated = True
        
        # Incident Probability
        local_incident_prob = self.incident_prob
        if (self.agent_row, self.agent_col) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        # Incident
        if self.np_random.random() < local_incident_prob:
            reward += self.incident_cost
            self.incidents_count += 1
        
        self.last_action = action
        self.total_reward += reward
        self.steps += 1

        return self._get_obs(), reward, self.terminated, False, {}

    def render(self, mode='ansi'):
        grid = np.full((self.grid_rows, self.grid_cols), '.', dtype=object)
        
        for (r,c) in self.unsafe_squares:
            grid[r,c] = 'U'
        for (r,c) in self.scenic_squares:
            if grid[r,c] == 'U':
                grid[r,c] = 'US'
            else:
                grid[r,c] = 'S'
        
        for (rc, cc), rt in self.road_types.items():
            if rt == 'highway':
                symbol = 'H'
            elif rt == 'rough':
                symbol = 'R'
            else:
                symbol = grid[rc, cc]
            if grid[rc,cc] not in ['.', 'H', 'R', 'S', 'U', 'US']:
                symbol = grid[rc,cc] + rt[0].upper()
            grid[rc, cc] = symbol

        pr, pc = self.pickup_loc
        grid[pr, pc] = 'P'
        dr, dc = self.dropoff_loc
        grid[dr, dc] = 'D'
        
        ar, ac = self.agent_row, self.agent_col
        agent_marker = 'A*' if self.passenger_picked_up else 'A'
        grid[ar, ac] = agent_marker
        
        output = ''
        for r in range(self.grid_rows):
            row_str = ' '.join(f"{cell:2}" for cell in grid[r,:])
            output += row_str + '\n'
        return output

    def close(self):
        pass

    def _state_to_tuple(self, s):
        return tuple(s)

    def _get_all_states(self):
        states = []
        for r in range(self.grid_rows):
            for c in range(self.grid_cols):
                for pu in [0,1]:
                    for la in range(5):
                        states.append((r,c,pu,la))
        return states

    def _build_transition_model(self):
        self.P = {}
        all_states = self._get_all_states()
        for s in all_states:
            self.P[s] = {}
            for a in range(self.action_space.n):
                obs, reward, done = self._simulate_transition(s, a)
                self.P[s][a] = (obs, reward, done)
    
    def _simulate_transition(self, s, a):
        (r,c,pu,la) = s
        last_action = la - 1
        
        new_r, new_c = r,c
        if a == 0 and r > 0:
            new_r -= 1
        elif a == 1 and r < self.grid_rows - 1:
            new_r += 1
        elif a == 2 and c > 0:
            new_c -= 1
        elif a == 3 and c < self.grid_cols - 1:
            new_c += 1

        reward = self.step_cost + self._compute_road_type_reward(new_r, new_c)
        
        if last_action != -1 and last_action != a:
            reward += self.turn_cost
        elif last_action != -1 and last_action == a:
            reward += self.straight_reward
        
        passenger_picked = (pu == 1)
        
        if (new_r, new_c) == self.pickup_loc and not passenger_picked:
            passenger_picked = True
        
        if (new_r, new_c) in self.scenic_squares:
            reward += self.scenic_reward
        
        done = False
        if (new_r, new_c) == self.dropoff_loc and passenger_picked:
            reward += self.final_dropoff_reward
            done = True

        local_incident_prob = self.incident_prob
        if (new_r, new_c) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        reward += local_incident_prob * self.incident_cost
        
        new_pu = 1 if passenger_picked else 0
        new_la = a + 1
        next_s = (new_r, new_c, new_pu, new_la)
        
        return next_s, reward, done

    def compute_optimal_policy(self, gamma=0.99, theta=1e-6):
        all_states = self._get_all_states()
        V = {s:0.0 for s in all_states}

        def is_terminal(s):
            (r,c,pu,la) = s
            if (r,c) == self.dropoff_loc and pu == 1:
                return True
            return False

        while True:
            delta = 0
            for s in all_states:
                if is_terminal(s):
                    continue
                v = V[s]
                q_values = []
                for a in range(self.action_space.n):
                    (next_s, r, done) = self.P[s][a]
                    q_val = r + gamma*(0 if done else V[next_s])
                    q_values.append(q_val)
                V[s] = max(q_values)
                delta = max(delta, abs(v - V[s]))
            if delta < theta:
                break
        
        pi = {}
        for s in all_states:
            if is_terminal(s):
                pi[s] = 0
                continue
            q_values = []
            for a in range(self.action_space.n):
                (next_s, r, done) = self.P[s][a]
                q_val = r + gamma*(0 if done else V[next_s])
                q_values.append(q_val)
            pi[s] = np.argmax(q_values)
        
        return V, pi

    def visualize_policy(self, pi, max_steps=50):
        obs, info = self.reset()
        step_count = 0
        print("Initial State:")
        print(self.render())
        while step_count < max_steps:
            s = tuple(obs)
            a = pi[s]
            obs, r, done, truncated, _ = self.step(a)
            print(f"Step: {step_count}, Action: {a}, Reward: {r}")
            print(self.render())
            if done:
                print("Episode finished!")
                break
            step_count += 1
        if step_count >= max_steps:
            print("Reached max steps without termination.")

        print("Performance Metrics:")
        print(f"  Total Reward: {self.total_reward}")
        print(f"  Steps Taken: {self.steps}")
        print(f"  Safety Incidents: {self.incidents_count}")

if __name__ == "__main__":
    # Updated realistic configuration with balanced rewards
    env = CustomTaxiMDP(
        grid_rows=15,
        grid_cols=15,
        pickup_loc=(0, 0),
        dropoff_loc=(14, 14),
        unsafe_squares=[
            (2, 5), (2, 6), (2, 7),
            (3, 5), (3, 6), (3, 7),
            (4, 5), (4, 6), (4, 7)
        ],
        scenic_squares=[
            (0, 11), (0, 12), (0, 13),
            (8, 2), (8, 3), (8, 4),
            (8, 8), (8, 9), (8, 10)
        ],
        incident_prob=0.001,
        unsafe_incident_prob=0.02,
        scenic_reward=1.0,         # Scenic gives +1 but step cost is -1, net 0 if staying put
        step_cost=-1,
        incident_cost=-20,
        turn_cost=-0.3,
        straight_reward=0.5,
        final_dropoff_reward=500.0, # Large final reward to ensure completing the trip is best
        road_types={
            # Highways (just some examples)
            (0, 0): 'highway', (0, 1): 'highway', (0, 2): 'highway', 
            (6, 0): 'highway', (6, 1): 'highway', (6, 2): 'highway',
            (11, 0): 'highway', (11, 1): 'highway', (11, 2): 'highway',

            # Rough roads
            (2, 5): 'rough', (2, 6): 'rough', (2, 7): 'rough',
            (3, 5): 'rough', (3, 6): 'rough', (3, 7): 'rough',
            (4, 5): 'rough', (4, 6): 'rough', (4, 7): 'rough',

            # Scenic roads
            (0, 11): 'scenic', (0, 12): 'scenic', (0, 13): 'scenic',
            (8, 2): 'scenic', (8, 3): 'scenic', (8, 4): 'scenic',
            (8, 8): 'scenic', (8, 9): 'scenic', (8, 10): 'scenic'
        }
    )

    V, pi = env.compute_optimal_policy(gamma=0.99, theta=1e-6)
    print("Optimal Policy Visualization:")
    env.visualize_policy(pi, max_steps=100)


Optimal Policy Visualization:
Initial State:
P  H  H  .  .  .  .  .  .  .  .  S  S  S  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
H  H  H  .  .  .  .  .  A  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  S  S  S  .  .  .  S  S  S  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
H  H  H  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  D 

Step: 0, Action: 2, Reward: -1.0
P  H  H  .  .  .  .  .  .  .  .  S  S  S  . 
.  .  .  .  .  .  .  .  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  R  R  R  .  .  .  .  .  .  . 
.  .  .  .  .  .  .  

In [48]:
import numpy as np
import gym
from gym import spaces
import pygame

class CustomTaxiMDP(gym.Env):
    """
    Custom Taxi MDP with improved Pygame Visualization.
    """

    metadata = {"render_modes": ["ansi", "human"], "render_fps": 4}

    def __init__(
        self,
        grid_rows=5,
        grid_cols=5,
        pickup_loc=(0,0),
        dropoff_loc=(4,4),
        incident_prob=0.001,
        unsafe_squares=[],
        unsafe_incident_prob=0.01,
        scenic_squares=[],
        scenic_reward=0.5,
        step_cost=-1,
        incident_cost=-10,
        turn_cost=-0.1,
        straight_reward=0.0,
        road_types=None,
        final_dropoff_reward=50.0,
        render_mode=None,
        seed=None
    ):
        super().__init__()
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.pickup_loc = pickup_loc
        self.dropoff_loc = dropoff_loc
        self.incident_prob = incident_prob
        self.unsafe_squares = set(unsafe_squares)
        self.unsafe_incident_prob = unsafe_incident_prob
        self.scenic_squares = set(scenic_squares)
        self.scenic_reward = scenic_reward
        self.step_cost = step_cost
        self.incident_cost = incident_cost
        self.turn_cost = turn_cost
        self.straight_reward = straight_reward
        self.final_dropoff_reward = final_dropoff_reward
        self.render_mode = render_mode

        if road_types is None:
            self.road_types = {}
        else:
            self.road_types = road_types

        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.MultiDiscrete([self.grid_rows, self.grid_cols, 2, 5])
        
        self.np_random, _ = gym.utils.seeding.np_random(seed)
        
        self.reset()

        self._build_transition_model()

        # Pygame Initialization if needed
        self.window = None
        self.clock = None
        self.cell_size = 40  # pixels per cell
        self.info_bar_height = 80
        self.window_size = (self.grid_cols * self.cell_size, self.grid_rows * self.cell_size + self.info_bar_height)

        if self.render_mode == "human":
            pygame.init()
            pygame.display.set_caption("Custom Taxi MDP")
            self.window = pygame.display.set_mode(self.window_size)
            self.clock = pygame.time.Clock()
            pygame.font.init()
            self.font = pygame.font.SysFont("Arial", 20)

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)
        self.agent_row = self.np_random.integers(self.grid_rows)
        self.agent_col = self.np_random.integers(self.grid_cols)
        
        # Ensure not starting at dropoff location
        while (self.agent_row, self.agent_col) == self.dropoff_loc:
            self.agent_row = self.np_random.integers(self.grid_rows)
            self.agent_col = self.np_random.integers(self.grid_cols)
        
        self.passenger_picked_up = False
        self.last_action = -1
        self.terminated = False

        # Performance metrics
        self.total_reward = 0.0
        self.steps = 0
        self.incidents_count = 0

        return self._get_obs(), {}

    def _get_obs(self):
        return (self.agent_row, self.agent_col, int(self.passenger_picked_up), self.last_action+1)

    def _compute_road_type_reward(self, r, c):
        rt = self.road_types.get((r, c), None)
        if rt == "highway":
            return 0.5
        elif rt == "rough":
            return -0.5
        return 0.0

    def step(self, action):
        prev_action = self.last_action

        # Move
        new_row = self.agent_row
        new_col = self.agent_col
        if action == 0 and self.agent_row > 0:
            new_row -= 1
        elif action == 1 and self.agent_row < self.grid_rows - 1:
            new_row += 1
        elif action == 2 and self.agent_col > 0:
            new_col -= 1
        elif action == 3 and self.agent_col < self.grid_cols - 1:
            new_col += 1
        
        self.agent_row = new_row
        self.agent_col = new_col
        
        # Base reward
        reward = self.step_cost + self._compute_road_type_reward(new_row, new_col)

        # Turn or straight
        if prev_action != -1 and prev_action != action:
            reward += self.turn_cost
        elif prev_action != -1 and prev_action == action:
            reward += self.straight_reward
            
        # Pickup
        if (self.agent_row, self.agent_col) == self.pickup_loc and not self.passenger_picked_up:
            self.passenger_picked_up = True
        
        # Scenic
        if (self.agent_row, self.agent_col) in self.scenic_squares:
            reward += self.scenic_reward
        
        # Check dropoff
        if (self.agent_row, self.agent_col) == self.dropoff_loc and self.passenger_picked_up:
            reward += self.final_dropoff_reward
            self.terminated = True
        
        # Incident Probability
        local_incident_prob = self.incident_prob
        if (self.agent_row, self.agent_col) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        # Incident
        if self.np_random.random() < local_incident_prob:
            reward += self.incident_cost
            self.incidents_count += 1
        
        self.last_action = action
        self.total_reward += reward
        self.steps += 1

        if self.render_mode == "human":
            self.render()

        return self._get_obs(), reward, self.terminated, False, {}

    def render(self, mode=None):
        if mode is None:
            mode = self.render_mode

        if mode == "ansi":
            return self._render_text()
        elif mode == "human":
            return self._render_pygame()
        else:
            raise ValueError("Unsupported render mode: {}".format(mode))

    def _render_text(self):
        grid = np.full((self.grid_rows, self.grid_cols), '.', dtype=object)
        
        for (r,c) in self.unsafe_squares:
            grid[r,c] = 'U'
        for (r,c) in self.scenic_squares:
            if grid[r,c] == 'U':
                grid[r,c] = 'US'
            else:
                grid[r,c] = 'S'
        
        for (rc, cc), rt in self.road_types.items():
            if rt == 'highway':
                symbol = 'H'
            elif rt == 'rough':
                symbol = 'R'
            else:
                symbol = grid[rc, cc]
            if grid[rc,cc] not in ['.', 'H', 'R', 'S', 'U', 'US']:
                symbol = grid[rc,cc] + rt[0].upper()
            grid[rc, cc] = symbol

        pr, pc = self.pickup_loc
        grid[pr, pc] = 'P'
        dr, dc = self.dropoff_loc
        grid[dr, dc] = 'D'
        
        ar, ac = self.agent_row, self.agent_col
        agent_marker = 'A*' if self.passenger_picked_up else 'A'
        grid[ar, ac] = agent_marker
        
        output = ''
        for r in range(self.grid_rows):
            row_str = ' '.join(f"{cell:2}" for cell in grid[r,:])
            output += row_str + '\n'
        return output

    def _render_pygame(self):
        if self.window is None:
            pygame.init()
            pygame.display.set_caption("Custom Taxi MDP")
            self.window = pygame.display.set_mode(self.window_size)
            self.clock = pygame.time.Clock()
            pygame.font.init()
            self.font = pygame.font.SysFont("Arial", 20)

        # Colors
        COLOR_BACKGROUND = (230, 230, 230)
        COLOR_HIGHWAY = (169,169,169)   # Dark grey for highways
        COLOR_ROUGH = (139, 69, 19)     # brown
        COLOR_SCENIC = (34, 139, 34)    # green
        COLOR_UNSAFE = (255, 0, 0)      # red
        COLOR_NORMAL = (200, 200, 200)  # lighter gray for normal roads
        COLOR_PICKUP = (255, 255, 0)    # yellow
        COLOR_DROPOFF = (0, 255, 255)   # cyan
        COLOR_AGENT = (0, 0, 255)       # blue
        COLOR_AGENT_WITH_PASSENGER = (0, 0, 128) # darker blue

        self.window.fill(COLOR_BACKGROUND)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

        # Draw the grid and cells
        top_offset = self.info_bar_height

        for r in range(self.grid_rows):
            for c in range(self.grid_cols):
                cell_rect = pygame.Rect(c*self.cell_size, r*self.cell_size+top_offset, self.cell_size, self.cell_size)
                
                rt = self.road_types.get((r,c), None)
                base_color = COLOR_NORMAL
                if rt == "highway":
                    base_color = COLOR_HIGHWAY
                elif rt == "rough":
                    base_color = COLOR_ROUGH

                if (r,c) in self.scenic_squares:
                    base_color = COLOR_SCENIC
                if (r,c) in self.unsafe_squares:
                    base_color = COLOR_UNSAFE

                if (r,c) == self.pickup_loc:
                    base_color = COLOR_PICKUP
                if (r,c) == self.dropoff_loc:
                    base_color = COLOR_DROPOFF

                pygame.draw.rect(self.window, base_color, cell_rect)

                # Agent
                if (r,c) == (self.agent_row, self.agent_col):
                    agent_color = COLOR_AGENT_WITH_PASSENGER if self.passenger_picked_up else COLOR_AGENT
                    inset = 5
                    agent_rect = pygame.Rect(
                        c*self.cell_size+inset, r*self.cell_size+inset+top_offset,
                        self.cell_size-2*inset, self.cell_size-2*inset
                    )
                    pygame.draw.rect(self.window, agent_color, agent_rect)

        # Draw grid lines for a road-like structure
        line_color = (100,100,100)
        for x in range(self.grid_cols+1):
            pygame.draw.line(self.window, line_color, (x*self.cell_size, top_offset), (x*self.cell_size, self.grid_rows*self.cell_size+top_offset), 1)
        for y in range(self.grid_rows+1):
            pygame.draw.line(self.window, line_color, (0, y*self.cell_size+top_offset), (self.grid_cols*self.cell_size, y*self.cell_size+top_offset), 1)

        # Draw info bar at the top
        info_bar_rect = pygame.Rect(0, 0, self.window_size[0], top_offset)
        pygame.draw.rect(self.window, (50,50,50), info_bar_rect)

        # Show metrics
        passenger_status = "Yes" if self.passenger_picked_up else "No"
        info_text = f"Steps: {self.steps} | Total Reward: {self.total_reward:.2f} | Incidents: {self.incidents_count} | Passenger Picked: {passenger_status}"
        info_surface = self.font.render(info_text, True, (255,255,255))
        self.window.blit(info_surface, (10, 10))

        pygame.display.update()
        self.clock.tick(self.metadata["render_fps"])

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
            self.window = None

    def _state_to_tuple(self, s):
        return tuple(s)

    def _get_all_states(self):
        states = []
        for r in range(self.grid_rows):
            for c in range(self.grid_cols):
                for pu in [0,1]:
                    for la in range(5):
                        states.append((r,c,pu,la))
        return states

    def _build_transition_model(self):
        self.P = {}
        all_states = self._get_all_states()
        for s in all_states:
            self.P[s] = {}
            for a in range(self.action_space.n):
                obs, reward, done = self._simulate_transition(s, a)
                self.P[s][a] = (obs, reward, done)
    
    def _simulate_transition(self, s, a):
        (r,c,pu,la) = s
        last_action = la - 1
        
        new_r, new_c = r,c
        if a == 0 and r > 0:
            new_r -= 1
        elif a == 1 and r < self.grid_rows - 1:
            new_r += 1
        elif a == 2 and c > 0:
            new_c -= 1
        elif a == 3 and c < self.grid_cols - 1:
            new_c += 1

        reward = self.step_cost + self._compute_road_type_reward(new_r, new_c)
        
        if last_action != -1 and last_action != a:
            reward += self.turn_cost
        elif last_action != -1 and last_action == a:
            reward += self.straight_reward
        
        passenger_picked = (pu == 1)
        
        if (new_r, new_c) == self.pickup_loc and not passenger_picked:
            passenger_picked = True
        
        if (new_r, new_c) in self.scenic_squares:
            reward += self.scenic_reward
        
        done = False
        if (new_r, new_c) == self.dropoff_loc and passenger_picked:
            reward += self.final_dropoff_reward
            done = True

        local_incident_prob = self.incident_prob
        if (new_r, new_c) in self.unsafe_squares:
            local_incident_prob += self.unsafe_incident_prob
        
        reward += local_incident_prob * self.incident_cost
        
        new_pu = 1 if passenger_picked else 0
        new_la = a + 1
        next_s = (new_r, new_c, new_pu, new_la)
        
        return next_s, reward, done

    def compute_optimal_policy(self, gamma=0.99, theta=1e-6):
        all_states = self._get_all_states()
        V = {s:0.0 for s in all_states}

        def is_terminal(s):
            (r,c,pu,la) = s
            return (r,c) == self.dropoff_loc and pu == 1

        while True:
            delta = 0
            for s in all_states:
                if is_terminal(s):
                    continue
                v = V[s]
                q_values = []
                for a in range(self.action_space.n):
                    (next_s, r, done) = self.P[s][a]
                    q_val = r + gamma*(0 if done else V[next_s])
                    q_values.append(q_val)
                V[s] = max(q_values)
                delta = max(delta, abs(v - V[s]))
            if delta < theta:
                break
        
        pi = {}
        for s in all_states:
            if is_terminal(s):
                pi[s] = 0
                continue
            q_values = []
            for a in range(self.action_space.n):
                (next_s, r, done) = self.P[s][a]
                q_val = r + gamma*(0 if done else V[next_s])
                q_values.append(q_val)
            pi[s] = np.argmax(q_values)
        
        return V, pi

    def visualize_policy(self, pi, max_steps=50):
        obs, info = self.reset()
        step_count = 0
        print("Initial State:")
        if self.render_mode == 'ansi':
            print(self.render('ansi'))
        while step_count < max_steps:
            s = tuple(obs)
            a = pi[s]
            obs, r, done, truncated, _ = self.step(a)
            if self.render_mode == 'ansi':
                print(f"Step: {step_count}, Action: {a}, Reward: {r}")
                print(self.render('ansi'))
            if done:
                print("Episode finished!")
                break
            step_count += 1
        if step_count >= max_steps:
            print("Reached max steps without termination.")

        print("Performance Metrics:")
        print(f"  Total Reward: {self.total_reward}")
        print(f"  Steps Taken: {self.steps}")
        print(f"  Safety Incidents: {self.incidents_count}")


if __name__ == "__main__":
    env = CustomTaxiMDP(
        grid_rows=15,
        grid_cols=15,
        pickup_loc=(0, 0),  # Pickup location at top-left corner
        dropoff_loc=(14, 14),  # Dropoff location at bottom-right corner
        unsafe_squares=[
            # Unsafe stretches representing high-traffic or industrial zones
            (2, 7), (3, 7), (4, 7), (5, 7), (6, 7), (7, 7), (8, 7),
            (11, 4), (11, 5), (11, 6), (11, 7), (11, 8)
        ],
        scenic_squares=[
            # Scenic roads on the eastern side
            (1, 12), (2, 12), (3, 12), (4, 12), (5, 12), (6, 12),
            (12, 13), (12, 14), (13, 13), (13, 14)
        ],
        road_types={
            # Highways form a cross in the center
            (7, c): 'highway' for c in range(15)  # Horizontal highway
        } | {
            (r, 7): 'highway' for r in range(15)  # Vertical highway
        } | {
            # Additional rough roads
            (5, 5): 'rough', (6, 6): 'rough', (8, 8): 'rough'
        },
        incident_prob=0.002,  # Base incident probability
        unsafe_incident_prob=0.05,  # Higher risk in unsafe squares
        scenic_reward=2.0,  # Higher reward for scenic roads
        step_cost=-1,  # Cost per step
        incident_cost=-100,  # Heavy penalty for safety incidents
        turn_cost=-0.3,  # Penalty for turns
        straight_reward=0.5,  # Reward for going straight
        final_dropoff_reward=500.0,  # Large reward for dropoff
        render_mode='human'  # Pygame visualization
    )

    # Compute optimal policy
    V, pi = env.compute_optimal_policy(gamma=0.99, theta=1e-6)

    # Visualize the policy execution
    print("Optimal Policy Visualization:")
    env.visualize_policy(pi, max_steps=100)

    # Close the environment properly
    env.close()



Optimal Policy Visualization:
Initial State:
Episode finished!
Performance Metrics:
  Total Reward: 493.5
  Steps Taken: 43
  Safety Incidents: 0
