In [14]:
import numpy as np 
import gym   
import random
from collections import defaultdict
import matplotlib.pyplot as plt
from datetime import datetime

env = gym.make("Taxi-v3")
env.render()

+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y|[43m [0m: |[35mB[0m: |
+---------+



In [15]:
action_size = env.action_space.n
print("Action Size: ", action_size)

state_size = env.observation_space.n
print("State Size: ", state_size)

Action Size:  6
State Size:  500


In [25]:
class Sarsa_Agent():
    
    def __init__(self, gamma = 0.95, learning_rate = 0.10, epsilon = 0.2, nepisodes = 40000):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.nepisodes = nepisodes
        self.Q = defaultdict(lambda: np.zeros(env.action_space.n)) 

    def greedy_policy(self, state):
        return np.argmax(self.Q[state])

    def epsilon_greedy_policy(self, state):
        action = 0
        if np.random.uniform() < self.epsilon:
            action = np.random.choice(env.action_space.n)
        else:
            action = self.greedy_policy(state)
        return action

    def onpolicy_control(self, test_episodes):
        
        rewards = []
        start = datetime.now()
        for episode in range(self.nepisodes):
            state = env.reset()
            done = False
            action = self.epsilon_greedy_policy(state)
            while not done:
                next_state, reward, done, info = env.step(action)
                next_action = self.epsilon_greedy_policy(next_state)
                self.Q[state][action] = self.Q[state][action] + self.learning_rate*(reward + self.gamma*self.Q[next_state][next_action] - self.Q[state][action])
                state = next_state
                action = next_action
                
            ep_avg, max_rew, min_rew = self.test_policy(test_episodes)
            rewards.append(ep_avg)
            
            if episode % 100 == 0:
                print(f'Episode: {episode} Average Reward: {ep_avg}')
                print(f'Maximum Reward : {max_rew} Minimum Reward: {min_rew}')
                
            if ep_avg>9:
                if ep_avg<9.5:
                    print(f'Solved! Episode: {episode} Average Reward: {ep_avg}')
                    
                    fig = plt.figure()
                    ax = fig.add_subplot(111)
                    plt.plot(np.arange(len(rewards)), rewards)
                    plt.ylabel('Average Reward')
                    plt.xlabel('Episode #')
                    plt.show()
                    end = datetime.now()
                    print("Solved in: ", (end-start))
                    
                    return self.Q
                    
                else:
                    continue     
        
        fig = plt.figure()
        ax = fig.add_subplot(111)
        plt.plot(np.arange(len(rewards)), rewards)
        plt.ylabel('Average Reward')
        plt.xlabel('Episode #')
        plt.show()
        
        end = datetime.now()
        print("Not Solved")
        print((end-start))
        return self.Q
    
    def test_policy(self, n):
        
        episode_rewards = []
        max_rewards = []
        min_rewards = []
        
        for episode in range(n):
            state = env.reset()
            done = False
            cum_rewards = 0
            while not done:
                action = self.greedy_policy(state)
                state, reward, done, info = env.step(action)
                cum_rewards += reward
                
            episode_rewards.append(cum_rewards)
        
        max_rewards.append(max(episode_rewards))
        min_rewards.append(min(episode_rewards))
        
                
        return sum(episode_rewards)/n, max_rewards, min_rewards

In [None]:
a = Sarsa_Agent()
Q = a.onpolicy_control(200)
a.test_policy(100)

Episode: 0 Average Reward: -209.0
Maximum Reward : [-200] Minimum Reward: [-2000]
Episode: 100 Average Reward: -200.0
Maximum Reward : [-200] Minimum Reward: [-200]
