Skip to content

Commit

Permalink
New Feature: Dict Observations + Rewritten Wrappers
Browse files Browse the repository at this point in the history
This is a major commit which adds the Dict Observation capability to Digideep.. It also rewrites lots of wrappers so they are compatible with the new feature.
  • Loading branch information
sharif1093 committed May 28, 2019
1 parent 41a2bf4 commit dede9a0
Show file tree
Hide file tree
Showing 19 changed files with 1,230 additions and 693 deletions.
16 changes: 10 additions & 6 deletions digideep/agent/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def action_generator(self, observations, hidden_state, masks, deterministic=Fals
dict: ``{"actions":...,"hidden_state":...}``
"""
observations_ = torch.from_numpy(observations).to(self.device)

observation_path = self.params.get("observation_path", "/agent")
observations_ = observations[observation_path].astype(np.float32)

observations_ = torch.from_numpy(observations_).to(self.device)

action = self.policy.generate_actions(observations_, deterministic=deterministic)
action = action.cpu().data.numpy()
Expand All @@ -104,10 +108,10 @@ def action_generator(self, observations, hidden_state, masks, deterministic=Fals
def step(self):
"""This function needs the following key values in the batch of memory:
* ``/observations``
* ``/obs_with_key``
* ``/rewards``
* ``/agents/<agent_name>/actions``
* ``/observations_2``
* ``/obs_with_key_2``
The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
and the last key is added by the sampler.
Expand All @@ -120,12 +124,12 @@ def step(self):


with KeepTime("/update/step/to_torch"):
# ['/observations', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/observations_2']
# ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2']

o1 = torch.from_numpy(batch["/observations"]).to(self.device)
o1 = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
r1 = torch.from_numpy(batch["/rewards"]).to(self.device)
a1 = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
o2 = torch.from_numpy(batch["/observations_2"]).to(self.device)
o2 = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)

# o1.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])
Expand Down
7 changes: 5 additions & 2 deletions digideep/agent/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def action_generator(self, observations, hidden_state, masks, deterministic=Fals
"""

observation_path = self.params.get("observation_path", "/agent")
observations_ = observations[observation_path].astype(np.float32)

with KeepTime("/explore/step/prestep/gen_action/to_torch"):
observations_ = torch.from_numpy(observations).to(self.device)
observations_ = torch.from_numpy(observations_).to(self.device)
hidden_state_ = torch.from_numpy(hidden_state).to(self.device)
masks_ = torch.from_numpy(masks).to(self.device)

Expand Down Expand Up @@ -137,7 +140,7 @@ def step(self):
for batch in data_sampler:
with KeepTime("/update/step/batches/to_torch"):
# Environment
observations = torch.from_numpy(batch["/observations"]).to(self.device)
observations = torch.from_numpy(batch["/observations"+self.params["observation_path"]]).to(self.device)
masks = torch.from_numpy(batch["/masks"]).to(self.device)
# Agent
hidden_state = torch.from_numpy(batch["/agents/"+self.params["name"]+"/hidden_state"]).to(self.device)
Expand Down
11 changes: 7 additions & 4 deletions digideep/agent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def action_generator(self, observations, hidden_state, masks, deterministic=Fals
dict: ``{"actions":...,"hidden_state":...}``
"""
observations_ = torch.from_numpy(observations).to(self.device)
observation_path = self.params.get("observation_path", "/agent")
observations_ = observations[observation_path].astype(np.float32)

observations_ = torch.from_numpy(observations_).to(self.device)
action = self.policy.generate_actions(observations_, deterministic=deterministic)
action = action.cpu().numpy()

Expand Down Expand Up @@ -132,11 +135,11 @@ def step(self):


with KeepTime("/update/step/to_torch"):
# ['/observations', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/observations_2']
state = torch.from_numpy(batch["/observations"]).to(self.device)
# ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
state = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
action = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
reward = torch.from_numpy(batch["/rewards"]).to(self.device)
next_state = torch.from_numpy(batch["/observations_2"]).to(self.device)
next_state = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
# masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)
masks = torch.from_numpy(batch["/masks"]).to(self.device)

Expand Down
8 changes: 6 additions & 2 deletions digideep/agent/samplers/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def get_sample_memory(buffer, info):
"""
batch_size = info["batch_size"]

num_workers = info["num_workers"]
observation_path = info["observation_path"]

N = info["num_records"] - 1 # We don't want to consider the last "incomplete" record, hence "-1"

masks_arr = buffer["/masks"][:,:N]
Expand All @@ -48,8 +49,11 @@ def get_sample_memory(buffer, info):
batch = {}
for key in buffer:
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]]

observation_path = "/observations" + observation_path
batch["/obs_with_key"] = batch[observation_path]
# Adding predictive keys
batch["/observations_2"] = buffer["/observations"][sample_tabular_2[0],sample_tabular_2[1]]
batch["/obs_with_key_2"] = buffer[observation_path][sample_tabular_2[0],sample_tabular_2[1]]

batch = flatten_first_two(batch)
return batch
Expand Down
9 changes: 7 additions & 2 deletions digideep/environment/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import numpy as np
from collections import OrderedDict
from digideep.utility.logging import logger


Expand Down Expand Up @@ -45,7 +46,7 @@ def flatten_dict(dic, sep="/", prefix=""):
{"/a":1, "/b/c":1, "/b/d/e":2, "/b/d/f":3}
"""
res = {}
res = OrderedDict()
for key, value in dic.items():
if isinstance(value, dict):
tmp = flatten_dict(value, sep=sep, prefix=join_keys(prefix,key,sep))
Expand Down Expand Up @@ -324,7 +325,7 @@ def list_of_dicts_to_flattened_dict_of_lists(List, length):
# This is used for info. But can be used for other list of dicts
if isinstance(List, dict):
return List
Dict = {}
Dict = OrderedDict()
for i in range(len(List)):
update_dict_of_lists(Dict, flatten_dict(List[i]), index=i)
# Here, complete_dict_of_list cannot be in the loop.
Expand All @@ -337,3 +338,7 @@ def list_of_dicts_to_flattened_dict_of_lists(List, length):
complete_dict_of_list(Dict, length=length)
return Dict

def flattened_dict_of_lists_to_dict_of_numpy(dic):
for key in dic:
dic[key] = np.asarray(dic[key], dtype=np.float32)
return dic
3 changes: 2 additions & 1 deletion digideep/environment/dmc2gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
id=gym_id,
entry_point="digideep.environment.dmc2gym.wrapper:DmControlWrapper",
kwargs={'dmcenv_creator':EnvCreatorSuite(domain_name, task_name, task_kwargs=None, environment_kwargs=None, visualize_reward=True),
'flat_observation':True}
'flat_observation':True, # Should be True
'observation_key':"agent"}
)


Expand Down
41 changes: 34 additions & 7 deletions digideep/environment/dmc2gym/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@
import collections
import sys
import copy
import six

from .spec2space import spec2space
from .viewer import Viewer

from dm_control.rl.control import FLAT_OBSERVATION_KEY, flatten_observation, _spec_from_observation
from dm_control.rl.control import flatten_observation
# from dm_control.rl.control import FLAT_OBSERVATION_KEY
from dm_control.rl.control import PhysicsError
from dm_control.rl import specs


def _spec_from_observation(observation):
result = collections.OrderedDict()
for key, value in six.iteritems(observation):
if isinstance(value, collections.OrderedDict):
result[key] = _spec_from_observation(value)
elif isinstance(value, dict):
raise NotImplementedError("'dict' types in observations are not supported as they may not preserve order. Use OrderedDict instead.")
else:
result[key] = specs.ArraySpec(value.shape, value.dtype, name=key)
return result


class DmControlWrapper(Env, EzPickle):
"""Class to convert dm_control environments into gym environments.
Expand All @@ -29,9 +45,10 @@ class DmControlWrapper(Env, EzPickle):
A callable object can delay the creation of the environment until the time we need it.
flat_observation (bool): Whether to flatten the observation dict or not.
"""
def __init__(self, dmcenv_creator, flat_observation=False):
def __init__(self, dmcenv_creator, flat_observation=False, observation_key="agent"):
self.dmcenv = dmcenv_creator()
self._flat_observation = flat_observation
self._observation_key = observation_key
# NOTE: We do not use the following to flatten observation to have more control over flattening and extracting "info".
## The next line will flatten the observations if we really need it.
#### self.dmcenv._flat_observation = self._flat_observation
Expand Down Expand Up @@ -68,6 +85,9 @@ def _delayed_init(self):
command. Only then we can modify those parameters. All of the stuff in the
"self.spec" are from that category; we should wait until make is called on
the environment and then update those at the first time reset is called.
The attributes which are added by the "_delayed_init" function may be used in
wrappers. However, they shouldn't be used in the wrapper initilization.
"""
if self._delayed_init_flag:
return
Expand Down Expand Up @@ -121,10 +141,15 @@ def _get_observation_spec(self):
observation = self.dmcenv.task.get_observation(self.dmcenv.physics)
self._extract_obs_info(observation)
if self._flat_observation:
observation = flatten_observation(observation)
return _spec_from_observation(observation)[FLAT_OBSERVATION_KEY]
else:
return _spec_from_observation(observation)
# observation = flatten_observation(observation)
# return _spec_from_observation(observation)[FLAT_OBSERVATION_KEY]

observation = flatten_observation(observation, output_key=self._observation_key)
# return _spec_from_observation(observation)
# else:
# return _spec_from_observation(observation)
specs = _spec_from_observation(observation)
return specs

def _get_observation(self, timestep):
""" This function will extract the observation from the output of the ``dmcenv.step``'s timestep.
Expand All @@ -134,7 +159,9 @@ def _get_observation(self, timestep):
"""
info = self._extract_obs_info(timestep.observation)
if self._flat_observation:
return flatten_observation(timestep.observation)[FLAT_OBSERVATION_KEY], info
# return flatten_observation(timestep.observation)[FLAT_OBSERVATION_KEY], info

return flatten_observation(timestep.observation, output_key=self._observation_key), info
else:
return timestep.observation, info

Expand Down
51 changes: 34 additions & 17 deletions digideep/environment/explorer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import time
from collections import OrderedDict

from digideep.environment import MakeEnvironment
from .data_helpers import flatten_dict, update_dict_of_lists, complete_dict_of_list, convert_time_to_batch_major, extract_keywise, dict_of_lists_to_list_of_dicts, list_of_dicts_to_flattened_dict_of_lists
from .data_helpers import flatten_dict, update_dict_of_lists, complete_dict_of_list, convert_time_to_batch_major, extract_keywise

# from mujoco_py import MujocoException
# from dm_control.rl.control import PhysicsError

Expand Down Expand Up @@ -89,15 +91,25 @@ def report_rewards(self, infos):
"""This function will extract episode information from infos and will send them to
:class:`~digideep.utility.monitoring.Monitor` class.
"""
for info in infos:
# print(infos)
if '/episode/r' in infos.keys():
rewards = infos['/episode/r']
for rew in rewards:
if not np.isnan(rew):
self.state["n_episode"] += 1
monitor("/explore/reward/"+self.params["mode"], rew)
# print("Everything is here:", episode)
# # for e in
# # self.state["n_episode"] += 1
# # r = info['episode']['r']
# # monitor("/explore/reward/"+self.params["mode"], r)

# for info in infos:
# This episode keyword only exists if we use a Monitor wrapper.
# This keyword will only appear at the "reset" times.
# TODO: If this is a true multi-agent system, then the rewards
# must be separated as well!
if 'episode' in info.keys():
self.state["n_episode"] += 1
r = info['episode']['r']
monitor("/explore/reward/"+self.params["mode"], r)


def close(self):
"""It closes all environments.
Expand All @@ -107,9 +119,8 @@ def close(self):
def reset(self):
"""Will reset the Explorer and all of its states. Will set ``was_reset`` to ``True`` to prevent immediate resets.
"""

self.state["observations"] = self.envs.reset()
self.state["masks"] = np.array([[1]]*self.params["num_workers"], dtype=np.float32)
self.state["masks"] = np.array([[0]]*self.params["num_workers"], dtype=np.float32)

# The initial hidden_state is not saved in the memory. The only use for it is
# getting passed to the action_generator.
Expand All @@ -135,7 +146,9 @@ def prestep(self, final_step=False):

with KeepTime("/explore/step/prestep/to_numpy"):
# TODO: Is it necessary for conversion of obs?
observations = np.array(self.state["observations"], dtype=np.float32)
# NOTE: The np conversion will not work if observation is a dictionary.
# observations = np.array(self.state["observations"], dtype=np.float32)
observations = self.state["observations"]
masks = self.state["masks"]
hidden_state = self.state["hidden_state"]

Expand Down Expand Up @@ -187,16 +200,15 @@ def step(self):

with KeepTime("/explore/step/envstep"):
# Prepare actions
actions_dict = extract_keywise(pre_transition["agents"], "actions")
actions = dict_of_lists_to_list_of_dicts(actions_dict, self.params["num_workers"])
actions = extract_keywise(pre_transition["agents"], "actions")

# Step
self.state["observations"], rewards, dones, infos = self.envs.step(actions)
# Post-step
self.state["hidden_state"] = extract_keywise(pre_transition["agents"], "hidden_state")
self.state["masks"] = np.array([0.0 if done_ else 1.0 for done_ in dones], dtype=np.float32).reshape((-1,1))
rewards = rewards.reshape((-1,1))

# TODO: Adapt with the new dict_of_lists data structure.
with KeepTime("/explore/step/report_reward"):
self.report_rewards(infos)

Expand All @@ -212,17 +224,22 @@ def step(self):
# # return self.run()

with KeepTime("/explore/step/poststep"):
if np.isnan(self.state["observations"]).any():
logger.warn('NaN caught in observations during rollout generation.', 'step =', self.state["steps"])
raise ValueError
# TODO: Sometimes the type of observations is "dict" which shouldn't be. Investigate the reason.
if isinstance(self.state["observations"], OrderedDict) or isinstance(self.state["observations"], dict):
for key in self.state["observations"]:
if np.isnan(self.state["observations"][key]).any():
logger.warn('NaN caught in observations during rollout generation.', 'step =', self.state["steps"])
raise ValueError
else:
if np.isnan(self.state["observations"]).any():
logger.warn('NaN caught in observations during rollout generation.', 'step =', self.state["steps"])
raise ValueError
## Retry??
# return self.run()

self.state["steps"] += 1
self.state["timesteps"] += self.params["num_workers"]

infos = list_of_dicts_to_flattened_dict_of_lists(infos, length=self.params["num_workers"])

transition = dict(**pre_transition,
rewards=rewards,
infos=infos)
Expand Down

0 comments on commit dede9a0

Please sign in to comment.