Skip to content

Commit 0dd4594

Browse files
authored
[Test] tests for BinarizeReward (#302)
1 parent d449ede commit 0dd4594

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

test/test_transforms.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
CatTensors,
3232
FlattenObservation,
3333
RewardScaling,
34+
BinarizeReward,
3435
)
3536
from torchrl.envs.libs.gym import _has_gym
3637
from torchrl.envs.transforms import VecNorm, TransformedEnv
@@ -874,8 +875,27 @@ def test_noop_reset_env(self, random, device, compose):
874875
assert transformed_env.step_count == 30
875876

876877
@pytest.mark.parametrize("device", get_available_devices())
877-
def test_binerized_reward(self, device):
878-
pass
878+
@pytest.mark.parametrize("batch", [[], [4], [6, 4]])
879+
def test_binarized_reward(self, device, batch):
880+
torch.manual_seed(0)
881+
br = BinarizeReward()
882+
reward = torch.randn(*batch, 1, device=device)
883+
reward_copy = reward.clone()
884+
misc = torch.randn(*batch, 1, device=device)
885+
misc_copy = misc.clone()
886+
887+
td = TensorDict(
888+
{
889+
"misc": misc,
890+
"reward": reward,
891+
},
892+
batch,
893+
)
894+
br(td)
895+
assert td["reward"] is reward
896+
assert (td["reward"] != reward_copy).all()
897+
assert (td["misc"] == misc_copy).all()
898+
assert (torch.count_nonzero(td["reward"]) == torch.sum(reward_copy > 0)).all()
879899

880900
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
881901
@pytest.mark.parametrize("scale", [0.1, 10])

0 commit comments

Comments
 (0)