In [None]:
import gym
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display

env = gym.make('CartPole-v1')

# configurations
numEpoch=10000
numTargetFrame=400
# agent='basic'
# agent='random_search'

buckets=[1, 1, 6, 12]
upper_bounds = [env.observation_space.high[0], 0.5, env.observation_space.high[2], math.radians(50) / 1.]
lower_bounds = [env.observation_space.low[0], -0.5, env.observation_space.low[2], -math.radians(50) / 1.]

In [None]:
# A naive agent that goes to right or left depending on the tilt angle
class AgentBasic(object):
    def get_action(self, states):
        return 0 if states[2] < 0 else 1

# An agent that uses random search
class AgentRandomSearch(object):
    def __init__(self, env, parameters=None):
        self.env = env
        if parameters==None:
            self.parameters=np.random.rand(4) * 2 - 1
        else:
            self.parameters=parameters
    def get_params(self):
        return self.parameters
    def set_params(self, parameters):
        self.parameters=parameters
    def get_action(self, states):
        return 0 if np.matmul(self.parameters, states) < 0 else 1
    def train(self, state):
        action = self.get_action(state)
        state_next, reward, terminal, info= self.env.step(action)
        return state_next, reward, terminal
    
class AgentSarsa(object):
    def __init__(self, env, state_space, action_space, alpha=0.01, gamma=0.9, epsilon=0.9):
        super().__init__(parameters)
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.q_table = np.zeros((state_space.n, action_space.n)) 
        self.actions = action_space
        self.env = env
    
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            target_actions = self.q_table.loc[state, :]
            target_actions = target_actions.reindex(np.random.permutation(target_actions.index))
            target_action = target_actions.idxmax()
        else:
            target_action = self.actions.sample()
        return target_action

    def update_q_table(self, state, action, reward, state_next, action_next, terminal):
        q_value_predict = self.q_table.loc[state, action]
        if terminal == False:
            q_value_real = reward + self.gamma * self.q_table.loc[state_next, action_next]
        else:
            q_value_real = reward
        self.q_table.loc[state, action] += self.alpha * (q_value_real - q_value_predict)

    def train(self, state):
        # Get first action.
        action = self.get_action(state)
        # Get next state.
        state_next, reward, terminal, info = self.env.step(action)
        # Get next action.
        action_next = self.get_action(state_next)
        # Update Q table.
        self.update_q_table(state, action, reward, state_next, action_next, terminal)
        return state_next, reward, terminal 

In [None]:
def discretize_state(env, obs):
    discretized = list()
    for i in range(len(obs)):
        scaling = (obs[i] + abs(lower_bounds[i])) / (upper_bounds[i] - lower_bounds[i])
        new_obs = int(round((buckets[i] - 1) * scaling))
        new_obs = min(buckets[i] - 1, max(0, new_obs)) # VW: needed?
        discretized.append(new_obs)
    return tuple(discretized)

def run_episode(env, agent):  
    """Runs the env for a certain amount of steps with the given parameters. Returns the reward obtained"""
    obs = env.reset()
    totalreward = 0
    state=discretize_state(env, obs)
    for _ in range(numTargetFrame):
        state, reward, done = agent.train(state)   
        totalreward += reward
        if done:
            break
    return totalreward, agent.get_params()

In [None]:
bestparams = None  
bestreward = 0
succeed = 0
for i in range(numEpoch):  
    agent=AgentRandomSearch(env)
#     agent=AgentRandomSearch(env)
#     gym.spaces.multi_discrete.MultiDiscrete(buckets)
    
    reward, parameters = run_episode(env, agent)
    if reward > bestreward:
        bestreward = reward
        bestparams = parameters
        # considered solved if the agent lasts for the required number of timesteps
        if reward == numTargetFrame:
            succeed = 1
            break
    if(i<10 or (i<100 and i%10==0) or (i<1000 and i%100==0) or (i%1000==0)):
        print("Running epoch # {}...".format(i))
if (succeed==1):
    print("Finished running and solution found in epoch # {}! =D \n".format(i)) # first epoch starts from label 0
else:
    print("Finished running but solution not found. =\\ \n")
            
print("#################################")
print("#                               #")
print("#        Done training!!        #")
print("#                               #")
print("#################################")

In [None]:
def show_episode(env, parameters):  
    """ Records the frames of the environment obtained using the given parameters... Returns RGB frames"""
    observation = env.reset()
    firstframe = env.render(mode='rgb_array')
    frames = [firstframe]
    
    for _ in range(numTargetFrame):
        action = 0 if np.matmul(parameters,observation) < 0 else 1
        observation, reward, done, info = env.step(action)
        frame = env.render(mode='rgb_array')
        frames.append(frame)
        if done:
            break
    return frames

def display_frames_as_gif(frames, filename_gif = None):
    """
    Displays a list of frames as a gif, with controls
    """
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 36)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50, repeat=False)
    if filename_gif: 
        print("Saving animation...")
        anim.save(filename_gif, writer = 'pillow', fps=10)
        print("Animation saves as gif at: {}".format(filename_gif))
        
frames = show_episode(env, bestparams)
display_frames_as_gif(frames, filename_gif="random_search_play.gif")
env.close()
print("###############################")
print("#                             #")
print("#        Done saving!!        #")
print("#                             #")
print("###############################")