Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
83 lines (70 sloc) 2.15 KB
import gym
import numpy as np
class SavableEnv(gym.Wrapper):
def __init__(self, env):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def step(self, action):
return self.env.step(action)
def save(self):
self._checkpoint = self.env.unwrapped.clone_state()
return self._checkpoint
def load(self, checkpoint=None):
if checkpoint:
self._checkpoint = checkpoint
class ObsTreeNode:
def __init__(self, env, height=0, parent=None,
action=0, reward=0, terminal=False, info=None,
total_reward=0, tree_height=-1):
if not isinstance(env, gym.envs.atari.atari_env.AtariEnv):
env = env.unwrapped
self.height = height
if tree_height < 0:
self.tree_height = self.height
self.tree_height = tree_height
self.parent = parent
self.is_root = parent is None
self.is_child = not self.is_root
self.n_children = env.action_space.n = 0
if self.is_child: += * self.n_children + 1 + action
self.action = action # edge connected from parent
self.reward = reward
self.total_reward = total_reward + reward
self.terminal = terminal
self.state = None
self.obs = None
if not terminal:
self.state = env.clone_state()
self.obs = env._get_obs() = info
self.children = None
if not terminal and height > 0:
self.is_leaf = self.children is None
def _extends(self, env):
self.children = []
for a in range(self.n_children):
_, reward, terminal, info = env.step(a)
child = ObsTreeNode(env, self.height - 1, self,
a, reward, terminal, info,
self.total_reward, self.tree_height)
class ObsTreeEnv(gym.Wrapper):
def __init__(self, env, depth=0):
self.depth = depth
def reset(self, **kwargs):
_ = self.env.reset(**kwargs)
return self._get_tree()
def step(self, action):
_, reward, terminal, info = self.env.step(action)
return self._get_tree(), reward, terminal, info
def _get_tree(self):
return ObsTreeNode(env, self.depth)
You can’t perform that action at this time.