In [38]:
import gym
from gym.wrappers import Monitor
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from collections import deque
cv2.ocl.setUseOpenCL(False)

In [39]:
def pre_process(frame):
    gray = np.mean(frame, axis=2)
    return cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)

In [42]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        """Take action on reset for environments that are fixed until firing."""
        gym.Wrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        return obs

    def step(self, ac):
        return self.env.step(ac)

In [None]:
model = tf.keras.models.load_model('../../saved_model/dqn_breakout')
env = gym.make('BreakoutDeterministic-v4')
env = FireResetEnv(env)
env = gym.wrappers.Monitor(env, "./vid_breakout", video_callable=lambda episode_id: True, force=True)
obs = env.reset()
obs, _, _, _ = env.step(0)

In [37]:
frame_stack = deque([], maxlen=4)
processed = pre_process(obs)
next_state = np.reshape([processed], (1, 84, 84, 1))
frame_stack.append(next_state)
frame_stack.append(next_state)
frame_stack.append(next_state)
frame_stack.append(next_state)

In [None]:
done = False
reward_sum = 0
lives = env.unwrapped.ale.lives()
while not done:
    out = np.concatenate(frame_stack, axis=-1)
    q_values = model.predict(out)
    action = tf.argmax(q_values[0])
    obs, reward, done, info = env.step(action.numpy())
    if env.unwrapped.ale.lives() < lives and not done:
        lives = env.unwrapped.ale.lives()
        obs, _, _, _ = env.step(1)
    env.render()
    processed = pre_process(obs)
    next_state = np.reshape([processed], (1, 84, 84, 1))
    frame_stack.append(next_state)

env.close()