Fetching contributors…
Cannot retrieve contributors at this time
336 lines (263 sloc) 10.8 KB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import collections
import numpy as np
# Defaults policy id for single agent environments
def to_float_array(v):
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
class SampleBatchBuilder(object):
"""Util to build a SampleBatch incrementally.
For efficiency, SampleBatches hold values in column form (as arrays).
However, it is useful to add data one row (dict) at a time.
def __init__(self):
self.buffers = collections.defaultdict(list)
self.count = 0
def add_values(self, **values):
"""Add the given dictionary (row) of values to this batch."""
for k, v in values.items():
self.count += 1
def add_batch(self, batch):
"""Add the given batch of values to this batch."""
for k, column in batch.items():
self.count += batch.count
def build_and_reset(self):
"""Returns a sample batch including all previously added values."""
batch = SampleBatch(
{k: to_float_array(v)
for k, v in self.buffers.items()})
self.count = 0
return batch
class MultiAgentSampleBatchBuilder(object):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
mapping between agents and policies. We retain one local batch builder
per agent. When an agent is done, then its local batch is appended into the
corresponding policy batch for the agent's policy.
def __init__(self, policy_map, clip_rewards):
"""Initialize a MultiAgentSampleBatchBuilder.
policy_map (dict): Maps policy ids to policy graph instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.policy_builders = {
k: SampleBatchBuilder()
for k in policy_map.keys()
self.agent_builders = {}
self.agent_to_policy = {}
self.count = 0 # increment this manually
def total(self):
"""Returns summed number of steps across all agent buffers."""
return sum(p.count for p in self.policy_builders.values())
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
return len(self.agent_builders) > 0
def add_values(self, agent_id, policy_id, **values):
"""Add the given dictionary (row) of values to this batch.
agent_id (obj): Unique id for the agent we are adding values for.
policy_id (obj): Unique id for policy controlling the agent.
values (dict): Row of values to add for this agent.
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
builder = self.agent_builders[agent_id]
def postprocess_batch_so_far(self, episode):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
episode: current MultiAgentEpisode object or None
# Materialize the batches so far
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
# Apply postprocessor
post_batches = {}
if self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
episode: current MultiAgentEpisode object or None
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
policy_batches[policy_id] = builder.build_and_reset()
old_count = self.count
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor).""" = dict(*args, **kwargs)
lengths = []
for k, v in
assert isinstance(k, six.string_types), self
lengths.append(len(v))[k] = np.array(v, copy=False)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, "data columns must be same length"
self.count = lengths[0]
def concat_samples(samples):
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
>>> print(b1.concat(b2))
{"a": [1, 2, 3, 4, 5]}
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in})
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
{"a": 1, "b": 4}
{"a": 2, "b": 5}
{"a": 3, "b": 6}
for i in range(self.count):
row = {}
for k in self.keys():
row[k] = self[k][i]
yield row
def columns(self, keys):
"""Returns a list of just the specified columns.
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
out = []
for k in keys:
return out
def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.items():
self[key] = val[permutation]
def __getitem__(self, key):
def __setitem__(self, key, item):[key] = item
def __str__(self):
return "SampleBatch({})".format(str(
def __repr__(self):
return "SampleBatch({})".format(str(
def keys(self):
def items(self):
def __iter__(self):
def __contains__(self, x):
return x in