-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Move away from dm-tree broke compatibility with Jax #246
Comments
@PaParaZz1 can you take a look? |
cc @XuehaiPan |
One solution is to register the In envpool/envpool/python/dm_envpool.py Lines 74 to 90 in cd2ece0
we are returning a namedtuple of IMO, we'd better use standard Python containers (e.g., Also, note that In [1]: import treevalue
In [2]: tree = {1: 'a', 2: 'b'}
In [3]: treevalue.FastTreeValue(tree)
TypeError: Expected unicode, got int
In [4]: tree = [{'a': 1}, {'a', 2}]
In [5]: treevalue.FastTreeValue(tree)
TypeError: Unknown initialization type for tree value - 'list'. |
@HansBug and I are working to fix this compatibility problem with JAX. At present, it seems that the solution should be to register |
We are adding This will be released in the next version. import jax
import numpy as np
from treevalue import FastTreeValue, PENETRATE_SESSIONID_ARGNAME, penetrate
@penetrate(jax.jit, static_argnames=PENETRATE_SESSIONID_ARGNAME)
def double(x):
return x * 2
t = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': 233,
'y': np.random.randn(2, 3)
}
})
print(t)
print(double(t))
print(double(t + 1)) |
Another solution based on nativa import jax
import numpy as np
from treevalue import FastTreeValue, flatten, unflatten, TreeValue
def flatten_treevalue(container):
contents = []
paths = []
for path, value in flatten(container):
paths.append(path)
contents.append(value)
return contents, (type(container), paths)
def unflatten_treevalue(aux_data, flat_contents):
type_, paths = aux_data
return unflatten(zip(paths, flat_contents), return_type=type_)
jax.tree_util.register_pytree_node(TreeValue, flatten_treevalue, unflatten_treevalue)
jax.tree_util.register_pytree_node(FastTreeValue, flatten_treevalue, unflatten_treevalue)
data = {
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': 233,
'y': np.random.randn(2, 3)
}
}
t = FastTreeValue(data)
@jax.jit
def double(x):
return x * 2
print(double(t)) |
Now treevalue import jax
from treevalue import FastTreeValue
d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = FastTreeValue(d)
@jax.jit
def double(x):
return x * 2
if __name__ == '__main__':
print(double(t)) If you need to register custom treevalue class, just use register_integrate_container import jax
from treevalue import FastTreeValue, register_treevalue_class
class MyTreeValue(FastTreeValue):
pass
register_treevalue_class(MyTreeValue)
d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = MyTreeValue(d)
@jax.jit
def double(x):
return x * 2
if __name__ == '__main__':
print(double(t)) |
Hi all, Thanks for the swift action on this issue. One comment: I agree with @XuehaiPan's comments here, #249 (comment), that standard container types should be used for the public-facing API. I appreciate the emphasis on performance but I think the tradeoff for user-facing APIs isn't worth it. Another example of custom tree-like data structures getting in the way can be seen in Flax's recent move away from their custom |
## Description The treevalue is not upgraded to `1.4.7`, which supports: 1. generic_flatten, generic_unflatten and generic_mapping functions, here are [some examples](https://opendilab.github.io/treevalue/v1.4.7/api_doc/tree/integration.html#generic-flatten). **With `generic_flatten` and `generic_unflatten`, the problem of non-native instances mentioned by @XuehaiPan will be easily solved**. ```python In [1]: from collections import namedtuple ...: from easydict import EasyDict ...: from treevalue import FastTreeValue ...: ...: class MyTreeValue(FastTreeValue): ...: pass ...: ...: nt = namedtuple('nt', ['a', 'b']) ...: ...: origin = { ...: 'a': 1, ...: 'b': (2, 3, 'f',), ...: 'c': (2, 5, 'ds', { ...: 'x': None, ...: 'z': [34, '1.2'], ...: }), ...: 'd': nt('f', 100), # namedtuple ...: 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue ...: } In [2]: from treevalue import generic_flatten ...: ...: v, spec = generic_flatten(origin) ...: v Out[2]: [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] In [3]: from treevalue import generic_unflatten ...: ...: generic_unflatten(v, spec) Out[3]: {'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fd552b5b410> ├── 'x' --> 1 └── 'y' --> 'dsfljk'} ``` With `generic_mapping`, **the operations can be performed on the entire structure** (even including custom classes, you can register this with [`register_integrate_container`](https://opendilab.github.io/treevalue/v1.4.7/api_doc/tree/integration.html#register-integrate-container)) with good performance, here is the metrics. ![1677503026559](https://user-images.githubusercontent.com/20508435/221570959-550940e8-22aa-4ee3-9967-260c6dfb0199.png) ![1677502998652](https://user-images.githubusercontent.com/20508435/221570895-efa6b35f-3663-42e1-8837-763579b2e850.png) ![1677498651795](https://user-images.githubusercontent.com/20508435/221570554-18361578-323d-421b-b932-c898b5ee8a2d.png) ![1677498667831](https://user-images.githubusercontent.com/20508435/221570567-60cee69b-cebc-48ab-9eb8-ea7f3c1ec158.png) 2. We support the torch pytree's intregration, this is the newest example ```python In [1]: import torch ...: ...: from treevalue import FastTreeValue ...: ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4} ...: t = FastTreeValue(d) ...: In [2]: torch.utils._pytree.tree_flatten(t) # Torch recognizes TreeValue as tree Out[2]: ([1, 2, 3, 4], TreeSpec(FastTreeValue, (<class 'treevalue.tree.general.fast.FastTreeValue'>, [('a',), ('b', 'c'), ('b', 'd'), ('e',)]), [*, *, *, *])) ``` 3. We provide `unpack` method for quickly unpack the values from treevalue. We should also notice that the treevalue's behaviour is similar to `dict`, not namedtuple. It also need to consider the nested structures like dict does. ```python In [1]: from treevalue import FastTreeValue ...: ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4} ...: t = FastTreeValue(d) ...: In [2]: vc, vd = t.b.unpack('c', 'd') In [3]: vc Out[3]: 2 In [4]: vd Out[4]: 3 ``` 4. We are really really glad to see @XuehaiPan @Benjamin-eecs bring some more advise to `treevalue`. This will make `treevalue` better and better since we are able to add all these new features asap. We hope to have more in-depth exchanges with the optree team in order to significantly improve the experience of using treevalue. 😄 --------------- old description --------------- Fix bug #246, by just use the newest version of treevalue. ## Motivation and Context This is a simple fix of bug #246 . In the treevalue `1.4.6`, the `TreeValue` and `FastTreeValue` class, which can satisfy almost all of the actual usage, are registered as jax nodes by default. If custom `TreeValue`-based class need to be registered, you can just do like this ```python from treevalue import FastTreeValue, register_for_jax class MyTreeValue(FastTreeValue): pass register_for_jax(MyTreeValue) # such a simple operation, isn't it? ``` Not only that, from an engineering point of view, maintaining the original design as much as possible is a strategy that is more conducive to code maintenance, and directly upgrading the treevalue version will **minimize the reconstruction to the code and technology stack**. A series of convenient calculation operation provided by treevalue (see the [documentation of treevalue](https://opendilab.github.io/treevalue/main/index.html) for details) are also very tempting for the subsequent development of the project, and the performance is fully sufficient to meet the requirements. Among the users of envpool, jax is not a must, and **users who do not use jax are also common**. It is obvious that treevalue has **better versatility and universality**. In summary, **we recommend upgrading the version of treevalue to fix this bug**. BTW, If you have further optimization needs and suggestions, we will respond asap. - [x] I have raised an issue to propose this change ([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes) ## Types of changes What types of changes does your code introduce? Put an `x` in all the boxes that apply: - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds core functionality) - [ ] New environment (non-breaking change which adds 3rd-party environment) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation (update in the documentation) - [ ] Example (update in the folder of example) ## Checklist Go over all the following points, and put an `x` in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help! - [x] I have read the [CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html) guide (**required**) - [ ] My change requires a change to the documentation. - [x] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. - [x] I have reformatted the code using `make format` (**required**) - [x] I have checked the code using `make lint` (**required**) - [x] I have ensured `make bazel-test` pass. (**required**)
## Description - close #246 - further speedup - [support custom node type registration](pytorch/pytorch#65761 (comment)) ## Motivation and Context <!--Why is this change required? What problem does it solve?--> <!--If it fixes an open issue, please link to the issue here.--> <!--You can use the syntax `close #233` if this solves the issue #233--> Initial Test Results: ### Test gym TreeValue: <img width="1073" alt="image" src="https://user-images.githubusercontent.com/32269413/221405036-ed1d1692-714b-4af7-a011-8accb36b8427.png"> OpTree: <img width="1074" alt="image" src="https://user-images.githubusercontent.com/32269413/221405130-521b48fe-7760-4ef0-8abb-bc9468c7bdeb.png"> ### Test dmc TreeValue: <img width="1061" alt="image" src="https://user-images.githubusercontent.com/32269413/221405292-69f2b96e-d590-42a5-b073-da4721ec6148.png"> OpTree: <img width="1076" alt="image" src="https://user-images.githubusercontent.com/32269413/221405305-e1d53f94-5f97-4c11-a20c-c167fdd9d45b.png"> - [x] I have raised an issue to propose this change ([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes) ## Types of changes What types of changes does your code introduce? Put an `x` in all the boxes that apply: - [x] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds core functionality) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Example (update in the folder of example) ## Checklist Go over all the following points, and put an `x` in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help! - [x] I have read the [CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html) guide (**required**) - [ ] My change requires a change to the documentation. - [x] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. - [x] I have reformatted the code using `make format` (**required**) - [x] I have checked the code using `make lint` (**required**) - [x] I have ensured `make bazel-test` pass. (**required**) cc @XuehaiPan --------- Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Describe the bug
It seems the move away from dm-tree caused some issues as
TreeValue
doesn't register itself as a valid PyTree node.To Reproduce
This is a direct rip from your XLA documentation:
Reason and Possible fixes
Checklist
The text was updated successfully, but these errors were encountered: