From 1f403fdc3c39870a00c23588966b9223bc0ae226 Mon Sep 17 00:00:00 2001 From: albert bou Date: Sat, 24 Dec 2022 18:53:21 +0100 Subject: [PATCH 01/13] minor fix --- torchrl/collectors/collectors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 1468babedae..0bf50a9cae8 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -457,7 +457,9 @@ def __init__( # See #505 for additional context. with torch.no_grad(): self._tensordict_out = env.fake_tensordict() + self._tensordict_out = self._tensordict_out.to(self.device) self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1) + self._tensordict_out = self._tensordict_out.to(self.env_device) self._tensordict_out = ( self._tensordict_out.expand(*env.batch_size, self.frames_per_batch) .to_tensordict() From e4b109e1ce338439a41378cdcb30e58958851c6c Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Dec 2022 11:25:42 +0100 Subject: [PATCH 02/13] added tests --- test/test_collector.py | 54 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/test/test_collector.py b/test/test_collector.py index b12e974097a..e103f1fdf7a 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -993,6 +993,60 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe del collector +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) +def test_collector_device_combinations(device, passing_device): + def env_fn(seed): + env = make_make_env("conv")() + env.set_seed(seed) + return env + policy = dummypolicy_conv() + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": 0}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device=device, + passing_device=passing_device, + pin_memory=False, + ) + batch = next(collector.iterator()) + assert batch.device == torch.device(passing_device) or batch["done"].device + collector.shutdown() + + collector = MultiSyncDataCollector( + create_env_fn=[env_fn, ], + create_env_kwargs=[{"seed": 0}, ], + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + devices=[device, ], + passing_devices=[passing_device, ], + pin_memory=False, + ) + batch = next(collector.iterator()) + assert batch.device == torch.device(passing_device) or batch["done"].device + collector.shutdown() + + collector = MultiaSyncDataCollector( + create_env_fn=[env_fn, ], + create_env_kwargs=[{"seed": 0}, ], + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + devices=[device, ], + passing_devices=[passing_device, ], + pin_memory=False, + ) + batch = next(collector.iterator()) + assert batch.device == torch.device(passing_device) or batch["done"].device + collector.shutdown() + @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize( "collector_class", From 4bcef8cd9eea38eea0a523752c3b6ca9a691dcf2 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Tue, 27 Dec 2022 11:27:30 +0100 Subject: [PATCH 03/13] format --- test/test_collector.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index e103f1fdf7a..5436bced773 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1000,6 +1000,7 @@ def env_fn(seed): env = make_make_env("conv")() env.set_seed(seed) return env + policy = dummypolicy_conv() collector = SyncDataCollector( @@ -1018,14 +1019,22 @@ def env_fn(seed): collector.shutdown() collector = MultiSyncDataCollector( - create_env_fn=[env_fn, ], - create_env_kwargs=[{"seed": 0}, ], + create_env_fn=[ + env_fn, + ], + create_env_kwargs=[ + {"seed": 0}, + ], policy=policy, frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[device, ], - passing_devices=[passing_device, ], + devices=[ + device, + ], + passing_devices=[ + passing_device, + ], pin_memory=False, ) batch = next(collector.iterator()) @@ -1033,20 +1042,29 @@ def env_fn(seed): collector.shutdown() collector = MultiaSyncDataCollector( - create_env_fn=[env_fn, ], - create_env_kwargs=[{"seed": 0}, ], + create_env_fn=[ + env_fn, + ], + create_env_kwargs=[ + {"seed": 0}, + ], policy=policy, frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[device, ], - passing_devices=[passing_device, ], + devices=[ + device, + ], + passing_devices=[ + passing_device, + ], pin_memory=False, ) batch = next(collector.iterator()) assert batch.device == torch.device(passing_device) or batch["done"].device collector.shutdown() + @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize( "collector_class", From 9e86208b83dc0e07048717fe6367b74123ec6dbf Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Dec 2022 12:34:30 +0100 Subject: [PATCH 04/13] fixed device placement --- test/test_collector.py | 6 +++--- torchrl/collectors/collectors.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 5436bced773..f7dd6603dd5 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1015,7 +1015,7 @@ def env_fn(seed): pin_memory=False, ) batch = next(collector.iterator()) - assert batch.device == torch.device(passing_device) or batch["done"].device + assert batch.device == torch.device(passing_device) collector.shutdown() collector = MultiSyncDataCollector( @@ -1038,7 +1038,7 @@ def env_fn(seed): pin_memory=False, ) batch = next(collector.iterator()) - assert batch.device == torch.device(passing_device) or batch["done"].device + assert batch.device == torch.device(passing_device) collector.shutdown() collector = MultiaSyncDataCollector( @@ -1061,7 +1061,7 @@ def env_fn(seed): pin_memory=False, ) batch = next(collector.iterator()) - assert batch.device == torch.device(passing_device) or batch["done"].device + assert batch.device == torch.device(passing_device) collector.shutdown() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0bf50a9cae8..402242113e8 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1258,12 +1258,14 @@ def iterator(self) -> Iterator[TensorDictBase]: out_buffer = torch.cat( list(out_tensordicts_shared.values()), 0, out=out_buffer ) + out_buffer = out_buffer.to(prev_device) else: out_buffer = torch.cat( [item.cpu() for item in out_tensordicts_shared.values()], 0, out=out_buffer, ) + out_buffer = out_buffer.to(torch.device("cpu")) if self.split_trajs: out = split_trajectories(out_buffer) From b6532cdc701588b17b2f6fe65915c0bbb9ecc935 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Jan 2023 09:45:00 +0100 Subject: [PATCH 05/13] remove unnecessry check --- torchrl/collectors/collectors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 24ba39c8b1c..3a288c589ef 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1285,14 +1285,12 @@ def iterator(self) -> Iterator[TensorDictBase]: out_buffer = torch.cat( list(out_tensordicts_shared.values()), 0, out=out_buffer ) - out_buffer = out_buffer.to(prev_device) else: out_buffer = torch.cat( [item.cpu() for item in out_tensordicts_shared.values()], 0, out=out_buffer, ) - out_buffer = out_buffer.to(torch.device("cpu")) if self.split_trajs: out = split_trajectories(out_buffer) From ff8ac3b88d35115c2d3c1c16a92dafdf13cd828f Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Jan 2023 10:34:27 +0100 Subject: [PATCH 06/13] tests fix --- test/test_collector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index 0602355fa72..9f10283ee40 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -796,7 +796,7 @@ def test_collector_vecnorm_envcreator(static_seed): @pytest.mark.parametrize("use_async", [False, True]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") +@pytest.mark.skipif(not torch.cuda.is_availabltorch.cuda.is_availablee(), reason="no cuda device found") def test_update_weights(use_async): def create_env(): return ContinuousActionVecMockEnv() @@ -992,6 +992,11 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe @pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) def test_collector_device_combinations(device, passing_device): + + if (device == "cuda" or "passing_device" == "cuda") and \ + not torch.cuda.is_availabltorch.cuda.is_availablee(): + pytest.skip("no cuda device found") + def env_fn(seed): env = make_make_env("conv")() env.set_seed(seed) From 6e93a8456b3a357dfb5a31387247c6be6100c60f Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 2 Jan 2023 10:40:07 +0100 Subject: [PATCH 07/13] format --- test/test_collector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 9f10283ee40..6497b5eee37 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -796,7 +796,7 @@ def test_collector_vecnorm_envcreator(static_seed): @pytest.mark.parametrize("use_async", [False, True]) -@pytest.mark.skipif(not torch.cuda.is_availabltorch.cuda.is_availablee(), reason="no cuda device found") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_update_weights(use_async): def create_env(): return ContinuousActionVecMockEnv() @@ -993,8 +993,9 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe @pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) def test_collector_device_combinations(device, passing_device): - if (device == "cuda" or "passing_device" == "cuda") and \ - not torch.cuda.is_availabltorch.cuda.is_availablee(): + if ( + device == "cuda" or "passing_device" == "cuda" + ) and not torch.cuda.is_available(): pytest.skip("no cuda device found") def env_fn(seed): From ea25c522cf7f493304f965ae0d129df50874dc2f Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Jan 2023 11:37:04 +0100 Subject: [PATCH 08/13] tests fix --- test/test_collector.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 6497b5eee37..d1c4232696f 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -989,15 +989,9 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe del collector -@pytest.mark.parametrize("device", ["cuda", "cpu"]) -@pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) +@pytest.mark.parametrize("device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) +@pytest.mark.parametrize("passing_device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) def test_collector_device_combinations(device, passing_device): - - if ( - device == "cuda" or "passing_device" == "cuda" - ) and not torch.cuda.is_available(): - pytest.skip("no cuda device found") - def env_fn(seed): env = make_make_env("conv")() env.set_seed(seed) From 06f95f4a8226050136f1b424244575b2f8eca9d8 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 2 Jan 2023 11:38:43 +0100 Subject: [PATCH 09/13] format --- test/test_collector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index d1c4232696f..ba4f1180cf8 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -989,8 +989,12 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe del collector -@pytest.mark.parametrize("device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) -@pytest.mark.parametrize("passing_device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) +@pytest.mark.parametrize( + "device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] +) +@pytest.mark.parametrize( + "passing_device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] +) def test_collector_device_combinations(device, passing_device): def env_fn(seed): env = make_make_env("conv")() From 0bc93d10af1ac76480d4da921c673226d21f7e84 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Jan 2023 12:40:47 +0100 Subject: [PATCH 10/13] skip only cpu test --- test/test_collector.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index ba4f1180cf8..aa2c0218773 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -989,12 +989,9 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe del collector -@pytest.mark.parametrize( - "device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] -) -@pytest.mark.parametrize( - "passing_device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] -) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize(["cuda", "cpu"]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_collector_device_combinations(device, passing_device): def env_fn(seed): env = make_make_env("conv")() From ed4a051c3bbce9ef55f3088da9590c993267aae6 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 2 Jan 2023 12:41:55 +0100 Subject: [PATCH 11/13] format --- torchrl/objectives/ppo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 0ab4924e271..44de4735b33 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -126,6 +126,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: state_value, loss_function=self.loss_critic_type, ) + # loss_value = (state_value - target_return).pow(2) + except KeyError: raise KeyError( f"the key {self.value_target_key} was not found in the input tensordict. " From 33bdf021383d2f5da30f68535e78ff0f2c02fd24 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 2 Jan 2023 12:45:34 +0100 Subject: [PATCH 12/13] fix --- test/test_collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index aa2c0218773..fe40e974ff4 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -990,7 +990,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe @pytest.mark.parametrize("device", ["cuda", "cpu"]) -@pytest.mark.parametrize(["cuda", "cpu"]) +@pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_collector_device_combinations(device, passing_device): def env_fn(seed): From 106838fdfa67916aab5fa7b196d7fab751f93e1f Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Jan 2023 12:51:04 +0100 Subject: [PATCH 13/13] fix --- torchrl/objectives/ppo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 44de4735b33..0ab4924e271 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -126,8 +126,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: state_value, loss_function=self.loss_critic_type, ) - # loss_value = (state_value - target_return).pow(2) - except KeyError: raise KeyError( f"the key {self.value_target_key} was not found in the input tensordict. "