Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions test/test_weightsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import pytest
import torch
import torch.nn as nn
from mocking_classes import ContinuousActionVecMockEnv
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import multiprocessing as mp
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.weight_update.weight_sync_schemes import (
_resolve_model,
MPTransport,
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self):
class TestCollectorIntegration:
@pytest.fixture
def simple_env(self):
return GymEnv("CartPole-v1")
return ContinuousActionVecMockEnv()

@pytest.fixture
def simple_policy(self, simple_env):
Expand All @@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")

collector = SyncDataCollector(
create_env_fn=lambda: GymEnv("CartPole-v1"),
create_env_fn=ContinuousActionVecMockEnv,
policy=simple_policy,
frames_per_batch=64,
total_frames=128,
Expand All @@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):

collector = MultiSyncDataCollector(
create_env_fn=[
lambda: GymEnv("CartPole-v1"),
lambda: GymEnv("CartPole-v1"),
ContinuousActionVecMockEnv,
ContinuousActionVecMockEnv,
],
policy=simple_policy,
frames_per_batch=64,
Expand All @@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy):

collector = MultiSyncDataCollector(
create_env_fn=[
lambda: GymEnv("CartPole-v1"),
lambda: GymEnv("CartPole-v1"),
ContinuousActionVecMockEnv,
ContinuousActionVecMockEnv,
],
policy=simple_policy,
frames_per_batch=64,
Expand All @@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy):
scheme = NoWeightSyncScheme()

collector = SyncDataCollector(
create_env_fn=lambda: GymEnv("CartPole-v1"),
create_env_fn=ContinuousActionVecMockEnv,
policy=simple_policy,
frames_per_batch=64,
total_frames=128,
Expand All @@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy):

class TestMultiModelUpdates:
def test_multi_model_state_dict_updates(self):
env = GymEnv("CartPole-v1")
env = ContinuousActionVecMockEnv()

policy = TensorDictModule(
nn.Linear(
Expand All @@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self):
}

collector = SyncDataCollector(
create_env_fn=lambda: GymEnv("CartPole-v1"),
create_env_fn=ContinuousActionVecMockEnv,
policy=policy,
frames_per_batch=64,
total_frames=128,
Expand Down Expand Up @@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self):
env.close()

def test_multi_model_tensordict_updates(self):
env = GymEnv("CartPole-v1")
env = ContinuousActionVecMockEnv()

policy = TensorDictModule(
nn.Linear(
Expand All @@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self):
}

collector = SyncDataCollector(
create_env_fn=lambda: GymEnv("CartPole-v1"),
create_env_fn=ContinuousActionVecMockEnv,
policy=policy,
frames_per_batch=64,
total_frames=128,
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,12 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811

@property
def lib(self) -> ModuleType:
return gym_backend()
gym = gym_backend()
if gym is None:
raise RuntimeError(
"Gym backend is not available. Please install gym or gymnasium."
)
return gym

def _set_seed(self, seed: int | None) -> None: # noqa: F811
if self._seed_calls_reset is None:
Expand Down
Loading