Skip to content
Permalink
Browse files

Implementation of parallel evaluation of actions and observations (#508)

  • Loading branch information...
Deathn0t authored and AlexKuhnle committed Sep 22, 2018
1 parent d1a9ecf commit b15fb3d413c59596fc8a309fcddc28a2feafe9a0
@@ -63,10 +63,6 @@ def __init__(
# Batched observe for better performance with Python.
self.batched_observe = batched_observe
self.batching_capacity = batching_capacity
if self.batched_observe:
assert self.batching_capacity is not None
self.observe_terminal = list()
self.observe_reward = list()

self.current_states = None
self.current_actions = None
@@ -78,6 +74,10 @@ def __init__(
self.episode = None

self.model = self.initialize_model()
if self.batched_observe:
assert self.batching_capacity is not None
self.observe_terminal = [list() for _ in range(self.model.num_parallel)]
self.observe_reward = [list() for _ in range(self.model.num_parallel)]
self.reset()

def __str__(self):
@@ -101,9 +101,9 @@ def reset(self):
self.episode, self.timestep, self.next_internals = self.model.reset()
self.current_internals = self.next_internals

def act(self, states, deterministic=False, independent=False, fetch_tensors=None, buffered=True):
def act(self, states, deterministic=False, independent=False, fetch_tensors=None, buffered=True, index=0):
"""
Return action(s) for given state(s). States preprocessing and exploration are applied if
Return action(s) for given state(s). States preprocessing and exploration are applied if
configured accordingly.
Args:
@@ -132,7 +132,8 @@ def act(self, states, deterministic=False, independent=False, fetch_tensors=None
internals=self.current_internals,
deterministic=deterministic,
independent=independent,
fetch_tensors=fetch_tensors
fetch_tensors=fetch_tensors,
index=index
)

if self.unique_action:
@@ -145,7 +146,8 @@ def act(self, states, deterministic=False, independent=False, fetch_tensors=None
states=self.current_states,
internals=self.current_internals,
deterministic=deterministic,
independent=independent
independent=independent,
index=index
)

# Buffered mode only works single-threaded because buffer inserts
@@ -161,7 +163,7 @@ def act(self, states, deterministic=False, independent=False, fetch_tensors=None
else:
return self.current_actions, self.current_states, self.current_internals

def observe(self, terminal, reward):
def observe(self, terminal, reward, index=0):
"""
Observe experience from the environment to learn from. Optionally pre-processes rewards
Child classes should call super to get the processed reward
@@ -176,16 +178,17 @@ def observe(self, terminal, reward):

if self.batched_observe:
# Batched observe for better performance with Python.
self.observe_terminal.append(self.current_terminal)
self.observe_reward.append(self.current_reward)
self.observe_terminal[index].append(self.current_terminal)
self.observe_reward[index].append(self.current_reward)

if self.current_terminal or len(self.observe_terminal) >= self.batching_capacity:
if self.current_terminal or len(self.observe_terminal[index]) >= self.batching_capacity:
self.episode = self.model.observe(
terminal=self.observe_terminal,
reward=self.observe_reward
terminal=self.observe_terminal[index],
reward=self.observe_reward[index],
index=index
)
self.observe_terminal = list()
self.observe_reward = list()
self.observe_terminal[index] = list()
self.observe_reward[index] = list()

else:
self.episode = self.model.observe(
@@ -240,9 +243,9 @@ def last_observation(self):

def save_model(self, directory=None, append_timestep=True):
"""
Save TensorFlow model. If no checkpoint directory is given, the model's default saver
directory is used. Optionally appends current timestep to prevent overwriting previous
checkpoint files. Turn off to be able to load model from the same given path argument as
Save TensorFlow model. If no checkpoint directory is given, the model's default saver
directory is used. Optionally appends current timestep to prevent overwriting previous
checkpoint files. Turn off to be able to load model from the same given path argument as
given here.
Args:
@@ -262,8 +265,8 @@ def save_model(self, directory=None, append_timestep=True):

def restore_model(self, directory=None, file=None):
"""
Restore TensorFlow model. If no checkpoint file is given, the latest checkpoint is
restored. If no checkpoint directory is given, the model's default saver directory is
Restore TensorFlow model. If no checkpoint file is given, the latest checkpoint is
restored. If no checkpoint directory is given, the model's default saver directory is
used (unless file specifies the entire path).
Args:
@@ -142,4 +142,9 @@ def sanity_check_execution_spec(execution_spec):
elif type_ == "single":
return execution_spec

if execution_spec.get('num_parallel') != None:
assert type(execution_spec['num_parallel']) is int, "ERROR: num_parallel needs to be of type int but is of type {}!".format(type(execution_spec['num_parallel']).__name__)
assert execution_spec['num_parallel'] > 0, "ERROR: num_parallel needs to be > 0 but is equal to {}".format(execution_spec['num_parallel'])
return execution_spec

raise TensorForceError("Unsupported execution type specified ({})!".format(type_))
@@ -16,5 +16,6 @@
from tensorforce.execution.base_runner import BaseRunner
from tensorforce.execution.runner import Runner, SingleRunner, DistributedTFRunner
from tensorforce.execution.threaded_runner import ThreadedRunner, WorkerAgentGenerator
from tensorforce.execution.parallel_runner import ParallelRunner

__all__ = ['BaseRunner', 'SingleRunner', 'DistributedTFRunner', 'Runner', 'ThreadedRunner', 'WorkerAgentGenerator']
__all__ = ['BaseRunner', 'SingleRunner', 'DistributedTFRunner', 'Runner', 'ThreadedRunner', 'WorkerAgentGenerator', 'ParallelRunner']
@@ -0,0 +1,155 @@
# Copyright 2017 reinforce.io. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from tensorforce.execution.base_runner import BaseRunner

import time
from six.moves import xrange
import warnings
from inspect import getargspec
from tqdm import tqdm

class ParallelRunner(BaseRunner):
"""
Simple runner for non-realtime single-process execution.
"""

def __init__(self, agent, environment, repeat_actions=1, history=None, id_=0):
"""
Initialize a single Runner object (one Agent/one Environment).
Args:
id_ (int): The ID of this Runner (for distributed TF runs).
"""
super(ParallelRunner, self).__init__(agent, environment, repeat_actions, history)

self.id = id_ # the worker's ID in a distributed run (default=0)
self.current_timestep = None # the time step in the current episode
self.episode_actions = []
self.num_parallel = self.agent.execution['num_parallel']
print('ParallelRunner with {} parallel buffers.'.format(self.num_parallel))

def close(self):
self.agent.close()
self.environment.close()

# TODO: make average reward another possible criteria for runner-termination
def run(self, num_timesteps=None, num_episodes=None, max_episode_timesteps=None, deterministic=False, episode_finished=None, summary_report=None, summary_interval=None, timesteps=None, episodes=None, testing=False, sleep=None
):
"""
Args:
timesteps (int): Deprecated; see num_timesteps.
episodes (int): Deprecated; see num_episodes.
"""

# deprecation warnings
if timesteps is not None:
num_timesteps = timesteps
warnings.warn("WARNING: `timesteps` parameter is deprecated, use `num_timesteps` instead.",
category=DeprecationWarning)
if episodes is not None:
num_episodes = episodes
warnings.warn("WARNING: `episodes` parameter is deprecated, use `num_episodes` instead.",
category=DeprecationWarning)

# figure out whether we are using the deprecated way of "episode_finished" reporting
old_episode_finished = False
if episode_finished is not None and len(getargspec(episode_finished).args) == 1:
old_episode_finished = True

# Keep track of episode reward and episode length for statistics.
self.start_time = time.time()

self.agent.reset()

if num_episodes is not None:
num_episodes += self.agent.episode

if num_timesteps is not None:
num_timesteps += self.agent.timestep

# add progress bar
with tqdm(total=num_episodes) as pbar:
# episode loop
index = 0
while True:
episode_start_time = time.time()
state = self.environment.reset()
self.agent.reset()

# Update global counters.
self.global_episode = self.agent.episode # global value (across all agents)
self.global_timestep = self.agent.timestep # global value (across all agents)

episode_reward = 0
self.current_timestep = 0

# time step (within episode) loop
while True:
action = self.agent.act(states=state, deterministic=deterministic, index=index)

reward = 0
for _ in xrange(self.repeat_actions):
state, terminal, step_reward = self.environment.execute(action=action)
reward += step_reward
if terminal:
break

if max_episode_timesteps is not None and self.current_timestep >= max_episode_timesteps:
terminal = True

if not testing:
self.agent.observe(terminal=terminal, reward=reward, index=index)

self.global_timestep += 1
self.current_timestep += 1
episode_reward += reward

if terminal or self.agent.should_stop(): # TODO: should_stop also terminate?
break

if sleep is not None:
time.sleep(sleep)

index = (index + 1) % self.num_parallel

# Update our episode stats.
time_passed = time.time() - episode_start_time
self.episode_rewards.append(episode_reward)
self.episode_timesteps.append(self.current_timestep)
self.episode_times.append(time_passed)
self.episode_actions.append(self.environment.conv_action)

self.global_episode += 1
pbar.update(1)

# Check, whether we should stop this run.
if episode_finished is not None:
# deprecated way (passing in only runner object):
if old_episode_finished:
if not episode_finished(self):
break
# new unified way (passing in BaseRunner AND some worker ID):
elif not episode_finished(self, self.id):
break
if (num_episodes is not None and self.global_episode >= num_episodes) or \
(num_timesteps is not None and self.global_timestep >= num_timesteps) or \
self.agent.should_stop():
break
pbar.update(num_episodes - self.global_episode)
@@ -318,7 +318,7 @@ def len_(cumulative, term):

def tf_reference(self, states, internals, actions, terminal, reward, next_states, next_internals, update):
"""
Creates the TensorFlow operations for obtaining the reference tensor(s), in case of a
Creates the TensorFlow operations for obtaining the reference tensor(s), in case of a
comparative loss.
Args:

0 comments on commit b15fb3d

Please sign in to comment.
You can’t perform that action at this time.