In [16]:
import numpy as np
import gymnasium as gym
import csv
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
from taxi_environment import TaxiEnvironment

def animate_training_progression(all_grids, log_interval=1000, save_path="training_progression.gif", label="Training"):    
    all_frames = []
    episode_labels = []
    for i, episode_frames in enumerate(all_grids):
        episode_num = (i + 1) * log_interval
        for frame in episode_frames:
            all_frames.append(frame)
            episode_labels.append(f"{label} Episode: {episode_num}")
            
    if len(all_frames) == 0:
        print("No frames found to animate.")
        return
    fig, ax = plt.subplots()
    ax.axis('off')
    im = ax.imshow(all_frames[0])
    title = ax.set_title(episode_labels[0])
    def update(frame_idx):
        im.set_array(all_frames[frame_idx])
        title.set_text(episode_labels[frame_idx])
        return [im, title]
    ani = animation.FuncAnimation(
        fig, update, frames=len(all_frames), interval=100, blit=False
    )
    try:
        print(f"Saving full training progression to {save_path} (this might take a minute)...")
        ani.save(save_path, writer='pillow', fps=10)
        print(f"Success! Saved to {save_path}")
    except Exception as e:
        print(f"Error saving animation: {e}")
    finally:
        plt.close(fig)

def train_agent():
    # Hyperparameters
    total_episodes = 25000        
    max_steps = 200               
    learning_rate = 0.1           
    discount_rate = 0.99          
    epsilon = 1.0
    max_epsilon = 1.0
    min_epsilon = 0.01
    decay_rate = 0.0005           
    log_interval = 1000

    env = TaxiEnvironment(render_mode="human")
    action_size = env.action_space.n
    state_size = env.observation_space.n
    qtable = np.zeros((state_size, action_size))
    grids = []
    log_filename = "training_log.csv"
    with open(log_filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Episode", "Total Reward", "Steps"])
    print("Starting Training")
    for episode in range(total_episodes):
        state, info = env.reset()
        step = 0
        total_rewards = 0
        terminated = False
        truncated = False
        save_episode = (episode + 1) % log_interval == 0
        episode_grids = []
        for step in range(max_steps):
            action_mask = info["action_mask"]
            exp_exp_tradeoff = np.random.uniform(0,1)
            if exp_exp_tradeoff > epsilon:
                # Exploitation: Filter Q-values using mask
                masked_q_values = qtable[state, :].copy()
                masked_q_values[action_mask == 0] = -np.inf
                action = np.argmax(masked_q_values)
            else:
                valid_actions = np.where(action_mask == 1)[0]
                if len(valid_actions) > 0:
                    action = np.random.choice(valid_actions)
                else:
                    action = env.action_space.sample()
                    
            new_state, reward, terminated, truncated, info = env.step(action)            
            if terminated:
                target_q = reward
            else:
                next_action_mask = info["action_mask"]                 
                next_q_values = qtable[new_state, :].copy()
                next_q_values[next_action_mask == 0] = -np.inf                
                target_q = reward + discount_rate * np.max(next_q_values)
                if save_episode:
                    current_grid = env.render()
                    episode_grids.append(current_grid)
            qtable[state, action] += learning_rate * (target_q - qtable[state, action])
            total_rewards += reward
            state = new_state
            if terminated:
                break
        # For Epsilon Greedy Approach
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate*episode)
        if save_episode:
            print(f"Episode {episode + 1}: Reward: {total_rewards}, Epsilon: {epsilon:.4f}")
            grids.append(episode_grids)
            with open(log_filename, mode="a", newline="") as file:
                writer = csv.writer(file)
                writer.writerow([episode + 1, total_rewards, step])        
    print("Training Finished.")
    return (qtable, grids)

def validate_agent(qtable, num_episodes=5):
    print("\n--- Starting Validation ---")
    env = TaxiEnvironment(render_mode="human")
    max_steps = 50
    total_test_rewards = []
    validation_grids = []
    for episode in range(num_episodes):
        state, info = env.reset()
        step = 0
        episode_reward = 0
        terminated = False
        episode_grids = []
        for step in range(max_steps):
            action_mask = info["action_mask"]
            masked_q_values = qtable[state, :].copy()
            masked_q_values[action_mask == 0] = -np.inf
            action = np.argmax(masked_q_values)
            new_state, reward, terminated, truncated, info = env.step(action)
            episode_reward += reward
            state = new_state
            curr_grid = env.render()
            episode_grids.append(curr_grid)
            if terminated:
                break
        total_test_rewards.append(episode_reward)
        print(f"Validation Episode {episode + 1}: Total Reward = {episode_reward}")
        grids.append(episode_grids)
    avg_reward = np.mean(total_test_rewards)
    print(f"\nValidation Finished. Average Reward: {avg_reward}")
    if avg_reward > 10:
        print("Result: SUCCESS - The agent has learned the task!")
    else:
        print("Result: FAILURE - The agent is still struggling.")
    return (avg_reward, grids)
        
# TRAINING
qtable, grids = train_agent()
animate_training_progression(grids, log_interval=1000, save_path="full_training_movie.gif", label="Training")

#VALIDATE
avg_reward, validate_grids = validate_agent(qtable, 50)
animate_training_progression(validate_grids, log_interval=1, save_path="full_validation_movie.gif", label="Validation")




Starting Training
Episode 1000: Reward: -20, Epsilon: 0.6108
Episode 2000: Reward: 8, Epsilon: 0.3744
Episode 3000: Reward: 12, Epsilon: 0.2310
Episode 4000: Reward: 18, Epsilon: 0.1440
Episode 5000: Reward: 21, Epsilon: 0.0913
Episode 6000: Reward: 15, Epsilon: 0.0593
Episode 7000: Reward: 18, Epsilon: 0.0399
Episode 8000: Reward: 24, Epsilon: 0.0281
Episode 9000: Reward: 16, Epsilon: 0.0210
Episode 10000: Reward: 23, Epsilon: 0.0167
Episode 11000: Reward: 21, Epsilon: 0.0140
Episode 12000: Reward: 13, Epsilon: 0.0125
Episode 13000: Reward: 22, Epsilon: 0.0115
Episode 14000: Reward: 14, Epsilon: 0.0109
Episode 15000: Reward: 21, Epsilon: 0.0105
Episode 16000: Reward: 17, Epsilon: 0.0103
Episode 17000: Reward: 22, Epsilon: 0.0102
Episode 18000: Reward: 12, Epsilon: 0.0101
Episode 19000: Reward: 15, Epsilon: 0.0101
Episode 20000: Reward: 21, Epsilon: 0.0100
Episode 21000: Reward: 20, Epsilon: 0.0100
Episode 22000: Reward: 18, Epsilon: 0.0100
Episode 23000: Reward: 24, Epsilon: 0.0100
Ep