In [1]:
import acme

In [2]:
environment_library = 'gym'

In [3]:
import IPython

from acme import environment_loop
from acme import specs
from acme import wrappers
from acme.agents.tf import d4pg
from acme.tf import networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
import numpy as np
import sonnet as snt

from absl import app
from absl import flags
import acme
from acme.agents.tf import dqn
from acme.tf import networks

# Import the selected environment lib
if environment_library == 'dm_control':
  from dm_control import suite
elif environment_library == 'gym':
  import gym

# Imports required for visualization
import pyvirtualdisplay
import imageio
import base64

# Set up a virtual display for rendering.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

In [4]:
# helper functions
import functools

from acme import wrappers
import dm_env
import gym

def make_environment(evaluation: bool = False,
                     level: str = 'PongNoFrameskip-v4') -> dm_env.Environment:
  env = gym.make(level, full_action_space=True)

  max_episode_len = 108_000 if evaluation else 50_000

  return wrappers.wrap_all(env, [
      wrappers.GymAtariAdapter,
      functools.partial(
          wrappers.AtariWrapper,
          to_float=True,
          max_episode_len=max_episode_len,
          zero_discount_on_life_loss=True,
      ),
      wrappers.SinglePrecisionWrapper,
  ])

In [5]:
flags.DEFINE_string('level', 'PongNoFrameskip-v4', 'Which Atari level to play.')
flags.DEFINE_integer('num_episodes', 10, 'Number of episodes to train for.')
FLAGS = flags.FLAGS

In [12]:
environment = make_environment('PongNoFrameskip-v4')
environment_spec = acme.make_environment_spec(environment)
network = networks.DQNAtariNetwork(environment_spec.actions.num_values)

agent = dqn.DQN(environment_spec, network)

# Create a logger for the agent and environment loop.
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)

# Create an loop connecting this agent to the environment created above.
# env_loop = environment_loop.EnvironmentLoop(
#     environment, agent, logger=env_loop_logger)
env_loop = acme.EnvironmentLoop(
    environment, agent, logger=env_loop_logger)

# Run a `num_episodes` training episodes.
# Rerun this cell until the agent has learned the given task.
env_loop.run(num_episodes=10)

INFO:tensorflow:Assets written to: /home/hayato/acme/155dbec6-d5f4-11eb-a11b-03d440faad90/snapshots/network/assets
INFO:tensorflow:Assets written to: /home/hayato/acme/155dbec6-d5f4-11eb-a11b-03d440faad90/snapshots/network/assets


In [13]:
# Create a simple helper function to render a frame from the current state of
# the environment.
def render(env):
    return env.environment.render(mode='rgb_array')

def display_video(frames, filename='temp.mp4'):
  """Save and display video."""

  # Write video
  with imageio.get_writer(filename, fps=60) as video:
    for frame in frames:
      video.append_data(frame)

  # Read video and display the video
  video = open(filename, 'rb').read()
  b64_video = base64.b64encode(video)
  video_tag = ('<video  width="320" height="240" controls alt="test" '
               'src="data:video/mp4;base64,{0}">').format(b64_video.decode())

  return IPython.display.HTML(video_tag)

In [14]:
timestep = environment.reset()
frames = [render(environment)]

while not timestep.last():
  # Simple environment loop.
  action = agent.select_action(timestep.observation)
  timestep = environment.step(action)

  # Render the scene and add it to the frame stack.
  frames.append(render(environment))

# Save and display a video of the behaviour.
display_video(np.array(frames))

