Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Move away from dm-tree broke compatibility with Jax #246

Closed
3 tasks done
JesseFarebro opened this issue Feb 23, 2023 · 8 comments · Fixed by #247
Closed
3 tasks done

[BUG] Move away from dm-tree broke compatibility with Jax #246

JesseFarebro opened this issue Feb 23, 2023 · 8 comments · Fixed by #247
Assignees
Labels
bug Something isn't working

Comments

@JesseFarebro
Copy link

JesseFarebro commented Feb 23, 2023

Describe the bug

It seems the move away from dm-tree caused some issues as TreeValue doesn't register itself as a valid PyTree node.

To Reproduce

This is a direct rip from your XLA documentation:

import envpool
import jax

env = envpool.make(
    "Pong-v5",
    env_type="dm",
    num_envs=2,
)
handle, recv, send, _ = env.xla()

def actor_step(iter, loop_var):
    handle0, states = loop_var
    action = 0
    handle1 = send(handle0, action, states.observation.env_id)
    handle1, new_states = recv(handle0)
    return handle1, new_states

@jax.jit
def run_actor_loop(num_steps, init_var):
    return jax.lax.fori_loop(0, num_steps, actor_step, init_var)

env.async_reset()
handle, states = recv(handle)
run_actor_loop(100, (handle, states))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'treevalue.tree.tree.tree.TreeValue'> as an abstract array; it does not have a dtype attribute

Reason and Possible fixes

  • Revert back to dm-tree
  • File issue upstream
  • Implement your own PyTree class and register it with Jax

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@Trinkle23897 Trinkle23897 added the bug Something isn't working label Feb 23, 2023
@Trinkle23897
Copy link
Collaborator

@PaParaZz1 can you take a look?

@Benjamin-eecs
Copy link
Collaborator

cc @XuehaiPan

@XuehaiPan
Copy link
Contributor

One solution is to register the TreeValue classes from treevalue as JAX PyTree node type. FYI, this will need to register all possible classes, such as TreeValue and FastTreeValue. Because the JAX PyTree registry lookup uses type(node) is registered_type rather than isinstance(node, registered_type.

In _to_dm:

def _to_dm(
self: Any,
state_values: List[np.ndarray],
reset: bool,
return_info: bool,
) -> TimeStep:
values = map(lambda i: state_values[i], state_idx)
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
timestep = TimeStep(
step_type=state.step_type,
observation=state.State,
reward=state.reward,
discount=state.discount,
)
return timestep

we are returning a namedtuple of TreeValue instances, which are non-jitable.

IMO, we'd better use standard Python containers (e.g., dicts or namedtuples) rather than TreeValue instances in our public API. The standard Python containers always have first-party support for many pytree libraries (jax, torch, dm-tree, optree).

Also, note that treevalue only supports nested dicts with str keys. It does not support arbitrary nested Python containers:

In [1]: import treevalue

In [2]: tree = {1: 'a', 2: 'b'}

In [3]: treevalue.FastTreeValue(tree)
TypeError: Expected unicode, got int

In [4]: tree = [{'a': 1}, {'a', 2}]

In [5]: treevalue.FastTreeValue(tree)
TypeError: Unknown initialization type for tree value - 'list'.

@PaParaZz1
Copy link

@PaParaZz1 can you take a look?

@HansBug and I are working to fix this compatibility problem with JAX. At present, it seems that the solution should be to register TreeValue in JAX.

@HansBug
Copy link
Contributor

HansBug commented Feb 26, 2023

We are adding penetrate function in order to make jax.jit support FastTreeValue (see: opendilab/treevalue#77 ). Here is the usage: https://opendilab.github.io/treevalue/dev/wrap/api_doc/tree/tree.html#penetrate

This will be released in the next version.

import jax
import numpy as np

from treevalue import FastTreeValue, PENETRATE_SESSIONID_ARGNAME, penetrate


@penetrate(jax.jit, static_argnames=PENETRATE_SESSIONID_ARGNAME)
def double(x):
    return x * 2


t = FastTreeValue({
    'a': np.random.randint(0, 10, (2, 3)),
    'b': {
        'x': 233,
        'y': np.random.randn(2, 3)
    }
})

print(t)
print(double(t))
print(double(t + 1))

@HansBug
Copy link
Contributor

HansBug commented Feb 26, 2023

Another solution based on nativa jax register

import jax
import numpy as np

from treevalue import FastTreeValue, flatten, unflatten, TreeValue


def flatten_treevalue(container):
    contents = []
    paths = []
    for path, value in flatten(container):
        paths.append(path)
        contents.append(value)

    return contents, (type(container), paths)


def unflatten_treevalue(aux_data, flat_contents):
    type_, paths = aux_data
    return unflatten(zip(paths, flat_contents), return_type=type_)


jax.tree_util.register_pytree_node(TreeValue, flatten_treevalue, unflatten_treevalue)
jax.tree_util.register_pytree_node(FastTreeValue, flatten_treevalue, unflatten_treevalue)


data = {
    'a': np.random.randint(0, 10, (2, 3)),
    'b': {
        'x': 233,
        'y': np.random.randn(2, 3)
    }
}
t = FastTreeValue(data)

@jax.jit
def double(x):
    return x * 2


print(double(t))

@HansBug
Copy link
Contributor

HansBug commented Feb 28, 2023

Now treevalue 1.4.7 can support the usage through jax.jit @JesseFarebro

import jax

from treevalue import FastTreeValue

d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = FastTreeValue(d)


@jax.jit
def double(x):
    return x * 2


if __name__ == '__main__':
    print(double(t))

If you need to register custom treevalue class, just use register_integrate_container

import jax

from treevalue import FastTreeValue, register_treevalue_class


class MyTreeValue(FastTreeValue):
    pass


register_treevalue_class(MyTreeValue)

d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = MyTreeValue(d)


@jax.jit
def double(x):
    return x * 2


if __name__ == '__main__':
    print(double(t))

@JesseFarebro
Copy link
Author

Hi all,

Thanks for the swift action on this issue. One comment: I agree with @XuehaiPan's comments here, #249 (comment), that standard container types should be used for the public-facing API.

I appreciate the emphasis on performance but I think the tradeoff for user-facing APIs isn't worth it. Another example of custom tree-like data structures getting in the way can be seen in Flax's recent move away from their custom FrozenDict structure to regular dicts. There were some issues irrespective of immutability that spurred on this change (e.g., see long-standing issues in Optax RE: Flax FrozenDict).

Trinkle23897 pushed a commit that referenced this issue Mar 20, 2023
## Description

The treevalue is not upgraded to `1.4.7`, which supports:
1. generic_flatten, generic_unflatten and generic_mapping functions,
here are [some
examples](https://opendilab.github.io/treevalue/v1.4.7/api_doc/tree/integration.html#generic-flatten).
**With `generic_flatten` and `generic_unflatten`, the problem of
non-native instances mentioned by @XuehaiPan will be easily solved**.

```python
In [1]: from collections import namedtuple
   ...: from easydict import EasyDict
   ...: from treevalue import FastTreeValue
   ...: 
   ...: class MyTreeValue(FastTreeValue):
   ...:     pass
   ...: 
   ...: nt = namedtuple('nt', ['a', 'b'])
   ...: 
   ...: origin = {
   ...:     'a': 1,
   ...:     'b': (2, 3, 'f',),
   ...:     'c': (2, 5, 'ds', {
   ...:         'x': None,
   ...:         'z': [34, '1.2'],
   ...:     }),
   ...:     'd': nt('f', 100),  # namedtuple
   ...:     'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})  # treevalue
   ...: }

In [2]: from treevalue import generic_flatten
   ...: 
   ...: v, spec = generic_flatten(origin)
   ...: v
Out[2]: [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

In [3]: from treevalue import generic_unflatten
   ...: 
   ...: generic_unflatten(v, spec)
Out[3]: 
{'a': 1,
 'b': (2, 3, 'f'),
 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}),
 'd': nt(a='f', b=100),
 'e': <MyTreeValue 0x7fd552b5b410>
 ├── 'x' --> 1
 └── 'y' --> 'dsfljk'}
```

With `generic_mapping`, **the operations can be performed on the entire
structure** (even including custom classes, you can register this with
[`register_integrate_container`](https://opendilab.github.io/treevalue/v1.4.7/api_doc/tree/integration.html#register-integrate-container))
with good performance, here is the metrics.


![1677503026559](https://user-images.githubusercontent.com/20508435/221570959-550940e8-22aa-4ee3-9967-260c6dfb0199.png)

![1677502998652](https://user-images.githubusercontent.com/20508435/221570895-efa6b35f-3663-42e1-8837-763579b2e850.png)

![1677498651795](https://user-images.githubusercontent.com/20508435/221570554-18361578-323d-421b-b932-c898b5ee8a2d.png)

![1677498667831](https://user-images.githubusercontent.com/20508435/221570567-60cee69b-cebc-48ab-9eb8-ea7f3c1ec158.png)

2. We support the torch pytree's intregration, this is the newest
example

```python
In [1]: import torch
   ...: 
   ...: from treevalue import FastTreeValue
   ...: 
   ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
   ...: t = FastTreeValue(d)
   ...: 

In [2]: torch.utils._pytree.tree_flatten(t)  # Torch recognizes TreeValue as tree
Out[2]: 
([1, 2, 3, 4],
 TreeSpec(FastTreeValue, (<class 'treevalue.tree.general.fast.FastTreeValue'>, [('a',), ('b', 'c'), ('b', 'd'), ('e',)]), [*, *, *, *]))
```

3. We provide `unpack` method for quickly unpack the values from
treevalue. We should also notice that the treevalue's behaviour is
similar to `dict`, not namedtuple. It also need to consider the nested
structures like dict does.

```python
In [1]: from treevalue import FastTreeValue
   ...: 
   ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
   ...: t = FastTreeValue(d)
   ...: 

In [2]: vc, vd = t.b.unpack('c', 'd')

In [3]: vc
Out[3]: 2

In [4]: vd
Out[4]: 3
```

4. We are really really glad to see @XuehaiPan @Benjamin-eecs bring some
more advise to `treevalue`. This will make `treevalue` better and better
since we are able to add all these new features asap. We hope to have
more in-depth exchanges with the optree team in order to significantly
improve the experience of using treevalue. 😄


--------------- old description ---------------

Fix bug #246, by just use the
newest version of treevalue.

## Motivation and Context

This is a simple fix of bug
#246 .

In the treevalue `1.4.6`, the `TreeValue` and `FastTreeValue` class,
which can satisfy almost all of the actual usage, are registered as jax
nodes by default. If custom `TreeValue`-based class need to be
registered, you can just do like this

```python
from treevalue import FastTreeValue, register_for_jax


class MyTreeValue(FastTreeValue):
    pass


register_for_jax(MyTreeValue)  # such a simple operation, isn't it?
```

Not only that, from an engineering point of view, maintaining the
original design as much as possible is a strategy that is more conducive
to code maintenance, and directly upgrading the treevalue version will
**minimize the reconstruction to the code and technology stack**. A
series of convenient calculation operation provided by treevalue (see
the [documentation of
treevalue](https://opendilab.github.io/treevalue/main/index.html) for
details) are also very tempting for the subsequent development of the
project, and the performance is fully sufficient to meet the
requirements. Among the users of envpool, jax is not a must, and **users
who do not use jax are also common**. It is obvious that treevalue has
**better versatility and universality**. In summary, **we recommend
upgrading the version of treevalue to fix this bug**.

BTW, If you have further optimization needs and suggestions, we will
respond asap.

- [x] I have raised an issue to propose this change
([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
for new features and bug fixes)

## Types of changes

What types of changes does your code introduce? Put an `x` in all the
boxes that apply:

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds core functionality)
- [ ] New environment (non-breaking change which adds 3rd-party
environment)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation (update in the documentation)
- [ ] Example (update in the folder of example)

## Checklist

Go over all the following points, and put an `x` in all the boxes that
apply.
If you are unsure about any of these, don't hesitate to ask. We are here
to help!

- [x] I have read the
[CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
guide (**required**)
- [ ] My change requires a change to the documentation.
- [x] I have updated the tests accordingly (*required for a bug fix or a
new feature*).
- [ ] I have updated the documentation accordingly.
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make lint` (**required**)
- [x] I have ensured `make bazel-test` pass. (**required**)
Trinkle23897 pushed a commit that referenced this issue Mar 20, 2023
## Description

- close #246 
- further speedup
- [support custom node type
registration](pytorch/pytorch#65761 (comment))

## Motivation and Context

<!--Why is this change required? What problem does it solve?-->
<!--If it fixes an open issue, please link to the issue here.-->
<!--You can use the syntax `close #233` if this solves the issue #233-->

Initial Test Results:

### Test gym

TreeValue:

<img width="1073" alt="image"
src="https://user-images.githubusercontent.com/32269413/221405036-ed1d1692-714b-4af7-a011-8accb36b8427.png">

OpTree:

<img width="1074" alt="image"
src="https://user-images.githubusercontent.com/32269413/221405130-521b48fe-7760-4ef0-8abb-bc9468c7bdeb.png">

### Test dmc

TreeValue:

<img width="1061" alt="image"
src="https://user-images.githubusercontent.com/32269413/221405292-69f2b96e-d590-42a5-b073-da4721ec6148.png">

OpTree:

<img width="1076" alt="image"
src="https://user-images.githubusercontent.com/32269413/221405305-e1d53f94-5f97-4c11-a20c-c167fdd9d45b.png">






- [x] I have raised an issue to propose this change
([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
for new features and bug fixes)

## Types of changes

What types of changes does your code introduce? Put an `x` in all the
boxes that apply:

- [x] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds core functionality)
- [x] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Example (update in the folder of example)

## Checklist

Go over all the following points, and put an `x` in all the boxes that
apply.
If you are unsure about any of these, don't hesitate to ask. We are here
to help!

- [x] I have read the
[CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
guide (**required**)
- [ ] My change requires a change to the documentation.
- [x] I have updated the tests accordingly (*required for a bug fix or a
new feature*).
- [ ] I have updated the documentation accordingly.
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make lint` (**required**)
- [x] I have ensured `make bazel-test` pass. (**required**)

cc @XuehaiPan

---------

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants