In [57]:
!pip install gymnasium
import gymnasium as gym
import numpy as np



In [58]:
import matplotlib.pyplot as plt
import numpy as np
import gymnasium as gym
from IPython.display import display, clear_output
import time
import os
from PIL import Image

# Helper function for epsilon-greedy policy
def choose_action(state, q_table, action_space, epsilon):
    if np.random.uniform(0, 1) < epsilon:
        return action_space.sample()  # Explore action space
    else:
        return np.argmax(q_table[state, :]) # Exploit learned values

# Function to visualize the environment map
def visualize_env_map(env, env_name, save_dir="."):
    """
    Visualizes the environment map and saves it as a PNG.
    """
    if env_name == 'FrozenLake-v1':
        # FrozenLake has a text-based rendering that can be captured
        # We can try to render it and convert to an image
        try:
            # Attempt to get the text render
            output = env.render()
            if output:
                print(f"Map for {env_name}:")
                print(output)
                # For visual map, we'll rely on creating a visual representation later
                # or if render_mode='rgb_array' is supported and useful for a static map.
                # Since FrozenLake text render is informative, displaying it for now.
                pass
            else:
                 print(f"Could not get text rendering for {env_name}.")

        except Exception as e:
            print(f"Error rendering {env_name} for map visualization: {e}")

    elif env_name == 'Taxi-v3':
        # Taxi also has a text render
        try:
            output = env.render()
            if output:
                print(f"Map for {env_name}:")
                print(output)
                pass
            else:
                 print(f"Could not get text rendering for {env_name}.")
        except Exception as e:
            print(f"Error rendering {env_name} for map visualization: {e}")

    elif env_name == 'CliffWalking-v0':
         # CliffWalking has a text render
        try:
            output = env.render()
            if output:
                print(f"Map for {env_name}:")
                print(output)
                pass
            else:
                 print(f"Could not get text rendering for {env_name}.")
        except Exception as e:
            print(f"Error rendering {env_name} for map visualization: {e}")

    else:
        print(f"Map visualization not specifically implemented for {env_name}. Displaying basic info:")
        print(f"Observation Space: {env.observation_space}")
        print(f"Action Space: {env.action_space}")

    # Note: Saving a generic visual map as PNG for all discrete environments
    # requires more sophisticated rendering or creating a custom visualization
    # based on environment structure, which can be complex and environment-specific.
    # The text-based render is the most straightforward for these environments.
    # If a visual PNG is strictly required, this function would need significant
    # expansion or a different rendering approach. For now, focusing on the text map.

# Function to generate and save video
def create_video_from_frames(frame_folder, output_filename, fps=30):
    images = [img for img in os.listdir(frame_folder) if img.endswith(".png")]
    # Sort the frames by the numerical part of the filename
    images.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))

    if not images:
        print(f"No frames found in {frame_folder} to create video.")
        return

    # Use the first image to determine dimensions
    first_frame_path = os.path.join(frame_folder, images[0])
    try:
        frame = Image.open(first_frame_path)
        width, height = frame.size
    except FileNotFoundError:
        print(f"Error opening first frame: {first_frame_path}")
        return
    except Exception as e:
        print(f"Error processing first frame {first_frame_path}: {e}")
        return


    import subprocess
    output_path = os.path.join(".", output_filename) # Save in the current directory

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Use ffmpeg to create the MP4 video
    # -y: overwrite output file without asking
    # -f image2: input format is image sequence
    # -r {fps}: input frame rate
    # -i {frame_folder}/frame_%04d.png: input file pattern (assuming frame_0000.png, frame_0001.png, ...)
    # -vcodec libx264: video codec
    # -crf 25: constant rate factor (quality, lower is better)
    # -pix_fmt yuv420p: pixel format (for compatibility)
    command = [
        'ffmpeg', '-y', '-f', 'image2', '-r', str(fps),
        '-i', os.path.join(frame_folder, 'frame_%04d.png'),
        '-vcodec', 'libx264', '-crf', '25', '-pix_fmt', 'yuv420p',
        output_path
    ]

    try:
        subprocess.run(command, check=True, capture_output=True, text=True)
        print(f"Video successfully created: {output_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error creating video: {e}")
        print(f"STDOUT: {e.stdout}")
        print(f"STDERR: {e.stderr}")
    except FileNotFoundError:
        print("ffmpeg not found. Please install ffmpeg to create videos.")
        print("You can install it in Colab using: !apt-get update && apt-get install ffmpeg")
    except Exception as e:
        print(f"An unexpected error occurred during video creation: {e}")


# Function to clean up frame images
def cleanup_frames(frame_folder):
    for img in os.listdir(frame_folder):
        if img.endswith(".png"):
            try:
                os.remove(os.path.join(frame_folder, img))
            except OSError as e:
                print(f"Error removing frame {img}: {e}")
    if os.path.exists(frame_folder) and not os.listdir(frame_folder):
         try:
             os.rmdir(frame_folder)
             print(f"Cleaned up frame folder: {frame_folder}")
         except OSError as e:
             print(f"Error removing frame folder {frame_folder}: {e}")

## Taxi

In [59]:
# Create and visualize the Taxi environment
env_name_taxi = 'Taxi-v3'
env_taxi = gym.make(env_name_taxi, render_mode='ansi')
visualize_env_map(env_taxi, env_name_taxi)

# Define hyperparameters for Taxi (can be adjusted)
alpha_taxi = 0.1
gamma_taxi = 0.99
epsilon_taxi = 0.1
num_episodes_taxi = 50000 # Taxi might need more episodes

# Initialize Q-tables for Taxi
q_table_sarsa_taxi = np.zeros((env_taxi.observation_space.n, env_taxi.action_space.n))
q_table_qlearning_taxi = np.zeros((env_taxi.observation_space.n, env_taxi.action_space.n))

# SARSA Implementation for Taxi
print("Training SARSA on Taxi...")
for i in range(num_episodes_taxi):
    state, _ = env_taxi.reset()
    action = choose_action(state, q_table_sarsa_taxi, env_taxi.action_space, epsilon_taxi)
    done = False
    while not done:
        next_state, reward, terminated, truncated, _ = env_taxi.step(action)
        done = terminated or truncated
        next_action = choose_action(next_state, q_table_sarsa_taxi, env_taxi.action_space, epsilon_taxi)

        # SARSA update
        old_value = q_table_sarsa_taxi[state, action]
        next_q = q_table_sarsa_taxi[next_state, next_action]
        new_value = old_value + alpha_taxi * (reward + gamma_taxi * next_q - old_value)
        q_table_sarsa_taxi[state, action] = new_value

        state = next_state
        action = next_action

print("SARSA training finished.")

# Q-Learning Implementation for Taxi
print("\nTraining Q-Learning on Taxi...")
for i in range(num_episodes_taxi):
    state, _ = env_taxi.reset()
    done = False
    while not done:
        action = choose_action(state, q_table_qlearning_taxi, env_taxi.action_space, epsilon_taxi)
        next_state, reward, terminated, truncated, _ = env_taxi.step(action)
        done = terminated or truncated

        # Q-Learning update
        old_value = q_table_qlearning_taxi[state, action]
        max_next_q = np.max(q_table_qlearning_taxi[next_state, :])
        new_value = old_value + alpha_taxi * (reward + gamma_taxi * max_next_q - old_value)
        q_table_qlearning_taxi[state, action] = new_value

        state = next_state

print("Q-Learning training finished.")

env_taxi.close() # Close the environment after training

Error rendering Taxi-v3 for map visualization: Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
Training SARSA on Taxi...
SARSA training finished.

Training Q-Learning on Taxi...
Q-Learning training finished.


In [62]:
# Function to generate a successful episode path, specific to Taxi
def generate_successful_episode_path_taxi(env_name, q_table, max_retries=1000, max_steps_per_episode=1000):
    """
    Generates a successful episode path for Taxi environment using the learned Q-table and policy,
    retrying until a successful episode (reaching the goal) is found
    or max_retries is reached.
    """
    env = gym.make(env_name, render_mode='rgb_array') # Use rgb_array for rendering frames
    path = []
    is_successful = False
    attempt = 0

    while attempt < max_retries and not is_successful:
        state, _ = env.reset()
        path = [state]
        done = False
        steps = 0
        while not done and steps < max_steps_per_episode:
            # Choose action based on greedy policy (exploit learned knowledge)
            action = np.argmax(q_table[state, :])
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            path.append(next_state)
            state = next_state
            steps += 1

            # Check if the goal state is reached successfully in Taxi
            if env_name == 'Taxi-v3' and terminated and reward > 0:
                 is_successful = True


        attempt += 1

    env.close()
    if not is_successful:
        print(f"Warning: Could not find a successful episode for {env_name} after {max_retries} attempts.")

    return path, is_successful

# Generate successful paths for Taxi (SARSA and Q-Learning)
print("Generating successful episode paths for Taxi...")

taxi_sarsa_path, taxi_sarsa_success = generate_successful_episode_path_taxi('Taxi-v3', q_table_sarsa_taxi)
print(f"Taxi SARSA path generated. Success: {taxi_sarsa_success}")

taxi_qlearning_path, taxi_qlearning_success = generate_successful_episode_path_taxi('Taxi-v3', q_table_qlearning_taxi)
print(f"Taxi Q-Learning path generated. Success: {taxi_qlearning_success}")

print("Finished generating Taxi paths.")

Generating successful episode paths for Taxi...
Taxi SARSA path generated. Success: True
Taxi Q-Learning path generated. Success: True
Finished generating Taxi paths.


In [63]:
# Function to generate and save video
def create_video_from_frames(frame_folder, output_filename, fps=30):
    images = [img for img in os.listdir(frame_folder) if img.endswith(".png")]
    # Sort the frames by the numerical part of the filename
    images.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))

    if not images:
        print(f"No frames found in {frame_folder} to create video.")
        return

    # Use the first image to determine dimensions
    first_frame_path = os.path.join(frame_folder, images[0])
    try:
        frame = Image.open(first_frame_path)
        width, height = frame.size
    except FileNotFoundError:
        print(f"Error opening first frame: {first_frame_path}")
        return
    except Exception as e:
        print(f"Error processing first frame {first_frame_path}: {e}")
        return

    import subprocess
    output_path = os.path.join(".", output_filename) # Save in the current directory

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Use ffmpeg to create the MP4 video
    # -y: overwrite output file without asking
    # -f image2: input format is image sequence
    # -r {fps}: input frame rate
    # -i {frame_folder}/frame_%04d.png: input file pattern (assuming frame_0000.png, frame_0001.png, ...)
    # -vcodec libx264: video codec
    # -crf 25: constant rate factor (quality, lower is better)
    # -pix_fmt yuv420p: pixel format (for compatibility)
    command = [
        'ffmpeg', '-y', '-f', 'image2', '-r', str(fps),
        '-i', os.path.join(frame_folder, 'frame_%04d.png'),
        '-vcodec', 'libx264', '-crf', '25', '-pix_fmt', 'yuv420p',
        output_path
    ]

    try:
        subprocess.run(command, check=True, capture_output=True, text=True)
        print(f"Video successfully created: {output_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error creating video: {e}")
        print(f"STDOUT: {e.stdout}")
        print(f"STDERR: {e.stderr}")
    except FileNotFoundError:
        print("ffmpeg not found. Please install ffmpeg to create videos.")
        print("You can install it in Colab using: !apt-get update && apt-get install ffmpeg")
    except Exception as e:
        print(f"An unexpected error occurred during video creation: {e}")


# Function to clean up frame images (kept as is)
def cleanup_frames(frame_folder):
    for img in os.listdir(frame_folder):
        if img.endswith(".png"):
            try:
                os.remove(os.path.join(frame_folder, img))
            except OSError as e:
                print(f"Error removing frame {img}: {e}")
    if os.path.exists(frame_folder) and not os.listdir(frame_folder):
         try:
             os.rmdir(frame_folder)
             print(f"Cleaned up frame folder: {frame_folder}")
         except OSError as e:
             print(f"Error removing empty frame folder {frame_folder}: {e}")


# Function to generate video from learned policy, incorporating path generation, specific to Taxi
def generate_video_from_learned_policy_and_path_taxi(env_name, q_table, output_filename, fps=2, max_retries=1000, max_steps_per_episode=1000, frame_folder="frames"):
    """
    Generates a successful episode for Taxi environment using the learned Q-table, renders frames,
    creates an MP4 video, and cleans up frames.
    Includes path generation within the function.
    """
    # Ensure frame folder is clean before starting
    if os.path.exists(frame_folder):
        cleanup_frames(frame_folder)
    else:
        os.makedirs(frame_folder)

    print(f"Attempting to generate successful episode for {env_name} for video...")
    # Use the Taxi-specific path generation function
    path, is_successful = generate_successful_episode_path_taxi(env_name, q_table, max_retries, max_steps_per_episode)

    if not is_successful:
        print(f"Skipping video generation for {env_name} as no successful path was found after {max_retries} attempts.")
        return # Exit if no successful episode found

    print(f"Successful episode found for {env_name}. Generating video...")

    # Re-run the successful episode using the generated path and render frames
    env = gym.make(env_name, render_mode='rgb_array')
    state, _ = env.reset() # Start from the actual initial state
    frames = []
    frames.append(env.render()) # Render the initial state

    # Step through the path and render
    temp_env = gym.make(env_name, render_mode='rgb_array')
    state, _ = temp_env.reset()
    frames = []
    frames.append(temp_env.render())
    frame_count = 1
    done = False
    steps = 0
    max_sim_steps = len(path) + 10 # Add a buffer

    while not done and steps < max_steps_per_episode and steps < max_sim_steps:
         action = np.argmax(q_table[state, :]) # Greedy action
         next_state, reward, terminated, truncated, _ = temp_env.step(action)
         done = terminated or truncated
         frames.append(temp_env.render())
         state = next_state
         steps += 1

         # Stop if the goal is reached successfully (Taxi specific check)
         if terminated and reward > 0:
             break


    temp_env.close()

    # Save frames if we have any
    if frames:
        for i, frame in enumerate(frames):
            img = Image.fromarray(frame)
            img.save(os.path.join(frame_folder, f"frame_{i:04d}.png"))
    else:
        print(f"No frames were generated for {env_name}.")
        # Clean up the potentially created frame folder if no frames
        if os.path.exists(frame_folder) and not os.listdir(frame_folder):
             try:
                 os.rmdir(frame_folder)
             except OSError as e:
                 print(f"Error removing empty frame folder {frame_folder}: {e}")
        return


    # Create video from saved frames
    create_video_from_frames(frame_folder, output_filename, fps=fps)

    # Clean up frames after video creation
    cleanup_frames(frame_folder)


# Generate videos for Taxi (SARSA and Q-Learning)
print("\nGenerating videos for Taxi environment...")

# Call the revised function which includes path generation and success check
generate_video_from_learned_policy_and_path_taxi('Taxi-v3', q_table_sarsa_taxi, 'taxi_sarsa_episode.mp4')
generate_video_from_learned_policy_and_path_taxi('Taxi-v3', q_table_qlearning_taxi, 'taxi_qlearning_episode.mp4')


print("Finished attempting to generate Taxi videos.")


Generating videos for Taxi environment...
Attempting to generate successful episode for Taxi-v3 for video...
Successful episode found for Taxi-v3. Generating video...
Video successfully created: ./taxi_sarsa_episode.mp4
Cleaned up frame folder: frames
Attempting to generate successful episode for Taxi-v3 for video...
Successful episode found for Taxi-v3. Generating video...


  return datetime.utcnow().replace(tzinfo=utc)


Video successfully created: ./taxi_qlearning_episode.mp4
Cleaned up frame folder: frames
Finished attempting to generate Taxi videos.


  return datetime.utcnow().replace(tzinfo=utc)


In [67]:
# Print the learned policies for SARSA and Q-Learning in Taxi

print("SARSA Learned Policy for Taxi:")
# The policy is derived by choosing the action with the maximum Q-value for each state
sarsa_policy_taxi = np.argmax(q_table_sarsa_taxi, axis=1)
print(sarsa_policy_taxi)

print("\nQ-Learning Learned Policy for Taxi:")
qlearning_policy_taxi = np.argmax(q_table_qlearning_taxi, axis=1)
print(qlearning_policy_taxi)

print("\nInterpretation of Taxi Actions:")
print("0: Move South")
print("1: Move North")
print("2: Move East")
print("3: Move West")
print("4: Pickup passenger")
print("5: Dropoff passenger")

SARSA Learned Policy for Taxi:
[0 4 4 4 0 0 2 0 0 0 0 0 0 0 0 0 5 0 0 0 0 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0
 3 3 0 0 0 0 0 2 0 2 2 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 2 0 2 2 0 0 0 0 0 0
 0 0 0 2 0 0 0 0 0 3 4 0 4 4 3 3 0 3 3 0 0 0 3 5 0 0 0 1 1 1 2 0 2 2 0 0 0
 0 2 0 0 0 1 0 0 2 0 3 1 3 0 0 0 0 3 3 0 0 0 0 0 0 3 0 0 0 0 0 0 0 2 0 1 2
 0 0 0 0 2 0 0 0 0 2 0 0 0 3 0 0 1 0 1 2 0 0 0 0 0 0 0 0 0 2 0 0 0 3 3 0 1
 0 1 1 3 3 0 0 0 0 3 0 0 1 3 3 0 1 1 1 2 0 2 2 0 0 0 0 2 2 2 0 1 2 0 2 0 3
 1 1 2 0 2 2 3 3 0 3 2 2 2 0 3 2 3 2 0 3 3 3 2 0 1 2 3 3 0 3 2 2 2 0 3 2 3
 2 0 3 3 3 2 0 2 2 3 3 0 3 0 0 0 0 3 1 3 0 0 3 3 3 1 0 1 1 3 3 0 3 3 0 0 0
 3 1 3 3 0 1 1 1 1 0 1 1 0 0 0 0 1 1 1 0 1 1 0 1 0 1 1 1 1 0 1 1 1 1 0 1 1
 1 2 0 1 2 1 2 0 1 1 3 1 0 1 1 1 3 0 1 1 1 1 0 3 1 1 1 0 1 1 1 2 0 2 2 1 1
 0 1 0 0 0 0 1 1 1 0 0 3 3 1 1 0 1 1 1 3 0 3 0 0 3 0 1 1 1 3 0 1 1 1 1 0 1
 1 4 4 0 4 1 1 1 0 1 1 5 1 0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0 1 0 1 2 0 3 1 1
 1 0 1 1 1 1 0 3 1 1 1 0 1 1 1 1 0 1 1 1 1 0 2 1 1 1 0 1 4 4 4 0 1 1 