|
31 | 31 | CatTensors,
|
32 | 32 | FlattenObservation,
|
33 | 33 | RewardScaling,
|
| 34 | + BinarizeReward, |
34 | 35 | )
|
35 | 36 | from torchrl.envs.libs.gym import _has_gym
|
36 | 37 | from torchrl.envs.transforms import VecNorm, TransformedEnv
|
@@ -874,8 +875,27 @@ def test_noop_reset_env(self, random, device, compose):
|
874 | 875 | assert transformed_env.step_count == 30
|
875 | 876 |
|
876 | 877 | @pytest.mark.parametrize("device", get_available_devices())
|
877 |
| - def test_binerized_reward(self, device): |
878 |
| - pass |
| 878 | + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) |
| 879 | + def test_binarized_reward(self, device, batch): |
| 880 | + torch.manual_seed(0) |
| 881 | + br = BinarizeReward() |
| 882 | + reward = torch.randn(*batch, 1, device=device) |
| 883 | + reward_copy = reward.clone() |
| 884 | + misc = torch.randn(*batch, 1, device=device) |
| 885 | + misc_copy = misc.clone() |
| 886 | + |
| 887 | + td = TensorDict( |
| 888 | + { |
| 889 | + "misc": misc, |
| 890 | + "reward": reward, |
| 891 | + }, |
| 892 | + batch, |
| 893 | + ) |
| 894 | + br(td) |
| 895 | + assert td["reward"] is reward |
| 896 | + assert (td["reward"] != reward_copy).all() |
| 897 | + assert (td["misc"] == misc_copy).all() |
| 898 | + assert (torch.count_nonzero(td["reward"]) == torch.sum(reward_copy > 0)).all() |
879 | 899 |
|
880 | 900 | @pytest.mark.parametrize("batch", [[], [2], [2, 4]])
|
881 | 901 | @pytest.mark.parametrize("scale", [0.1, 10])
|
|
0 commit comments