-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] Use ObservationNorm.init_stats for stats computation in example scripts #715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that here, if we spawn multiple processes, each env on each process will run its own init_env_steps and hence they will all have a different set of summary stats.
We should compute the stats in the main process like we did before (but using the method you provided), then pass these stats to each env.
…tionNorm transforms. Add checks in init_stats to ensure proper initialization
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Do you think things would be simpler with state dict?
We could
(1) create a dummy env, compute stats
(2) get the state dict of this dummy env
(3) load the state dict on every env created subsequently
LMK what you think
torchrl/trainers/helpers/envs.py
Outdated
| raise AttributeError("init_env_steps missing from arguments.") | ||
|
|
||
| if ( | ||
| type(proof_environment.transform) != Compose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not isinstance?
Why equality and not is not?
torchrl/trainers/helpers/envs.py
Outdated
|
|
||
| if ( | ||
| type(proof_environment.transform) != Compose | ||
| and type(proof_environment.transform) != ObservationNorm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
torchrl/trainers/helpers/envs.py
Outdated
| ) | ||
|
|
||
| obs_norm_transforms = [] | ||
| if type(proof_environment.transform) == Compose: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
torchrl/trainers/helpers/envs.py
Outdated
| obs_norm_transforms.append((0, proof_environment.transform)) | ||
|
|
||
| stats = [] | ||
| for (idx, transform) in obs_norm_transforms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon reflection we could simply take the state dict of the transforms and load it no?
Wouldn't that be simpler than this?
If loc and scale are buffers it could simplify things a bit
acf16b8 to
17b1b15
Compare
Codecov Report
@@ Coverage Diff @@
## main #715 +/- ##
==========================================
- Coverage 88.71% 88.66% -0.05%
==========================================
Files 120 120
Lines 20240 20386 +146
==========================================
+ Hits 17955 18075 +120
- Misses 2285 2311 +26
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
torchrl/trainers/helpers/envs.py
Outdated
|
|
||
|
|
||
| def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv): | ||
| """Traverse the transforms of the environment and retrieve the ObservationNorm state dicts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe Traverses + retrieves?
Also for code objects we can use :obj:ObservationNorm for a cleaner rendering in the doc
torchrl/trainers/helpers/envs.py
Outdated
| num_iter: int = 1000, | ||
| key: Union[str, Tuple[str, ...]] = None, | ||
| ): | ||
| """Calling init_stats on all uninitialised ObservationNorms transform of a TransformedEnv. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
torchrl/trainers/helpers/envs.py
Outdated
| ): | ||
| """Calling init_stats on all uninitialised ObservationNorms transform of a TransformedEnv. | ||
| If an ObservationNorm already has non-null loc or stats, it will be skipped. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an :obj:`ObservationNorm.init_stats` already has non-null loc or stats, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
Similarly, if the transformed environment does not contain any ObservationNorm, a call to this function will have no effect.
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
Refactor examples to use
ObservationNorm.init_statsinstead ofget_stats_random_rolloutMotivation and Context
close #699
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!