-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(hansbug): use treevalue 1.4.7 #249
Conversation
I don't think this (registering JAX pytree node) is the ultimate solution to #246. The true fix is: def _to_dm(
self: Any,
state_values: List[np.ndarray],
reset: bool,
return_info: bool,
) -> TimeStep:
values = map(lambda i: state_values[i], state_idx)
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
+ state = to_namedtuple("State", treevalue.jsonify(state)) # TreeValue -> nested dict -> nested namedtuple
timestep = TimeStep(
step_type=state.step_type,
observation=state.State,
reward=state.reward,
discount=state.discount,
)
return timestep However, As I commented in #246 (comment):
The true issue here is we are returning a collection of third-party data structures (i.e., I would be happy that
|
Treevalue |
Well, there is a package called |
Description
The treevalue is now upgraded to
1.4.7
, which supports:generic_flatten
andgeneric_unflatten
, the problem of non-native instances mentioned by @XuehaiPan will be easily solved.With
generic_mapping
, the operations can be performed on the entire structure (even including custom classes, you can register this withregister_integrate_container
) with good performance, here is the metrics.unpack
method for quickly unpack the values from treevalue. We should also notice that the treevalue's behaviour is similar todict
, not namedtuple. It also need to consider the nested structures like dict does.treevalue
. This will maketreevalue
better and better since we are able to add all these new features asap. We hope to have more in-depth exchanges with the optree team in order to significantly improve the experience of using treevalue. 😄--------------- old description ---------------
Fix bug #246, by just use the newest version of treevalue.
Motivation and Context
This is a simple fix of bug #246 .
In the treevalue
1.4.6
, theTreeValue
andFastTreeValue
class, which can satisfy almost all of the actual usage, are registered as jax nodes by default. If customTreeValue
-based class need to be registered, you can just do like thisNot only that, from an engineering point of view, maintaining the original design as much as possible is a strategy that is more conducive to code maintenance, and directly upgrading the treevalue version will minimize the reconstruction to the code and technology stack. A series of convenient calculation operation provided by treevalue (see the documentation of treevalue for details) are also very tempting for the subsequent development of the project, and the performance is fully sufficient to meet the requirements. Among the users of envpool, jax is not a must, and users who do not use jax are also common. It is obvious that treevalue has better versatility and universality. In summary, we recommend upgrading the version of treevalue to fix this bug.
BTW, If you have further optimization needs and suggestions, we will respond asap.
Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply:Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
make format
(required)make lint
(required)make bazel-test
pass. (required)