Skip to content

Commit

Permalink
Merge pull request #7 from cnheider/develop
Browse files Browse the repository at this point in the history
sanitize sanity
  • Loading branch information
cnheider committed Oct 1, 2020
2 parents 2abcdbe + 650952c commit ddb899a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
signal_space: SignalSpace,
policy_weight: float,
weight: float,
intrinsic_reward_integration: float,
intrinsic_signal_factor: float,
hidden_dim: int = 128,
):
"""
Expand All @@ -134,7 +134,7 @@ def __init__(
:param signal_space: used for scaling the intrinsic reward returned by this module. Can be used to control how
the fluctuation scale of the intrinsic signal
:param weight: balances the importance between forward and inverse model
:param intrinsic_reward_integration: balances the importance between extrinsic and intrinsic signal.
:param intrinsic_signal_factor: balances the importance between extrinsic and intrinsic signal.
"""

assert (
Expand All @@ -148,7 +148,7 @@ def __init__(
self.policy_weight = policy_weight
self.reward_scale = signal_space.span
self.weight = weight
self.intrinsic_signal_integration = intrinsic_reward_integration
self.intrinsic_signal_factor = intrinsic_signal_factor

self.encoder = nn.Sequential(
nn.Linear(observation_space.shape[0], hidden_dim),
Expand Down Expand Up @@ -233,8 +233,8 @@ def sample(
writer.scalar("icm/signal", intrinsic_signal.mean().item())

return (
1.0 - self.intrinsic_signal_integration
) * signals + self.intrinsic_signal_integration * intrinsic_signal
1.0 - self.intrinsic_signal_factor
) * signals + self.intrinsic_signal_factor * intrinsic_signal

def loss(
self,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import pytest

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 01/08/2020
"""

def test_sanity():
assert True
Expand All @@ -21,7 +24,7 @@ def test_print(capsys):
print(text)
sys.stderr.write("world")
captured = capsys.readouterr()
assert text in captured.head
assert text in captured.out
assert err in captured.err


Expand Down

0 comments on commit ddb899a

Please sign in to comment.