Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(hansbug): use treevalue 1.4.7 #249

Merged
merged 4 commits into from
Mar 20, 2023
Merged

Conversation

HansBug
Copy link
Contributor

@HansBug HansBug commented Feb 26, 2023

Description

The treevalue is now upgraded to 1.4.7, which supports:

  1. generic_flatten, generic_unflatten and generic_mapping functions, here are some examples. With generic_flatten and generic_unflatten, the problem of non-native instances mentioned by @XuehaiPan will be easily solved.
In [1]: from collections import namedtuple
   ...: from easydict import EasyDict
   ...: from treevalue import FastTreeValue
   ...: 
   ...: class MyTreeValue(FastTreeValue):
   ...:     pass
   ...: 
   ...: nt = namedtuple('nt', ['a', 'b'])
   ...: 
   ...: origin = {
   ...:     'a': 1,
   ...:     'b': (2, 3, 'f',),
   ...:     'c': (2, 5, 'ds', {
   ...:         'x': None,
   ...:         'z': [34, '1.2'],
   ...:     }),
   ...:     'd': nt('f', 100),  # namedtuple
   ...:     'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})  # treevalue
   ...: }

In [2]: from treevalue import generic_flatten
   ...: 
   ...: v, spec = generic_flatten(origin)
   ...: v
Out[2]: [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

In [3]: from treevalue import generic_unflatten
   ...: 
   ...: generic_unflatten(v, spec)
Out[3]: 
{'a': 1,
 'b': (2, 3, 'f'),
 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}),
 'd': nt(a='f', b=100),
 'e': <MyTreeValue 0x7fd552b5b410>
 ├── 'x' --> 1
 └── 'y' --> 'dsfljk'}

With generic_mapping, the operations can be performed on the entire structure (even including custom classes, you can register this with register_integrate_container) with good performance, here is the metrics.

1677503026559
1677502998652
1677498651795
1677498667831

  1. We support the torch pytree's intregration, this is the newest example
In [1]: import torch
   ...: 
   ...: from treevalue import FastTreeValue
   ...: 
   ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
   ...: t = FastTreeValue(d)
   ...: 

In [2]: torch.utils._pytree.tree_flatten(t)  # Torch recognizes TreeValue as tree
Out[2]: 
([1, 2, 3, 4],
 TreeSpec(FastTreeValue, (<class 'treevalue.tree.general.fast.FastTreeValue'>, [('a',), ('b', 'c'), ('b', 'd'), ('e',)]), [*, *, *, *]))
  1. We provide unpack method for quickly unpack the values from treevalue. We should also notice that the treevalue's behaviour is similar to dict, not namedtuple. It also need to consider the nested structures like dict does.
In [1]: from treevalue import FastTreeValue
   ...: 
   ...: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
   ...: t = FastTreeValue(d)
   ...: 

In [2]: vc, vd = t.b.unpack('c', 'd')

In [3]: vc
Out[3]: 2

In [4]: vd
Out[4]: 3
  1. We are really really glad to see @XuehaiPan @Benjamin-eecs bring some more advise to treevalue. This will make treevalue better and better since we are able to add all these new features asap. We hope to have more in-depth exchanges with the optree team in order to significantly improve the experience of using treevalue. 😄

--------------- old description ---------------

Fix bug #246, by just use the newest version of treevalue.

Motivation and Context

This is a simple fix of bug #246 .

In the treevalue 1.4.6, the TreeValue and FastTreeValue class, which can satisfy almost all of the actual usage, are registered as jax nodes by default. If custom TreeValue-based class need to be registered, you can just do like this

from treevalue import FastTreeValue, register_for_jax


class MyTreeValue(FastTreeValue):
    pass


register_for_jax(MyTreeValue)  # such a simple operation, isn't it?

Not only that, from an engineering point of view, maintaining the original design as much as possible is a strategy that is more conducive to code maintenance, and directly upgrading the treevalue version will minimize the reconstruction to the code and technology stack. A series of convenient calculation operation provided by treevalue (see the documentation of treevalue for details) are also very tempting for the subsequent development of the project, and the performance is fully sufficient to meet the requirements. Among the users of envpool, jax is not a must, and users who do not use jax are also common. It is obvious that treevalue has better versatility and universality. In summary, we recommend upgrading the version of treevalue to fix this bug.

BTW, If you have further optimization needs and suggestions, we will respond asap.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • New environment (non-breaking change which adds 3rd-party environment)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of example)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the code using make lint (required)
  • I have ensured make bazel-test pass. (required)

@XuehaiPan
Copy link
Contributor

XuehaiPan commented Feb 27, 2023

In the treevalue 1.4.6, the TreeValue and FastTreeValue class, which can satisfy almost all of the actual usage, are registered as jax nodes by default.

I don't think this (registering JAX pytree node) is the ultimate solution to #246.

The true fix is:

    def _to_dm(
      self: Any,
      state_values: List[np.ndarray],
      reset: bool,
      return_info: bool,
    ) -> TimeStep:
      values = map(lambda i: state_values[i], state_idx)
      state = treevalue.unflatten(
        [(path, vi) for (path, _), vi in zip(tree_pairs, values)]
      )
+     state = to_namedtuple("State", treevalue.jsonify(state))  # TreeValue -> nested dict -> nested namedtuple
      timestep = TimeStep(
        step_type=state.step_type,
        observation=state.State,
        reward=state.reward,
        discount=state.discount,
      )
      return timestep

However, to_namedtuple will dynamically create namedtuple classes each time we call _to_dm. This will introduce serious performance regression and also would be a disaster for the Python GC system.


As I commented in #246 (comment):

In _to_dm:

def _to_dm(
self: Any,
state_values: List[np.ndarray],
reset: bool,
return_info: bool,
) -> TimeStep:
values = map(lambda i: state_values[i], state_idx)
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
timestep = TimeStep(
step_type=state.step_type,
observation=state.State,
reward=state.reward,
discount=state.discount,
)
return timestep

we are returning a namedtuple of TreeValue instances, which are non-jitable.

IMO, we'd better use standard Python containers (e.g., dicts or namedtuples) rather than TreeValue instances in our public API. The standard Python containers always have first-party support for many pytree libraries (jax, torch, dm-tree, optree).

The true issue here is we are returning a collection of third-party data structures (i.e., TreeValue instances) in our public API. We'd better return Python builtin types here. Otherwise, it would be a serious hindrance for downstream RL frameworks to use envpool.

I would be happy that treevalue now out-of-box supports JAX pytree system. However, it still needs some work to do.

  1. Not support the PyTorch pytree utilities and dm-tree. TreeValue instances still are considered as leaves. FYI, Ray RLlib is using the dm-tree as their pytree utility. Returning TreeValue instances would be a serious hindrance for RLlib to use envpool.

    In [1]: import treevalue  # 1.4.6
    
    In [2]: d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
    
    In [3]: tree = treevalue.FastTreeValue(d)
    
    In [4]: import jax
    
    In [5]: jax.tree_util.tree_leaves(d)     # JAX recognizes nested dict as tree
    Out[5]: [1, 2, 3, 4]
    
    In [6]: jax.tree_util.tree_leaves(tree)  # JAX recognizes TreeValue as tree
    Out[6]: [1, 2, 3, 4]
    
    In [7]: import torch
    
    In [8]: torch.utils._pytree.tree_flatten(d)  # PyTorch recognizes nested dict as tree
    Out[8]: 
    (
        [1, 2, 3, 4],
        TreeSpec(dict, ['a', 'b', 'e'], [*, TreeSpec(dict, ['c', 'd'], [*, *]), *])
    )
    
    In [9]: torch.utils._pytree.tree_flatten(tree)  # PyTorch recognizes TreeValue as leaf
    Out[9]: 
    (
        [
            <FastTreeValue 0x7fe302d83670>
            ├── 'a' --> 1
            ├── 'b' --> <FastTreeValue 0x7fe302d83be0>
            │   ├── 'c' --> 2
            │   └── 'd' --> 3
            └── 'e' --> 4
        
        ],
        *
    )
    
    In [10]: import tree as dm_tree
    
    In [11]: dm_tree.flatten(d)     # DM-Tree recognizes nested dict as tree
    Out[11]: [1, 2, 3, 4]
    
    In [12]: dm_tree.flatten(tree)  # DM-Tree recognizes TreeValue as leaf
    Out[12]: 
    [
        <FastTreeValue 0x7fe302d83670>
        ├── 'a' --> 1
        ├── 'b' --> <FastTreeValue 0x7fe302d83be0>
        │   ├── 'c' --> 2
        │   └── 'd' --> 3
        └── 'e' --> 4
    
    ]

    It would be impossible for treevalue to out-of-box support all pytree utilities. Unless each time a new pytree package come out you do a revision bump in treevalue and ask all your dependent to do the same thing. See the fix I suggested above, return a Python builtin type in envpool API.

  2. Although TreeValue instance supports attribute access. But it cannot simulate all behavior of namedtuple, such as sequence unpacking. It should not be returned in dm_env.TimeStep.

    In [13]: tree.a  # TreeValue supports attribute access
    Out[13]: 1
    
    In [14]: a, b, e = tree  # TreeValue supports sequence unpacking but iterates the keys
    
    In [15]: a
    Out[15]: 'a'  # expected 1
    
    In [16]: b
    Out[16]: 'b'  # expected somenamedtuple(c=2, d=3)
    
    In [17]: e
    Out[17]: 'e'  # expected 4

@HansBug HansBug changed the title fix(hansbug): use treevalue 1.4.6 fix(hansbug): use treevalue 1.4.7 Feb 27, 2023
@HansBug
Copy link
Contributor Author

HansBug commented Feb 27, 2023

Treevalue 1.4.7 is now released.

@HansBug
Copy link
Contributor Author

HansBug commented Feb 28, 2023

I don't think this (registering JAX pytree node) is the ultimate solution to #246.

The true issue here is we are returning a collection of third-party data structures (i.e., TreeValue instances) in our public API. We'd better return Python building types here. Otherwise, it would be a serious hindrance for downstream RL frameworks to use envpool.

Well, there is a package called easydict. This is a simple inheritance of native dict. The values can be accessed through attributes with similar performance of native dict. So maybe TreeValue --> easydict will be another solution which is much better that TreeValue --> dict --> namedtuple. Namedtuple is unnecessary here.

@Trinkle23897
Copy link
Collaborator

triton-lang/triton#1374

@Trinkle23897 Trinkle23897 merged commit b8df995 into sail-sg:main Mar 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants