From c8ac5a30e7990caff92f55dd6b5579a3de62b4a6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 16 Oct 2025 23:12:46 +0200 Subject: [PATCH 1/3] Test --- test/test_libs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8ade2ad056b..49a62adf5a8 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2790,7 +2790,7 @@ class TestVmas: @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) @pytest.mark.parametrize("continuous_actions", [True, False]) def test_all_vmas_scenarios(self, scenario_name, continuous_actions): - + env = VmasEnv( scenario=scenario_name, continuous_actions=continuous_actions, @@ -2814,9 +2814,13 @@ def test_vmas_seeding(self, scenario_name): scenario=scenario_name, num_envs=4, ) + + def policy(td, env=env): + return env.action_spec.zero() + final_seed.append(env.set_seed(0)) tdreset.append(env.reset()) - tdrollout.append(env.rollout(max_steps=10)) + tdrollout.append(env.rollout(max_steps=10, policy=policy)) env.close() del env assert final_seed[0] == final_seed[1] From e3909c5d337fa9b1e24a91b320a4664617683e43 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 16 Oct 2025 23:24:20 +0200 Subject: [PATCH 2/3] Test --- test/test_libs.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 49a62adf5a8..34e4ec0dbc3 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -5,6 +5,7 @@ from __future__ import annotations import collections +import copy import functools import gc import importlib.util @@ -2809,14 +2810,24 @@ def test_vmas_seeding(self, scenario_name): final_seed = [] tdreset = [] tdrollout = [] - for _ in range(2): - env = VmasEnv( + rollout_length = 10 + + def create_env(): + return VmasEnv( scenario=scenario_name, num_envs=4, ) - def policy(td, env=env): - return env.action_spec.zero() + env = create_env() + td_actions = [env.action_spec.rand() for _ in range(rollout_length)] + td_actions_buffer = copy.deepcopy(td_actions) + + for _ in range(2): + env = create_env() + td_actions_buffer = copy.deepcopy(td_actions) + + def policy(td, actions=td_actions_buffer): + return actions.pop(0) final_seed.append(env.set_seed(0)) tdreset.append(env.reset()) From a6770b5e2d1d7c3d9d596b2818d613f4d9fb8de3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 16 Oct 2025 23:25:07 +0200 Subject: [PATCH 3/3] Test --- test/test_libs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 34e4ec0dbc3..69eb8029de7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2820,7 +2820,6 @@ def create_env(): env = create_env() td_actions = [env.action_spec.rand() for _ in range(rollout_length)] - td_actions_buffer = copy.deepcopy(td_actions) for _ in range(2): env = create_env() @@ -2831,7 +2830,7 @@ def policy(td, actions=td_actions_buffer): final_seed.append(env.set_seed(0)) tdreset.append(env.reset()) - tdrollout.append(env.rollout(max_steps=10, policy=policy)) + tdrollout.append(env.rollout(max_steps=rollout_length, policy=policy)) env.close() del env assert final_seed[0] == final_seed[1]