Skip to content

Commit

Permalink
[BugFix] Fix VecNorm test in test_collectors.py (#2162)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 16, 2024
1 parent 73d09c3 commit 3a7cf6a
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,17 +1136,33 @@ def test_collector_vecnorm_envcreator(static_seed):

s = c.state_dict()

td1 = s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]["td"].clone()
td2 = s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]["td"].clone()
td1 = (
TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"])
.unflatten_keys(VecNorm.SEP)
.clone()
)
td2 = (
TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"])
.unflatten_keys(VecNorm.SEP)
.clone()
)
assert (td1 == td2).all()

next(c_iter)
next(c_iter)

s = c.state_dict()

td3 = s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]["td"].clone()
td4 = s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]["td"].clone()
td3 = (
TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"])
.unflatten_keys(VecNorm.SEP)
.clone()
)
td4 = (
TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"])
.unflatten_keys(VecNorm.SEP)
.clone()
)
assert (td3 == td4).all()
assert (td1 != td4).any()
c.shutdown()
Expand Down

0 comments on commit 3a7cf6a

Please sign in to comment.