From 77220de89b6c522fa1a698e2e78f2342866843d4 Mon Sep 17 00:00:00 2001 From: Srikanth MG Date: Wed, 20 Jul 2022 10:09:17 +0200 Subject: [PATCH 1/3] Added test for BinarizeReward --- test/test_transforms.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index a597d5afc4b..c4086d4228d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -31,6 +31,7 @@ CatTensors, FlattenObservation, RewardScaling, + BinarizeReward ) from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import VecNorm, TransformedEnv @@ -874,8 +875,26 @@ def test_noop_reset_env(self, random, device, compose): assert transformed_env.step_count == 30 @pytest.mark.parametrize("device", get_available_devices()) - def test_binerized_reward(self, device): - pass + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) + def test_binarized_reward(self, device, batch): + torch.manual_seed(0) + br = BinarizeReward() + reward = torch.randn(*batch, 1, device=device) + reward_copy = reward.clone() + misc = torch.randn(*batch, 1, device=device) + misc_copy = misc.clone() + + td = TensorDict( + { + "misc": misc, + "reward": reward, + }, + batch, + ) + br(td) + assert td["reward"] is reward + assert (td["reward"] != reward_copy).all() + assert (td["misc"] == misc_copy).all() @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("scale", [0.1, 10]) From fe4400c358a1eecc80eed0af281653ea40559c2d Mon Sep 17 00:00:00 2001 From: Srikanth MG Date: Wed, 20 Jul 2022 10:21:21 +0200 Subject: [PATCH 2/3] Fixed Lint errors --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c4086d4228d..119bac4f567 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -31,7 +31,7 @@ CatTensors, FlattenObservation, RewardScaling, - BinarizeReward + BinarizeReward, ) from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import VecNorm, TransformedEnv From 1833a3dbf39afdc635fd4c357ff172b974a1c02f Mon Sep 17 00:00:00 2001 From: Srikanth MG Date: Wed, 20 Jul 2022 13:36:13 +0200 Subject: [PATCH 3/3] Added test for binarization of reward --- test/test_transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 119bac4f567..18e84c53bb4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -895,6 +895,7 @@ def test_binarized_reward(self, device, batch): assert td["reward"] is reward assert (td["reward"] != reward_copy).all() assert (td["misc"] == misc_copy).all() + assert (torch.count_nonzero(td["reward"]) == torch.sum(reward_copy > 0)).all() @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("scale", [0.1, 10])