In [None]:
import gym
import random
import tensorflow as tf
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import keras
import cv2
from time import gmtime, strftime
from Agent import *
%matplotlib inline

In [None]:
env = gym.make('Enduro-v0')
env.reset()
env.seed(0)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

In [None]:
actions = env.action_space.n
agent = QAgent(actions)

In [None]:
def preprocess(observation):
    observation = cv2.cvtColor(cv2.resize(observation, (84, 110)), cv2.COLOR_BGR2GRAY)
    #the first 26 rows contain only the score, we can ignore them
    observation = observation[26:110,:]
    #ret,thresh1 = cv.threshold(img,lower,upper,cv.THRESH_BINARY) return whitr if < lower, black > upper 
    ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
    return np.reshape(observation,(84,84,1))


action0 = 0  # do nothing
env.reset()
observation0, reward0, terminal, info = env.step(action0)
print("Before processing: " + str(np.array(observation0).shape))
plt.imshow(np.array(observation0))
plt.show()
observation0 = preprocess(observation0)
print("After processing: " + str(np.array(observation0).shape))
plt.imshow(np.array(np.squeeze(observation0)))
plt.show()

In [None]:
agent.setInitState(observation0)
agent.currentState = np.squeeze(agent.currentState)

In [None]:
while True:
    action = agent.getAction()
    actionmax = np.argmax(np.array(action))
    
    nextObservation,reward,terminal, info = env.step(actionmax)
    
    if terminal:
        nextObservation = env.reset()
    nextObservation = preprocess(nextObservation)
    agent.setPerception(nextObservation,action,reward,terminal)

In [None]:
from JSAnimation.IPython_display import display_animation
from ipywidgets import widgets
from IPython.display import display
from matplotlib import animation

def display_frames_as_gif(frames, filename_gif = None):
    """
    Displays a list of frames as a gif, with controls
    """
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    if filename_gif: 
        anim.save(filename_gif, writer = 'imagemagick', fps=20)
    display(display_animation(anim, default_mode='loop'))

    
frameshistory = []
observation = env.reset()
backupepsilon = agent.epsilon

agent.epsilon = 0.2

for _ in range(150):
    action = agent.getAction()
    
    #print(action)
    actionmax = np.argmax(np.array(action))
    
    nextObservation,reward,terminal, info = env.step(actionmax)
    if terminal:
        nextObservation = env.reset()
    frameshistory.append(nextObservation)
    nextObservation = preprocess(nextObservation)
    agent.setPerception(nextObservation,action,reward,terminal)
agent.epsilon = backupepsilon
    
display_frames_as_gif(frameshistory, 'playing_enduro.gif')

agent.plot_cost()