In [1]:
import numpy as np

In [2]:
''' 
####################CLIFF WALKING ENVIRONMENT#########################

A schematic view of the environment-

o  o  o  o  o  o  o
o  o  o  o  o  o  o
o  o  o  o  o  o  o
S  x  x  x  x  x  T

Actions: 
    UP (0)
    DOWN (1)
    RIGHT (2)
    LEFT (3)

Rewards: 
     0 for going in Terminal state
    -100 for falling in the cliff
    -1 for all other actions in any state

Note: State remains the same on going out of the maze (but -1 reward is given)
      The episode ends and the agent returns to the start state after falling in the cliff

'''
START_STATE = 36
TERMINAL_STATE = 47
def reward(state):
    if(state == TERMINAL_STATE):
        reward = 0
    elif(state > START_STATE and state < TERMINAL_STATE):
        reward = -100
    else:
        reward = -1
    return reward

def env(state, action):
    # return_val = [prob, next state, reward, isdone]
    num_states = rows * columns
    isdone = lambda state: state > START_STATE and state <= TERMINAL_STATE
    
    if(isdone(state)):
        next_state = state
    else:
        if(action==0):
            next_state = state-columns if state-columns>=0 else state
        elif(action==1):
            next_state = state+columns if state+columns<num_states else state
        elif(action==2):
            next_state = state+1 if (state+1)%columns else state
        elif(action==3):
            next_state = state-1 if state%columns else state 
    # State Transition Probability is 1 because the environment is deterministic
    return_val = [1, next_state, reward(next_state), isdone(next_state)]
    return return_val

In [3]:
alpha = 0.1 # Learning Rate
epsilon = 0.1 # For Epsilon-greedy policy to balance exploration and exploitation
rows = 4
columns = 12
num_states = rows * columns
num_actions = 4
gamma = 1 # Discount Factor
episodes = 100000 # Number of games played

In [4]:
def sarsa():
    # Initialize the action value function
    Q = np.zeros((num_states, num_actions))
    for episode in range(episodes):
        # Initialize S
        curr_state = START_STATE
        # Pick a random number between 0 and 1
        P = np.random.random()
        if(P > epsilon):
            # Pick the greedy action
            curr_action = np.argmax(Q[curr_state])
        else:
            # Pick a random action to explore
            curr_action = np.random.randint(0, num_actions)
        while True:
            # prob: State Transition Probability 
            # reward, next_state: Immediate reward and next state on taking curr_action in curr_state
            # isdone: Whether the next state is Terminal or not
            prob, next_state, reward, isdone = env(curr_state, curr_action)
            # Pick a random number between 0 and 1
            P = np.random.random()
            if(P > epsilon):
                # Pick the greedy action
                next_action = np.argmax(Q[next_state])
            else:
                # Pick a random action to explore
                next_action = np.random.randint(0, num_actions)
            # Update the current state-action value
            Q[curr_state, curr_action] += alpha * (reward + gamma * Q[next_state, next_action] - Q[curr_state, curr_action])
            curr_state = next_state
            curr_action = next_action
            if isdone:
                break
    return Q

In [7]:
Q = sarsa()
# Deterministic policy obtained using updated Q values
policy = np.argmax(Q,axis=1)
print(f'Value Function for SARSA:\n {Q.reshape(num_states, num_actions)}')
print('\n')
print(f'Deterministic Policy:\n {policy.reshape(rows, columns)}')

Value Function for SARSA:
 [[ -15.60555687  -16.96641271  -14.53493899  -15.73252715]
 [ -14.49554434  -15.36761577  -13.50622921  -15.7150597 ]
 [ -13.23100905  -14.4058423   -12.16701437  -14.53295737]
 [ -12.14177393  -12.13892805  -11.23893725  -13.66379029]
 [ -11.01564949  -10.73712911  -10.01043216  -12.4128524 ]
 [  -9.99746244   -9.51264392   -8.75839051  -11.20221966]
 [  -8.76153989   -8.49614833   -7.59918664  -10.16049443]
 [  -7.81443012   -7.27031718   -6.56318363   -9.11688528]
 [  -6.57847282   -5.91881158   -5.4960741    -7.97106281]
 [  -5.78924386   -4.77733112   -4.5228535    -6.94935889]
 [  -4.47987028   -3.57070257   -3.34761424   -5.94305141]
 [  -3.51003363   -2.61424542   -3.3738531    -4.73072512]
 [ -15.50078571  -18.10493193  -16.06252619  -16.57097846]
 [ -14.61090049  -17.56109788  -14.29913518  -16.61421957]
 [ -13.37697914  -14.93443129  -13.39354137  -15.07501876]
 [ -12.28509083  -17.37021221  -10.85345724  -13.57757109]
 [ -11.10047832  -17.27792207

In [8]:
map_dict = {0:'Up', 1:'Down', 2:'Right', 3:'Left'}