In [1]:
import gym
from gym.core import ObservationWrapper
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display

from models.agentRandomSearch import AgentRandomSearch
from models.agentSarsa import AgentSarsa

# configurations
numEpoch=int(1e6)
numTargetFrame=200
numRecordStep=int(1e2)
numSucceed=3 # number of epochs it reaches the target
agent_choice='sarsa' # choose from 'random_search', 'sarsa', 'qlearning'
alpha=0.1
gamma=0.99

buckets=[2, 2, 8, 4]
# buckets=[1, 10, 10, 10]

In [2]:
# changes the continuous state space to discrete ones
class DiscretizeStateWrapper(gym.ObservationWrapper):
    def __init__(self, env, buckets=[1, 1, 6, 12], upper_bounds=[], lower_bounds=[]):
        super(ObservationWrapper, self).__init__(env)
        self.buckets = buckets
        if upper_bounds==[]:
            self.upper_bounds = [env.observation_space.high[0], 0.5, env.observation_space.high[2], math.radians(50) / 1.]
        else:
            self.upper_bounds = upper_bounds            
        if lower_bounds==[]:
            self.lower_bounds = [env.observation_space.low[0], -0.5, env.observation_space.low[2], -math.radians(50) / 1.]
        else:
            self.lower_bounds = lower_bounds
    def observation(self, obs):
        discretized = list()
        for i in range(len(obs)):
            scaling = (obs[i] + abs(self.lower_bounds[i])) / (self.upper_bounds[i] - self.lower_bounds[i])
            new_obs = int(round((self.buckets[i] - 1) * scaling))
            new_obs = min(self.buckets[i] - 1, max(0, new_obs)) # VW: needed?
            discretized.append(new_obs)
        return tuple(discretized)

# # doesn't do anything
# class RewardWrapper(gym.RewardWrapper):
#     def __init__(self, env):
#         super().__init__(env)
    
#     def reward(self, rew):
#         # modify rew
#         return rew

# # doesn't do anything
# class ActionWrapper(gym.ActionWrapper):
#     def __init__(self, env):
#         super().__init__(env)
    
#     def action(self, act):
#         # modify act
#         return act
    
env = DiscretizeStateWrapper(gym.make('CartPole-v1'),buckets)


In [3]:
def check_if_state_exist(self, state):
    if state3 not in self.q_table:
        self.q_table[state]=np.zeros(env.action_space.n)


In [4]:
# # A naive agent that goes to right or left depending on the tilt angle
# class AgentBasic(object):
#     def get_action(self, states):
#         return 0 if states[2] < 0 else 1

In [5]:
def run_episode(env, agent):  
    """Runs the env for a certain amount of steps with the given parameters. Returns the reward obtained"""
    state = env.reset()
    totalreward = 0
    for _ in range(numTargetFrame):
        state, reward, done = agent.train(state)   
        totalreward += reward
        if done:
            break
    return totalreward

In [None]:
epoch_list=[]
reward_list=[]
bestreward = 0
succeed = 0

if(agent_choice=='random_search'):
    agent=AgentRandomSearch(env)
elif(agent_choice=='sarsa'):
    agent=AgentSarsa(env, alpha=alpha, gamma=gamma)
elif(agent_choice=='qlearning'):
    agent=AgentQlearning(env, alpha=alpha, gamma=gamma)

for i in range(numEpoch):  
    # varying epsilon to reduce exploration as timestep increases
    agent.set_epsilon(math.exp(-numEpoch/100000)) # 0.37 when epoch=100000

    reward = run_episode(env, agent)
    if reward > bestreward:
        bestreward = reward
        # considered solved if the agent lasts for the required number of timesteps
        if reward == numTargetFrame:
            succeed+=1
            if(succeed==numSucceed):
                reward_list.append(reward)
                epoch_list.append(i)
                break

    if(i<10 or (i<100 and i%10==0) or (i<1000 and i%100==0) or (i%1000==0)):
        print("Running epoch # {}...".format(i))
    if(i%numRecordStep==0):
        reward_list.append(reward)
        epoch_list.append(i)

if (succeed==numSucceed):
    print("Finished running and solution found in epoch # {}! =D \n".format(i)) # first epoch starts from label 0
else:
    print("Finished running but solution not found. =\\ \n")
print("The best reward was {} steps.".format(bestreward))
        
print("#################################")
print("#                               #")
print("#        Done training!!        #")
print("#                               #")
print("#################################")

Running epoch # 0...
Running epoch # 1...
Running epoch # 2...
Running epoch # 3...
Running epoch # 4...
Running epoch # 5...
Running epoch # 6...
Running epoch # 7...
Running epoch # 8...
Running epoch # 9...
Running epoch # 10...
Running epoch # 20...
Running epoch # 30...
Running epoch # 40...
Running epoch # 50...
Running epoch # 60...
Running epoch # 70...
Running epoch # 80...
Running epoch # 90...
Running epoch # 100...
Running epoch # 200...
Running epoch # 300...
Running epoch # 400...
Running epoch # 500...
Running epoch # 600...
Running epoch # 700...
Running epoch # 800...
Running epoch # 900...
Running epoch # 1000...
Running epoch # 2000...
Running epoch # 3000...
Running epoch # 4000...
Running epoch # 5000...
Running epoch # 6000...
Running epoch # 7000...
Running epoch # 8000...
Running epoch # 9000...
Running epoch # 10000...
Running epoch # 11000...
Running epoch # 12000...
Running epoch # 13000...
Running epoch # 14000...
Running epoch # 15000...
Running epoch # 160

In [None]:
plt.plot(epoch_list, reward_list)
plt.xlabel("Epoch")
plt.ylabel("Reward")
plt.ylim([0,numTargetFrame*1.2])
plt.savefig('reward_vs_epoch.png')
del epoch_list, reward_list

In [None]:
def show_episode(env):  
    """ Records the frames of the environment obtained using the given parameters... Returns RGB frames"""
    state = env.reset()
    firstframe = env.render(mode='rgb_array')
    frames = [firstframe]
    
    for _ in range(numTargetFrame):
        action = agent.get_action(state)
        state, reward, done, info = env.step(action)
        frame = env.render(mode='rgb_array')
        frames.append(frame)
        if done:
            break
    return frames

def display_frames_as_gif(frames, filename_gif = None):
    """
    Displays a list of frames as a gif, with controls
    """
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 36)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50, repeat=False)
    if filename_gif: 
        print("Saving animation...")
        anim.save(filename_gif, writer = 'pillow', fps=10)
        print("Animation saves as gif at: {}".format(filename_gif))
        
frames = show_episode(env)
display_frames_as_gif(frames, filename_gif="random_search_play.gif")
env.close()
print("###############################")
print("#                             #")
print("#        Done saving!!        #")
print("#                             #")
print("###############################")