Skip to content

Commit

Permalink
Update Trajectory and its docs and test
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Sep 5, 2018
1 parent c44c009 commit 88e2511
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 67 deletions.
45 changes: 45 additions & 0 deletions lagom/runner/base_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def all_returns(self):
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
"""
raise NotImplementedError

Expand All @@ -105,6 +109,10 @@ def all_discounted_returns(self):
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
"""
raise NotImplementedError

Expand All @@ -118,10 +126,18 @@ def all_bootstrapped_returns(self):
.. math::
Q_t = r_t + r_{t+1} + \dots + r_T + V(s_{T+1})
.. note::
The state value for terminal state is set as zero !
.. note::
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
"""
raise NotImplementedError

Expand All @@ -134,10 +150,18 @@ def all_bootstrapped_discounted_returns(self):
.. math::
Q_t = r_t + \gamma r_{t+1} + \dots + \gamma^{T - t} r_T + \gamma^{T - t + 1} V(s_{T+1})
.. note::
The state value for terminal state is set as zero !
.. note::
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
"""
raise NotImplementedError

Expand All @@ -149,6 +173,11 @@ def all_V(self):
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns Tensor dtype, used for backprop to train value function. It does not set
zero value for terminal state !
"""
raise NotImplementedError

Expand All @@ -162,20 +191,36 @@ def all_TD(self):
.. math::
\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
.. note::
The state value for terminal state is set as zero !
.. note::
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
"""
raise NotImplementedError

def all_GAE(self, gae_lambda):
r"""Return a list of all `generalized advantage estimates`_ (GAE) in the history including
the terminal states.
.. note::
The state value for terminal state is set as zero !
.. note::
This behaves differently for :class:`Trajectory` and :class:`Segment`.
.. note::
It returns raw values instead of Tensor dtype, not to be used for backprop.
.. _generalized advantage estimates:
https://arxiv.org/abs/1506.02438
Expand Down
79 changes: 61 additions & 18 deletions lagom/runner/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ class Trajectory(BaseHistory):
Example::
>>> transition1 = Transition(s=1, a=0.1, r=0.5, s_next=2, done=False)
>>> transition1.add_info(name='V_s', value=10.0)
>>> transition2 = Transition(s=2, a=0.2, r=0.5, s_next=3, done=False)
>>> transition2.add_info(name='V_s', value=20.0)
>>> transition3 = Transition(s=3, a=0.3, r=1.0, s_next=4, done=True)
>>> transition3.add_info(name='V_s', value=30.0)
>>> transition3.add_info(name='V_s_next', value=40.0)
>>> trajectory = Trajectory(gamma=0.1)
>>> trajectory.add_transition(transition1)
>>> trajectory.add_transition(transition2)
>>> trajectory.add_transition(transition3)
>>> trajectory.all_discounted_returns
[0.56, 0.6, 1.0]
>>> trajectory.all_TD
[-7.5, -16.5, -29.0]
"""
def add_transition(self, transition):
# Sanity check for trajectory
Expand All @@ -41,34 +62,55 @@ def all_returns(self):
def all_discounted_returns(self):
return ExpFactorCumSum(self.gamma)(self.all_r)

@property
def all_V(self):
return [transition.V_s for transition in self.transitions] + [self.transitions[-1].V_s_next]
def _rewards_with_bootstrapping(self):
# Get last state value and last done
last_V = self.transitions[-1].V_s_next
last_done = self.transitions[-1].done
# Get raw value if Tensor dtype
if torch.is_tensor(last_V):
last_V = last_V.item()
assert isinstance(last_V, float), f'expected float dtype, got {type(last_V)}'

# Set zero value if terminal state
if last_done:
last_V = 0.0

return self.all_r + [last_V]

@property
def all_TD(self):
r"""
Return a list of TD errors for all time steps.
def all_bootstrapped_returns(self):
bootstrapped_rewards = self._rewards_with_bootstrapping()

It requires that each transition has the information with key 'V_s' and
last transition with both 'V_s' and 'V_s_next'.
out = ExpFactorCumSum(1.0)(bootstrapped_rewards)
# Take out last one, because it is just last state value itself
out = out[:-1]

If last transition with done=True, then V_s_next should be zero as terminal state value.
return out

TD error is calculated as following:
\delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t)
@property
def all_bootstrapped_discounted_returns(self):
bootstrapped_rewards = self._rewards_with_bootstrapping()

Note that we would like to use raw float dtype, rather than Tensor.
Because we often do not backprop via TD error.
"""
out = ExpFactorCumSum(self.gamma)(bootstrapped_rewards)
# Take out last one, because it is just last state value itself
out = out[:-1]

return out

@property
def all_V(self):
return [transition.V_s for transition in self.transitions] + [self.transitions[-1].V_s_next]

@property
def all_TD(self):
# Get all rewards
all_r = np.array(self.all_r)

# Get all state values
# Retrieve raw value if dtype is Tensor
# Get all state values with raw values if Tensor dtype
all_V = np.array([v.item() if torch.is_tensor(v) else v for v in self.all_V])
if self.all_done[-1]: # value of terminal state is zero
assert all_V[-1] == 0.0
# Set last state value as zero if terminal state
if self.all_done[-1]:
all_V[-1] = 0.0

# Unpack state values into current and next time step
all_V_s = all_V[:-1]
Expand All @@ -81,4 +123,5 @@ def all_TD(self):

@property
def all_GAE(self, gae_lambda):
# TODO: implement it + add to test_runner
raise NotImplementedError
117 changes: 68 additions & 49 deletions test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,74 @@ def test_transition(self):
assert transition.V_s_next == 10.0
assert np.allclose(transition.info['extra'], [1, 2, 3, 4])

def test_trajectory(self):
transition1 = Transition(s=1,
a=0.1,
r=0.5,
s_next=2,
done=False)
transition1.add_info(name='V_s', value=10.0)

transition2 = Transition(s=2,
a=0.2,
r=0.5,
s_next=3,
done=False)
transition2.add_info(name='V_s', value=20.0)

transition3 = Transition(s=3,
a=0.3,
r=1.0,
s_next=4,
done=True)
transition3.add_info(name='V_s', value=30.0)
transition3.add_info(name='V_s_next', value=40.0) # Note that here non-zero value

trajectory = Trajectory(gamma=0.1)

assert trajectory.gamma == 0.1
assert len(trajectory.info) == 0
assert trajectory.T == 0

trajectory.add_info(name='extra', value=[1, 2, 3])
assert len(trajectory.info) == 1
assert np.allclose(trajectory.info['extra'], [1, 2, 3])

trajectory.add_transition(transition=transition1)
trajectory.add_transition(transition=transition2)
trajectory.add_transition(transition=transition3)

assert trajectory.T == 3

# Test error to add one more transition, not allowed because last transition already done=True
transition4 = Transition(s=0.1,
a=0.1,
r=1.0,
s_next=0.2,
done=False)
with pytest.raises(AssertionError):
trajectory.add_transition(transition=transition4)

assert np.allclose(trajectory.all_s, [1, 2, 3, 4])
assert np.allclose(trajectory.all_a, [0.1, 0.2, 0.3])
assert np.allclose(trajectory.all_r, [0.5, 0.5, 1.0])
assert np.allclose(trajectory.all_done, [False, False, True])
assert np.allclose(trajectory.all_returns, [2.0, 1.5, 1.0])
assert np.allclose(trajectory.all_discounted_returns, [0.56, 0.6, 1.0])
assert np.allclose(trajectory.all_bootstrapped_returns, [2.0, 1.5, 1.0])
assert np.allclose(trajectory.all_bootstrapped_discounted_returns, [0.56, 0.6, 1.0])
assert np.allclose(trajectory.all_V, [10, 20, 30, 40])
assert np.allclose(trajectory.all_TD, [-7.5, -16.5, -29])
assert np.allclose(trajectory.all_info(name='V_s'), [10, 20, 30])

# Make last transition: done=False
trajectory.transitions[-1].done = False
assert np.allclose(trajectory.all_done, [False, False, False])
assert np.allclose(trajectory.all_bootstrapped_returns, [42, 41.5, 41])
assert np.allclose(trajectory.all_bootstrapped_discounted_returns, [0.6, 1, 5])
assert np.allclose(trajectory.all_V, [10, 20, 30, 40])
assert np.allclose(trajectory.all_TD, [-7.5, -16.5, -25])

def test_segment(self):
# All test cases in the following use 4 transitions
# states: 10, 20, ...
Expand Down Expand Up @@ -355,55 +423,6 @@ def test_segment(self):
assert np.allclose(segment.all_V, [100, 150, 200, 300, 400, 500])
assert np.allclose(segment.all_TD, [-99, -168, -257, -396])

def test_trajectory(self):
transition1 = Transition(s=1,
a=0.1,
r=0.5,
s_next=2,
done=False)
transition1.add_info(name='V_s', value=10)

transition2 = Transition(s=2,
a=0.2,
r=0.5,
s_next=3,
done=False)
transition2.add_info(name='V_s', value=20)

transition3 = Transition(s=3,
a=0.3,
r=1.0,
s_next=4,
done=True)
transition3.add_info(name='V_s', value=30)
transition3.add_info(name='V_s_next', value=0.0)

trajectory = Trajectory(gamma=0.1)

assert trajectory.gamma == 0.1
assert len(trajectory.info) == 0
assert trajectory.T == 0

trajectory.add_info(name='extra', value=[1, 2, 3])
assert len(trajectory.info) == 1
assert np.allclose(trajectory.info['extra'], [1, 2, 3])

trajectory.add_transition(transition=transition1)
trajectory.add_transition(transition=transition2)
trajectory.add_transition(transition=transition3)

assert trajectory.T == 3

assert np.allclose(trajectory.all_s, [1, 2, 3, 4])
assert np.allclose(trajectory.all_a, [0.1, 0.2, 0.3])
assert np.allclose(trajectory.all_r, [0.5, 0.5, 1.0])
assert np.allclose(trajectory.all_done, [False, False, True])
assert np.allclose(trajectory.all_returns, [2.0, 1.5, 1.0])
assert np.allclose(trajectory.all_discounted_returns, [0.56, 0.6, 1.0])
assert np.allclose(trajectory.all_V, [10, 20, 30, 0])
assert np.allclose(trajectory.all_TD, [-7.5, -16.5, -29])
assert np.allclose(trajectory.all_info(name='V_s'), [10, 20, 30])

def test_trajectoryrunner(self):
def helper(agent, env):
env = GymEnv(env)
Expand Down

0 comments on commit 88e2511

Please sign in to comment.