From f73be7c6900f2c69bf8a1eb85fa8440f6105c7d7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Oct 2025 14:30:37 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_weightsync.py | 24 ++++++++++++------------ torchrl/envs/libs/gym.py | 7 ++++++- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 5da43accf7c..f5a4515f224 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -9,11 +9,11 @@ import pytest import torch import torch.nn as nn +from mocking_classes import ContinuousActionVecMockEnv from tensordict import TensorDict from tensordict.nn import TensorDictModule from torch import multiprocessing as mp from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.envs import GymEnv from torchrl.weight_update.weight_sync_schemes import ( _resolve_model, MPTransport, @@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self): class TestCollectorIntegration: @pytest.fixture def simple_env(self): - return GymEnv("CartPole-v1") + return ContinuousActionVecMockEnv() @pytest.fixture def simple_policy(self, simple_env): @@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy): scheme = MultiProcessWeightSyncScheme(strategy="state_dict") collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + create_env_fn=ContinuousActionVecMockEnv, policy=simple_policy, frames_per_batch=64, total_frames=128, @@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): collector = MultiSyncDataCollector( create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), + ContinuousActionVecMockEnv, + ContinuousActionVecMockEnv, ], policy=simple_policy, frames_per_batch=64, @@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): collector = MultiSyncDataCollector( create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), + ContinuousActionVecMockEnv, + ContinuousActionVecMockEnv, ], policy=simple_policy, frames_per_batch=64, @@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy): scheme = NoWeightSyncScheme() collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + create_env_fn=ContinuousActionVecMockEnv, policy=simple_policy, frames_per_batch=64, total_frames=128, @@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy): class TestMultiModelUpdates: def test_multi_model_state_dict_updates(self): - env = GymEnv("CartPole-v1") + env = ContinuousActionVecMockEnv() policy = TensorDictModule( nn.Linear( @@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self): } collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + create_env_fn=ContinuousActionVecMockEnv, policy=policy, frames_per_batch=64, total_frames=128, @@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self): env.close() def test_multi_model_tensordict_updates(self): - env = GymEnv("CartPole-v1") + env = ContinuousActionVecMockEnv() policy = TensorDictModule( nn.Linear( @@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self): } collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + create_env_fn=ContinuousActionVecMockEnv, policy=policy, frames_per_batch=64, total_frames=128, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index a5c43d2ad23..ee34ab87185 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1255,7 +1255,12 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811 @property def lib(self) -> ModuleType: - return gym_backend() + gym = gym_backend() + if gym is None: + raise RuntimeError( + "Gym backend is not available. Please install gym or gymnasium." + ) + return gym def _set_seed(self, seed: int | None) -> None: # noqa: F811 if self._seed_calls_reset is None: