In [1]:
import numpy as np

In [2]:
''' 
####################CLIFF WALKING ENVIRONMENT#########################

A schematic view of the environment-

o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
S  x  x  x  x  x  x  x  x  x  x  T

Actions: 
    UP (0)
    DOWN (1)
    RIGHT (2)
    LEFT (3)

Rewards: 
     0 for going in Terminal state
    -100 for falling in the cliff
    -1 for all other actions in any state

Note: State remains the same on going out of the maze (but -1 reward is given)
      The episode ends and the agent returns to the start state after falling in the cliff

'''
START_STATE = 36
TERMINAL_STATE = 47
def reward(state):
    if(state == TERMINAL_STATE):
        reward = 0
    elif(state > START_STATE and state < TERMINAL_STATE):
        reward = -100
    else:
        reward = -1
    return reward

def env(state, action):
    # return_val = [prob, next state, reward, isdone]
    num_states = rows * columns
    isdone = lambda state: state > START_STATE and state <= TERMINAL_STATE
    
    if(isdone(state)):
        next_state = state
    else:
        if(action==0):
            next_state = state-columns if state-columns>=0 else state
        elif(action==1):
            next_state = state+columns if state+columns<num_states else state
        elif(action==2):
            next_state = state+1 if (state+1)%columns else state
        elif(action==3):
            next_state = state-1 if state%columns else state 
    # State Transition Probability is 1 because the environment is deterministic
    return_val = [1, next_state, reward(next_state), isdone(next_state)]
    return return_val

In [3]:
alpha = 0.1 # Learning Rate
epsilon = 0.1 # For Epsilon-greedy policy to balance exploration and exploitation
rows = 4
columns = 12
num_states = rows * columns
num_actions = 4
gamma = 1 # Discount Factor
episodes = 100000 # Number of games played

In [4]:
def sarsa():
    # Initialize the action value function
    Q = np.zeros((num_states, num_actions))
    for episode in range(episodes):
        # Initialize S
        curr_state = START_STATE
        # Pick a random number between 0 and 1
        P = np.random.random()
        if(P > epsilon):
            # Pick the greedy action
            curr_action = np.argmax(Q[curr_state])
        else:
            # Pick a random action to explore
            curr_action = np.random.randint(0, num_actions)
        while True:
            # prob: State Transition Probability 
            # reward, next_state: Immediate reward and next state on taking curr_action in curr_state
            # isdone: Whether the next state is Terminal or not
            prob, next_state, reward, isdone = env(curr_state, curr_action)
            # Pick a random number between 0 and 1
            P = np.random.random()
            if(P > epsilon):
                # Pick the greedy action
                next_action = np.argmax(Q[next_state])
            else:
                # Pick a random action to explore
                next_action = np.random.randint(0, num_actions)
            # Update the current state-action value
            Q[curr_state, curr_action] += alpha * (reward + gamma * Q[next_state, next_action] - Q[curr_state, curr_action])
            curr_state = next_state
            curr_action = next_action
            if isdone:
                break
    return Q

In [5]:
Q = sarsa()
# Deterministic policy obtained using updated Q values
policy = np.argmax(Q,axis=1)
print(f'Value Function for SARSA:\n {Q.reshape(num_states, num_actions)}')
print('\n')
print(f'Deterministic Policy:\n {policy.reshape(rows, columns)}')

Value Function for SARSA:
 [[ -15.94559988  -16.9948888   -14.69603306  -15.72345508]
 [ -15.05465572  -15.76705586  -13.54840985  -15.87710741]
 [ -13.45242819  -14.69868729  -12.16853176  -14.94160782]
 [ -12.63413295  -13.699749    -10.96799573  -13.63458863]
 [ -11.61395304  -13.18976409  -10.12519889  -12.66692852]
 [ -10.21356605  -11.41455096   -8.81595204  -11.54811983]
 [  -8.89894462   -9.44245617   -7.84409062  -10.38606989]
 [  -7.97057786   -7.50709271   -6.86752574   -9.04047756]
 [  -6.71601565   -6.28608529   -5.76857999   -8.10824048]
 [  -5.75416482   -6.05996278   -4.48956925   -6.89134512]
 [  -4.53879864   -3.37399566   -3.69865185   -5.72865438]
 [  -3.29939681   -2.29040333   -3.38761926   -4.6112411 ]
 [ -15.77520247  -18.15055322  -16.08927591  -16.81982097]
 [ -14.82251886  -20.91290962  -15.2901838   -17.36156562]
 [ -13.41420048  -19.27629725  -13.80332536  -15.83643976]
 [ -12.2886318   -14.46490268  -13.37879625  -14.82369556]
 [ -11.34638929  -18.07730175

In [6]:
map_dict = {0:'Up', 1:'Down', 2:'Right', 3:'Left'}