In [9]:
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

# The estimator is of the form (s,a) -> scalar value

class FunctionEstimator:
    def __init__(self,n_actions):
        self.n_actions = n_actions
        self.initial_state = env.reset()
        self.model = self._build_model()
        self.memory_buffer = deque(maxlen=2000)
        self.update_buffer = []

    def _concat(self, state, action):
        return np.hstack([state,action]).reshape(1,-1)

    def _build_model(self):
        
        model = Sequential()
        model.add(Dense(24, input_dim=(80*80+1), activation='tanh'))
        model.add(Dense(24, activation='tanh'))
        model.add(Dense(1, activation='linear'))
        model.compile(
                    loss='mse'
                    , optimizer= Adam(lr = 0.001)
                    )      
        return model
    
    def update(self,buffer):
        states = [buffer[ix][0] for ix in range(len(buffer))]
        actions = [buffer[ix][1] for ix in range(len(buffer))]
        td_targets = [buffer[ix][2] for ix in range(len(buffer))]
        for state, action, target in zip(states, actions, td_targets):
            self.model.fit(self._concat(state,action), [td_target], verbose=0)
        
    def predict(self,state):
        concats = [np.array([self._concat(state, a)]).reshape(1,-1) for a in range(self.n_actions)]
        return [self.model.predict(c) for c in concats]

    
    def remember(self, state, action,td_target):
        self.memory_buffer.append((state,action,td_target))
    
    
    def replay(self,batch_size):
        # Experience replay
            # choose only a sample from the collected experience
        update_buffer_idxs = np.random.choice(len(self.memory_buffer)
                                                , size=min(len(self.memory_buffer), batch_size)
                                                , replace=False
                                                    ) 
        update_buffer_idxs = np.ravel(update_buffer_idxs)
                
        for ix in range(len(update_buffer_idxs)):
            saved_ix = update_buffer_idxs[ix]
            self.update_buffer.append(self.memory_buffer[saved_ix])
            
        self.update(self.update_buffer)


## Auxiliary function for the policy
def make_policy(estimator, n_actions, ep):
    def policy_fn(state):
        preds = np.ravel(estimator.predict(state))
        noise = np.ravel(np.random.randn(1,n_actions)*(1./(ep+1)))
        action = np.argmax(preds+noise)
        return action
    return policy_fn



In [None]:
def preprocess(I):
    """ prepro 210x160x3 uint8 frame into 6400 (80x80) 1D float vector """
    I= I[35:195] # crop
    I = I[::2,::2,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    return I.astype(np.float).ravel()

In [None]:
import gym
gym.logger.set_level(40)

env = gym.make('Pong-v0')
n_episodes = 1000
gamma = 1
estimator = FunctionEstimator(env.action_space.n)
score = []

for ep in range(n_episodes):
    obs = env.reset()
    done = False
    policy = make_policy(estimator, env.action_space.n, ep)
    ep_reward = 0
    while not done:
        state = preprocess(obs)
        action = policy(state)
        new_obs, reward, done, _ = env.step(action)
        new_state = preprocess(new_obs)
        ep_reward += reward
        # Update the Q-function
        if done:
            td_target = reward
        else:
            td_target = reward + gamma*np.argmax(estimator.predict(new_state))
        
        estimator.remember(state,action, td_target)
        # Update the state
        state = new_state
        #
    estimator.replay(32)
    # Show stats
    if done:
        if len(score) < 100:
            score.append(ep_reward)
        else:
            score[ep % 100] = ep_reward
    if (ep+1) % 100 == 0:
        estimator.model.save_weights('./out/pong-keras-{}.h5'.format(ep+1))
        print("Number of episodes: {} . Average 100-episode reward: {}".format(ep+1, np.mean(score)))



In [None]:
import matplotlib.pyplot as plt
plt.imshow(obs)
plt.show()

In [None]:
# To plot pretty figures and animations
%matplotlib nbagg
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return patch

def plot_animation(frames, repeat=False, interval=40):
    plt.close()  # or else nbagg sometimes plots in the previous cell
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
    return animation.FuncAnimation(fig, update_scene, fargs=(frames, patch), frames=len(frames), repeat=repeat, interval=interval)

In [None]:
frames = []
estimator = FunctionEstimator(env.action_space.n)
estimator.model.load_weights('./out/pong-keras-100.h5')
done = False
obs = env.reset()
policy = make_policy(estimator, env.action_space.n, 1)
while not done:
    state = preprocess(obs)
    action = policy(state)
    obs, reward, done, _ = env.step(action)
    
    img = env.render(mode="rgb_array")
    frames.append(img)
    
plot_animation(frames)
    
