In [None]:
import numpy as np
import gym
import random
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
%%html
<img src="img/eps-greedy.png">
<img src="img/algo.png">

In [None]:
env_type = "FrozenLake8x8-v0"

In [None]:
env = gym.make(env_type)

alpha = 0.05
max_steps = 250 
episodes = 500000
epsilon_max = 1.0
epsilon_min = 0.1
epsilon = 1.0
gamma = 0.95
_lambda = 0.0
#decay_factor = 0.999985
#decay_factor = 0.00005

NUMBER_OF_EVAL_SIMS = 100
history = []
is_decay = True

In [None]:
def Initialize_Q():
    Q = np.zeros((env.env.nS,env.env.nA))
    return Q

def Initialize_E():
    E = np.zeros((env.env.nS,env.env.nA))
    return E

def Initalize_Policy(env):
    Policy = np.array([env.action_space.sample() for _ in range(env.env.nS)])
    return Policy

def tune_params():
    global epsilon
    epsilon = epsilon * 0.999985
    
def tune_params2(episode):
    global epsilon
    epsilon = epsilon_min + (epsilon_max - epsilon_min)*np.exp(-0.00005*episode)

def eps_greedy_policy(state,Q):    
    if random.uniform(0,1) <  epsilon:
        new_action = env.action_space.sample() #explore
    else:
        new_action = np.argmax(Q[state,:]) #exploit
        
    return new_action

def policy_eval(policy):
    """
    policy should be an iterable with length of number of states (action per state)
    """
    rewards = []
    for i in range(NUMBER_OF_EVAL_SIMS):
        state = env.reset()

        run_reward = 0
        is_done = False
        while not is_done:
            state, reward, is_done, _ = env.step(policy[state])

            run_reward += reward

        rewards.append(run_reward)

    return np.mean(rewards)

hl,=plt.plot(np.array([]),np.array([]))

import time
from IPython import display

def updateLine(x,y,print_to_screen=False):
    hl.set_xdata(np.append(hl.get_xdata(), x))
    hl.set_ydata(np.append(hl.get_ydata(), y))
    if print_to_screen:
        plt.plot(hl.get_xdata(),hl.get_ydata())
        display.clear_output(wait=True) 
        display.display(plt.gcf())
        display.display("x:{0} y:{1:.3f} alpha:{2:.3f} lambda:{3:.3f} starting epsilon:{4:.3f} current epsilon:{5:.3f}"
        .format(x,y,alpha,_lambda,epsilon_max,epsilon))
        time.sleep(0.15)
        
def Sarsa_lambda(episodes=episodes,max_steps=max_steps,is_decay=True):
    Q = Initialize_Q()
    total_steps = 0
    rewards = []
    updateLine(0,0)

    for k in range(episodes):
        
        #init E,S,A
        E = Initialize_E()
        state = env.reset() #random.randint(0,env.env.nS-1) 
        action = eps_greedy_policy(state,Q)
        #R = 0

        for step in range(max_steps):
            #Take action A, ovserve R,S'
            new_state, reward, done, _ = env.step(action)
            new_action = eps_greedy_policy(new_state,Q)
            
            delta_error = reward + gamma*Q[new_state,new_action]-Q[state,action]
            E[state,action]+=1
            Q = np.add(Q, np.multiply(alpha * delta_error, E))
            E = np.multiply(gamma * _lambda, E)
            
            state=new_state
            action=new_action
            #R+=reward
            total_steps+=1

            if  total_steps % 1000 == 0:
                policy = np.argmax(Q,axis=1)
                policy_evaluate = policy_eval(policy)
                updateLine(total_steps,policy_evaluate,True)
                
            if done:
                break
        
        #print("is_decay",is_decay)
        if is_decay:
            tune_params2(k)

        if total_steps > 1e6:
            break    
        #rewards.append(R)
         
    return Q,np.argmax(Q,axis=1) # returns Q and the policy

Q,Policy = Sarsa_lambda(episodes,max_steps)


In [None]:
plt.plot(hl.get_xdata(),hl.get_ydata()) 
plt.savefig("policy_eval_over_steps.png")

In [None]:
# history = np.array(history)

# plt.plot(history[:, 0], history[:, 1], '.') 
# plt.xlabel('steps:')
# plt.ylabel('Cumulative rewards')

# str_epsilon = epsilon
# if is_decay:
#     str_epsilon = str(epsilon_origin)+"*0.99^n"

# title = "epsilon: {}, alpha: {}, lambda: {}".format(str_epsilon, alpha, _lambda)
# plt.title(title)
# plt.savefig("policy_eval_over_steps.png")
# plt.show()

In [None]:
def decode_action(action):
    if action == 0:
        return "<"
    if action == 1:
        return "v"
    if action == 2:
        return ">"
    if action == 3:
        return "^"
    else:
        return "Unknown"

env.reset()
env.render()

if env_type == "FrozenLake8x8-v0":
    print(np.array([decode_action(action) for action in Policy]).reshape((8, 8)))
if env_type == "FrozenLake-v0":
    print(np.array([decode_action(action) for action in Policy]).reshape((4, 4)))
        