Skip to content

Commit

Permalink
Update metric: use numpify, remove redundant utils
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Jul 1, 2019
1 parent 0821079 commit 5174394
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 33 deletions.
9 changes: 4 additions & 5 deletions lagom/metric/returns.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np

from lagom.transform import geometric_cumsum

from .utils import _wrap_last_V
from lagom.utils import numpify


def returns(gamma, traj):
Expand All @@ -23,10 +22,10 @@ def bootstrapped_returns(gamma, traj, last_V):
The state values for terminal states are masked out as zero !
"""
last_V = _wrap_last_V(last_V)
last_V = numpify(last_V, np.float32).item()

if traj.reach_terminal:
out = geometric_cumsum(gamma, traj.rewards + [0.0])
out = geometric_cumsum(gamma, np.append(traj.rewards, 0.0))
else:
out = geometric_cumsum(gamma, traj.rewards + [last_V])
out = geometric_cumsum(gamma, np.append(traj.rewards, last_V))
return out[0, :-1].astype(np.float32)
11 changes: 5 additions & 6 deletions lagom/metric/td.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

from .utils import _wrap_Vs
from .utils import _wrap_last_V
from lagom.utils import numpify


def td0_target(gamma, traj, Vs, last_V):
Expand All @@ -19,8 +18,8 @@ def td0_target(gamma, traj, Vs, last_V):
The state values for terminal states are masked out as zero !
"""
Vs = _wrap_Vs(Vs)
last_V = _wrap_last_V(last_V)
Vs = numpify(Vs, np.float32)
last_V = numpify(last_V, np.float32)

if traj.reach_terminal:
Vs = np.append(Vs, 0.0)
Expand All @@ -45,8 +44,8 @@ def td0_error(gamma, traj, Vs, last_V):
The state values for terminal states are masked out as zero !
"""
Vs = _wrap_Vs(Vs)
last_V = _wrap_last_V(last_V)
Vs = numpify(Vs, np.float32)
last_V = numpify(last_V, np.float32)

if traj.reach_terminal:
Vs = np.append(Vs, 0.0)
Expand Down
22 changes: 0 additions & 22 deletions lagom/metric/utils.py

This file was deleted.

0 comments on commit 5174394

Please sign in to comment.