In [1]:
# !pip3 install box2d-py
# !pip3 install gym[Box_2D]
# import gym
# env = gym.make("LunarLander-v2")

In [1]:
import gym
import torch 
import collections
import os
import numpy as np
from utils import *
from exp_replay_memory import ReplayMemory




In [27]:
def sarsa_lander(env, n_episodes, gamma, lr, min_eps, print_freq=500, render_freq=500):
    q_states = collections.defaultdict(float)   # note that the first insertion of a key initializes its value to 0.0
    return_per_ep = [0.0]
    epsilon = 1.0
    num_actions = env.action_space.n
    
    for i in range(n_episodes):
        t = 0
        if (i + 1) % render_freq == 0:
            render = True
        else:
            render = False

        # Initial episode state: S
        curr_state = discretize_state(env.reset())
        # Choose A from S using policy π
        action = epsilon_greedy(q_states, curr_state, epsilon, num_actions)
        
        while True:
#             if render:
#                 env.render()
                

            # Create (S, A) pair
            
            qstate = curr_state + (action, )
            print( "qstate  " ,qstate )
            print(action )

            # Take action A, earn immediate reward R and land into next state S'
            # S --> A --> R --> S'
            
            observation, reward, done, _ = env.step(action)
            print("obs    " , observation)
            next_state = discretize_state(observation)
            print("next_state", next_state)

            # Next State: S'
            # Choose A' from S' using policy π
            next_action = epsilon_greedy(q_states, next_state, epsilon, num_actions)

            # create (S', A') pair
            new_qstate = next_state + (next_action, )

            ###################################################################
            # Policy evaluation step
            if not done:
                q_states[qstate] += lr * (reward + gamma * q_states[new_qstate] - q_states[qstate]) # (S', A') non terminal state
            else:
                q_states[qstate] += lr * (reward - q_states[qstate])    # (S', A') terminal state
            ###################################################################

            return_per_ep[-1] += reward

            if done:
                if (i + 1) % print_freq == 0:
                    print("\nEpisode finished after {} timesteps".format(t + 1))
                    print("Episode {}: Total Return = {}".format(i + 1, return_per_ep[-1]))
                    print("Total keys in q_states dictionary = {}".format(len(q_states)))

                if (i + 1) % 100 == 0:
                    mean_100ep_reward = round(np.mean(return_per_ep[-101:-1]), 1)
                    print("Last 100 episodes mean reward: {}".format(mean_100ep_reward))

                epsilon = decay_epsilon(epsilon, min_eps)
                return_per_ep.append(0.0)

                break

            curr_state = next_state
            action = next_action
            t += 1
            
    env.close()

    return return_per_ep


In [23]:
n_episodes= 1
lr = 0.1
gamma = 0.99
final_eps = 0.01
environment = gym.make("LunarLander-v2")

In [24]:
print("\nTraining Sarsa lander with arguments num_episodes={}, step-size={}, gamma={}, final_epsilon={} ..."\
                            .format(n_episodes, lr, gamma, final_eps))
total_rewards  = sarsa_lander(environment, n_episodes, gamma, lr, final_eps)
print("Done!")
environment = gym.make("LunarLander-v2")



Training Sarsa lander with arguments num_episodes=1, step-size=0.1, gamma=0.99, final_epsilon=0.01 ...
qstate   (0, 2, -2, -1, 0, 0, 0, 0, 0)
0
obs     [-0.00833111  1.4040464  -0.42135143 -0.16559404  0.00955594  0.0944603
  0.          0.        ]
next_state (0, 2, -2, -1, 0, 0, 0, 0)
qstate   (0, 2, -2, -1, 0, 0, 0, 0, 2)
2
obs     [-0.01256723  1.4002732  -0.4280549  -0.16773167  0.01395565  0.08800233
  0.          0.        ]
next_state (0, 2, -2, -1, 0, 0, 0, 0)
qstate   (0, 2, -2, -1, 0, 0, 0, 0, 3)
3
obs     [-0.01671629  1.3959023  -0.417149   -0.19428948  0.01616507  0.04419252
  0.          0.        ]
next_state (0, 2, -2, -1, 0, 0, 0, 0)
qstate   (0, 2, -2, -1, 0, 0, 0, 0, 3)
3
obs     [-2.0779228e-02  1.3909341e+00 -4.0633145e-01 -2.2080697e-01
  1.6203981e-02  7.7822659e-04  0.0000000e+00  0.0000000e+00]
next_state (0, 2, -2, -2, 0, 0, 0, 0)
qstate   (0, 2, -2, -2, 0, 0, 0, 0, 2)
2
obs     [-0.02490234  1.3865678  -0.4121232  -0.19404693  0.01601725 -0.00373467
  0.   