diff --git a/test/test_transforms.py b/test/test_transforms.py index a597d5afc4b..18e84c53bb4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -31,6 +31,7 @@ CatTensors, FlattenObservation, RewardScaling, + BinarizeReward, ) from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import VecNorm, TransformedEnv @@ -874,8 +875,27 @@ def test_noop_reset_env(self, random, device, compose): assert transformed_env.step_count == 30 @pytest.mark.parametrize("device", get_available_devices()) - def test_binerized_reward(self, device): - pass + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) + def test_binarized_reward(self, device, batch): + torch.manual_seed(0) + br = BinarizeReward() + reward = torch.randn(*batch, 1, device=device) + reward_copy = reward.clone() + misc = torch.randn(*batch, 1, device=device) + misc_copy = misc.clone() + + td = TensorDict( + { + "misc": misc, + "reward": reward, + }, + batch, + ) + br(td) + assert td["reward"] is reward + assert (td["reward"] != reward_copy).all() + assert (td["misc"] == misc_copy).all() + assert (torch.count_nonzero(td["reward"]) == torch.sum(reward_copy > 0)).all() @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("scale", [0.1, 10])