From 4d71fbedc0d7d37e594ca28788720f83c78b8cca Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 14 May 2025 13:10:44 -0700 Subject: [PATCH 01/32] wip --- content/posts/optree/pytrees/index.md | 150 ++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 content/posts/optree/pytrees/index.md diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md new file mode 100644 index 0000000..1486f61 --- /dev/null +++ b/content/posts/optree/pytrees/index.md @@ -0,0 +1,150 @@ +--- +title: "Pytrees for Scientific Python" +date: 2025-05-14T10:27:59-07:00 +draft: false +description: " +Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific python, and how to work _efficiently_ with them. +" +tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] +displayInList: true +author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] +--- + +### Manipulating tree-like data using functional programming paradigms + +A "PyTree" is a nested collection of python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. +As you can imagine (or even experienced in the past), these arbitrary nested collections can become cumbersome to manipulate _efficiently_. +Often this requires complex recursive logic, and which usually does not generalize to other PyTree structures. + +#### PyTree Origins + +Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary". +This was quickly adopted by AI researchers: semantically grouping layers of weights and biases in e.g. a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following pseudo-code: + +```python +from typing import NamedTuple, Callable +import jax +import jax.numpy as jnp + + +class Layer(NamedTuple): + W: jax.Array + b: jax.Array + + +layers = [ + Layer(W=jnp.array(...), b=jnp.array(...)), # first layer + Layer(W=jnp.array(...), b=jnp.array(...)), # second layer + ..., +] + + +@jax.jit +def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: + for layer in layers: + x = jnp.tanh(layer.W @ x + layer.b) + return x + + +pred = neural_network(layers=layers, x=jnp.array(...)) +``` + +Here, `layers` is a PyTree - a `list` of multiple `Layer` - and the JIT compiled `neural_network` function _just works_ with this datastructure as input. + +#### PyTrees in Scientific Python + +Wouldn't it be nice to make workflows in the scientific python ecosystem _just work_ with any PyTree? +Enabling semantic meaning through PyTrees can be useful for applications outside of AI as well. +Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: + +```Python +from scipy.optimize import minimize + +def rosenbrock(params: tuple[float]) -> float: + """ + Rosenbrock function. Minimum: f(1, 1) = 0. + + https://en.wikipedia.org/wiki/Rosenbrock_function + """ + x, y = params + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 + + +x0 = (0.9, 1.2) +res = minimize(rosenbrock, x0) +print(res.x) +>> [0.99999569 0.99999137] +``` + +Now, let's turn it a minimization that uses a more complex type for the parameters - a NamedTuple that describes our fit parameters: + +```Python +import optree as pt # standalone PyTree library +from typing import NamedTuple, Callable +from scipy.optimize import minimize as sp_minimize + + +class Params(NamedTuple): + x: float + y: float + + +def rosenbrock(params: Params) -> float: + """ + Rosenbrock function. Minimum: f(1, 1) = 0. + + https://en.wikipedia.org/wiki/Rosenbrock_function + """ + return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2 + + +def minimize(fun: Callable, params: Params) -> Params: + # flatten and store PyTree definition + flat_params, PyTreeDef = pt.tree_flatten(params) + + # wrap fun to work with flat_params + def wrapped_fun(flat_params): + params = pt.tree_unflatten(PyTreeDef, flat_params) + return fun(params) + + # actual minimization + res = sp_minimize(wrapped_fun, flat_params) + + # re-wrap the bestfit values into Params with stored PyTree definition + return pt.tree_unflatten(PyTreeDef, res.x) + + +# scipy minimize that works with any PyTree +x0 = Params(x=0.9, y=1.2) +bestfit_params = minimize(rosenbrock, x0) +print(bestfit_params) +>> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) +``` + +This new `minimize` function works with _any_ PyTree, e.g.: + +```python +import numpy as np + + +def rosenbrock_modified(params: Params) -> float: + """ + Modified Rosenbrock where the x and y parameters are determined by + a non-linear transformations of two versions of each, i.e.: + x = arcsin(min(x1, x2) / max(x1, x2)) + y = sigmoid(x1 - x2) + """ + p1, p2 = params + x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) + y = 1.0 / (1.0 + np.exp(-(p1.y / p2.y))) + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 + + +x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3)) +bestfit_params = minimize(rosenbrock_modified, x0) +print(bestfit_params) +# >> ( +# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)), +# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)), +# ) +``` From dd6c7ca4a4ab3d8906c9b1e0f91086072c9d54f6 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 14 May 2025 15:59:50 -0700 Subject: [PATCH 02/32] first complete version --- content/posts/optree/pytrees/index.md | 83 +++++++++++++++++++++------ 1 file changed, 65 insertions(+), 18 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 1486f61..1e2c4dc 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -10,16 +10,56 @@ displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] --- -### Manipulating tree-like data using functional programming paradigms +## Manipulating tree-like data using functional programming paradigms A "PyTree" is a nested collection of python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. -As you can imagine (or even experienced in the past), these arbitrary nested collections can become cumbersome to manipulate _efficiently_. -Often this requires complex recursive logic, and which usually does not generalize to other PyTree structures. +As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. +It often requires complex recursive logic, and which usually does not generalize to other nested Python containers (PyTrees). -#### PyTree Origins +The core concept of PyTrees is being able to flatten them into a flat collection of leafs and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. +This allows to apply generic transformations, e.g. through a `tree_map(fun, pytree)` operation: -Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary". -This was quickly adopted by AI researchers: semantically grouping layers of weights and biases in e.g. a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following pseudo-code: +```python +import optree as pt +import numpy as np + +# tuple of a list of a dict with an array as value, and an array +pytree = ([[{"foo": np.array([2.0])}], np.array([3.0])],) + +# sqrt of each leaf array +sqrt_pytree = pt.tree_map(np.sqrt, pytree) +print(f"{sqrt_pytree=}") +# >> sqrt_pytree=([[{'foo': array([1.41421356])}], array([1.73205081])],) + +# reductions +all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree)) +print(f"{all_positive=}") +# >> all_positive=True + +summed = pt.tree_reduce(sum, pytree) +print(f"{summed=}") +# >> summed=array([5.]) +``` + +The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: + +```python +# step 1: +leafs, treedef = pt.tree_flatten(pytree) + +# step 2: +new_leafs = tuple(map(fun, leafs)) + +# step 3: +result_pytree = pt.tree_unflatten(treedef, new_leafs) +``` + +Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) - a standalone PyTree library - that enables all these manipulations. It focusses on performance, feature richness, minimal dependencies, and got adopted by PyTorch, Keras, and TensorFlow as a core dependency. + +### PyTree Origins + +Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about python containers, only about JAX Arrays). +However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in e.g. a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: ```python from typing import NamedTuple, Callable @@ -46,20 +86,22 @@ def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: return x -pred = neural_network(layers=layers, x=jnp.array(...)) +prediction = neural_network(layers=layers, x=jnp.array(...)) ``` Here, `layers` is a PyTree - a `list` of multiple `Layer` - and the JIT compiled `neural_network` function _just works_ with this datastructure as input. -#### PyTrees in Scientific Python +### PyTrees in Scientific Python Wouldn't it be nice to make workflows in the scientific python ecosystem _just work_ with any PyTree? -Enabling semantic meaning through PyTrees can be useful for applications outside of AI as well. + +Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well. Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: -```Python +```python from scipy.optimize import minimize + def rosenbrock(params: tuple[float]) -> float: """ Rosenbrock function. Minimum: f(1, 1) = 0. @@ -73,12 +115,12 @@ def rosenbrock(params: tuple[float]) -> float: x0 = (0.9, 1.2) res = minimize(rosenbrock, x0) print(res.x) ->> [0.99999569 0.99999137] +# >> [0.99999569 0.99999137] ``` Now, let's turn it a minimization that uses a more complex type for the parameters - a NamedTuple that describes our fit parameters: -```Python +```python import optree as pt # standalone PyTree library from typing import NamedTuple, Callable from scipy.optimize import minimize as sp_minimize @@ -100,28 +142,28 @@ def rosenbrock(params: Params) -> float: def minimize(fun: Callable, params: Params) -> Params: # flatten and store PyTree definition - flat_params, PyTreeDef = pt.tree_flatten(params) + flat_params, treedef = pt.tree_flatten(params) # wrap fun to work with flat_params def wrapped_fun(flat_params): - params = pt.tree_unflatten(PyTreeDef, flat_params) - return fun(params) + params = pt.tree_unflatten(treedef, flat_params) + return fun(params) # actual minimization res = sp_minimize(wrapped_fun, flat_params) # re-wrap the bestfit values into Params with stored PyTree definition - return pt.tree_unflatten(PyTreeDef, res.x) + return pt.tree_unflatten(treedef, res.x) # scipy minimize that works with any PyTree x0 = Params(x=0.9, y=1.2) bestfit_params = minimize(rosenbrock, x0) print(bestfit_params) ->> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) +# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) ``` -This new `minimize` function works with _any_ PyTree, e.g.: +This new `minimize` function works with _any_ PyTree, let's consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input: ```python import numpy as np @@ -135,8 +177,11 @@ def rosenbrock_modified(params: Params) -> float: y = sigmoid(x1 - x2) """ p1, p2 = params + + # calculate `x` and `y` from two sources: x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) y = 1.0 / (1.0 + np.exp(-(p1.y / p2.y))) + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 @@ -148,3 +193,5 @@ print(bestfit_params) # Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)), # ) ``` + +The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree! From e4804e18f06e0101421ebc9a0b69bb97ad9d194d Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 14 May 2025 16:38:28 -0700 Subject: [PATCH 03/32] first round of review improvements --- content/posts/optree/pytrees/index.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 1e2c4dc..39b73c4 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -10,26 +10,26 @@ displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] --- -## Manipulating tree-like data using functional programming paradigms +## Manipulating Tree-like Data using Functional Programming Paradigms -A "PyTree" is a nested collection of python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. +A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. -It often requires complex recursive logic, and which usually does not generalize to other nested Python containers (PyTrees). +It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). The core concept of PyTrees is being able to flatten them into a flat collection of leafs and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. -This allows to apply generic transformations, e.g. through a `tree_map(fun, pytree)` operation: +This allows to apply generic transformations, e.g. taking the square root of each leaf of a PyTree with a `tree_map(np.sqrt, pytree)` operation: ```python import optree as pt import numpy as np # tuple of a list of a dict with an array as value, and an array -pytree = ([[{"foo": np.array([2.0])}], np.array([3.0])],) +pytree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) # sqrt of each leaf array sqrt_pytree = pt.tree_map(np.sqrt, pytree) print(f"{sqrt_pytree=}") -# >> sqrt_pytree=([[{'foo': array([1.41421356])}], array([1.73205081])],) +# >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],) # reductions all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree)) @@ -38,7 +38,7 @@ print(f"{all_positive=}") summed = pt.tree_reduce(sum, pytree) print(f"{summed=}") -# >> summed=array([5.]) +# >> summed=array([13.]) ``` The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: @@ -54,11 +54,11 @@ new_leafs = tuple(map(fun, leafs)) result_pytree = pt.tree_unflatten(treedef, new_leafs) ``` -Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) - a standalone PyTree library - that enables all these manipulations. It focusses on performance, feature richness, minimal dependencies, and got adopted by PyTorch, Keras, and TensorFlow as a core dependency. +Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) - a standalone PyTree library - that enables all these manipulations. It focusses on performance, feature richness, minimal dependencies, and got adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. ### PyTree Origins -Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about python containers, only about JAX Arrays). +Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in e.g. a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: ```python @@ -93,7 +93,7 @@ Here, `layers` is a PyTree - a `list` of multiple `Layer` - and the JIT compiled ### PyTrees in Scientific Python -Wouldn't it be nice to make workflows in the scientific python ecosystem _just work_ with any PyTree? +Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree? Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well. Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: @@ -163,7 +163,7 @@ print(bestfit_params) # >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) ``` -This new `minimize` function works with _any_ PyTree, let's consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input: +This new `minimize` function works with _any_ PyTree, let's consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input - a common pattern for hierarchical models: ```python import numpy as np @@ -180,7 +180,7 @@ def rosenbrock_modified(params: Params) -> float: # calculate `x` and `y` from two sources: x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) - y = 1.0 / (1.0 + np.exp(-(p1.y / p2.y))) + y = 1 / (1 + np.exp(-(p1.y / p2.y))) return (1 - x) ** 2 + 100 * (y - x**2) ** 2 From d71c8e8fd31d1f95d92d5557f26c560e2325a8ad Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Sun, 18 May 2025 16:05:37 -0400 Subject: [PATCH 04/32] add final thought section --- content/posts/optree/pytrees/index.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 39b73c4..f0fa292 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -195,3 +195,9 @@ print(bestfit_params) ``` The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree! + +### Final Thought + +Working with nested data structures doesn’t have to be messy. +PyTrees let you focus on the data and the transformations you want to apply. +Whether you're building neural networks, optimizing scientific models, or just deal with complex nested python containers, they make your code cleaner, more flexible, and just nicer to work with. From 653296b993f44eff2a6af67b1e0646e69e355268 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Sun, 18 May 2025 16:13:11 -0400 Subject: [PATCH 05/32] improve code snippet for modified rosenbrock --- content/posts/optree/pytrees/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index f0fa292..077b4bc 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -169,14 +169,14 @@ This new `minimize` function works with _any_ PyTree, let's consider a modified import numpy as np -def rosenbrock_modified(params: Params) -> float: +def rosenbrock_modified(two_params: tuple[Params, Params]) -> float: """ Modified Rosenbrock where the x and y parameters are determined by a non-linear transformations of two versions of each, i.e.: x = arcsin(min(x1, x2) / max(x1, x2)) y = sigmoid(x1 - x2) """ - p1, p2 = params + p1, p2 = two_params # calculate `x` and `y` from two sources: x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) From 628d64544089ce42e30ba64d4412b303dc502a39 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Sun, 18 May 2025 16:51:00 -0400 Subject: [PATCH 06/32] remove obsolete comment --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 077b4bc..d6d1787 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -121,7 +121,7 @@ print(res.x) Now, let's turn it a minimization that uses a more complex type for the parameters - a NamedTuple that describes our fit parameters: ```python -import optree as pt # standalone PyTree library +import optree as pt from typing import NamedTuple, Callable from scipy.optimize import minimize as sp_minimize From 7d6b287609791accae6316d1cebab1f9b37bb230 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:16:26 -0400 Subject: [PATCH 07/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index d6d1787..8dd22c5 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -54,7 +54,7 @@ new_leafs = tuple(map(fun, leafs)) result_pytree = pt.tree_unflatten(treedef, new_leafs) ``` -Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) - a standalone PyTree library - that enables all these manipulations. It focusses on performance, feature richness, minimal dependencies, and got adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. +Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. ### PyTree Origins From 99fe18c141d91e09aaa5f584c57ec806ca66e211 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:16:46 -0400 Subject: [PATCH 08/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 8dd22c5..d35216b 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -59,7 +59,7 @@ Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) &mda ### PyTree Origins Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). -However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in e.g. a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: +However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: ```python from typing import NamedTuple, Callable From cd443a389d950d8fb49d0f9f6534620a0105a6ab Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:16:59 -0400 Subject: [PATCH 09/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index d35216b..e9bdbda 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -89,7 +89,7 @@ def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: prediction = neural_network(layers=layers, x=jnp.array(...)) ``` -Here, `layers` is a PyTree - a `list` of multiple `Layer` - and the JIT compiled `neural_network` function _just works_ with this datastructure as input. +Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. ### PyTrees in Scientific Python From 13d17001d4365b01f69ebb666697b14e044667a4 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:17:27 -0400 Subject: [PATCH 10/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index e9bdbda..e249cce 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -118,7 +118,7 @@ print(res.x) # >> [0.99999569 0.99999137] ``` -Now, let's turn it a minimization that uses a more complex type for the parameters - a NamedTuple that describes our fit parameters: +Now, let's consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters: ```python import optree as pt From 17f2f951535d22b18330de90b354e3ea6d86abcb Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:17:51 -0400 Subject: [PATCH 11/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index e249cce..d407ba1 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -163,7 +163,9 @@ print(bestfit_params) # >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) ``` -This new `minimize` function works with _any_ PyTree, let's consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input - a common pattern for hierarchical models: +This new `minimize` function works with _any_ PyTree! + +Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models: ```python import numpy as np From b08dcb4168d9b0ad0f4565d0e682151d2b075eef Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:18:13 -0400 Subject: [PATCH 12/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index d407ba1..c3b3c18 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -202,4 +202,4 @@ The new `minimize` still works, because a `tuple` of `Params` is just _another_ Working with nested data structures doesn’t have to be messy. PyTrees let you focus on the data and the transformations you want to apply. -Whether you're building neural networks, optimizing scientific models, or just deal with complex nested python containers, they make your code cleaner, more flexible, and just nicer to work with. +Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with. From 76bbff496130d84380948221b78b7103e4e6aa99 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 11:18:19 -0400 Subject: [PATCH 13/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index c3b3c18..3a9c1cd 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -3,7 +3,7 @@ title: "Pytrees for Scientific Python" date: 2025-05-14T10:27:59-07:00 draft: false description: " -Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific python, and how to work _efficiently_ with them. +Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them. " tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] displayInList: true From 7a124b137e6a0c1781ab0eba8a7dbdd2d46f816c Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 19 May 2025 13:48:37 -0400 Subject: [PATCH 14/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Mihai Maruseac --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 3a9c1cd..3a46891 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -201,5 +201,5 @@ The new `minimize` still works, because a `tuple` of `Params` is just _another_ ### Final Thought Working with nested data structures doesn’t have to be messy. -PyTrees let you focus on the data and the transformations you want to apply. +PyTrees let you focus on the data and the transformations you want to apply, in a generic manner. Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with. From 19262ddb7439eb2c385fff8bcba52a7a660d4cf3 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 20 May 2025 10:15:13 -0400 Subject: [PATCH 15/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt --- content/posts/optree/pytrees/index.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 3a46891..aceec6b 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -8,6 +8,8 @@ Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] +summary: | + Add summary here. --- ## Manipulating Tree-like Data using Functional Programming Paradigms From bc98572be97d5d4655a3bc691c9ae0f899179e70 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 20 May 2025 10:28:42 -0400 Subject: [PATCH 16/32] add summary --- content/posts/optree/pytrees/index.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index aceec6b..7a2a0cc 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -9,7 +9,10 @@ tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] summary: | - Add summary here. + This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. + While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. + This enables flexible, generic operations like mapping and reducing from functional programming. + By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. --- ## Manipulating Tree-like Data using Functional Programming Paradigms From 11fde89e7c9042bab6d7386f74d72d5adbde7bf7 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 20 May 2025 10:30:27 -0400 Subject: [PATCH 17/32] Revert "add summary" This reverts commit bc98572be97d5d4655a3bc691c9ae0f899179e70. --- content/posts/optree/pytrees/index.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 7a2a0cc..aceec6b 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -9,10 +9,7 @@ tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] summary: | - This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. - While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. - This enables flexible, generic operations like mapping and reducing from functional programming. - By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. + Add summary here. --- ## Manipulating Tree-like Data using Functional Programming Paradigms From 0555d93b48b5896bee15785ac803c83d8c0e6478 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 20 May 2025 10:31:58 -0400 Subject: [PATCH 18/32] add summary --- content/posts/optree/pytrees/index.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index aceec6b..7a2a0cc 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -9,7 +9,10 @@ tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] displayInList: true author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] summary: | - Add summary here. + This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. + While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. + This enables flexible, generic operations like mapping and reducing from functional programming. + By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. --- ## Manipulating Tree-like Data using Functional Programming Paradigms From 4db8fba856c6a2606f7e441356d3ccaeac37dae9 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 20 May 2025 13:13:21 -0400 Subject: [PATCH 19/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Ross Barnowski --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 7a2a0cc..e98a2ed 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -17,7 +17,7 @@ summary: | ## Manipulating Tree-like Data using Functional Programming Paradigms -A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leafs are of interest. +A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest. As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). From 50a3b2eeab0d8cae2d91af2f4c18bc6671a5ba53 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 20 May 2025 13:13:36 -0400 Subject: [PATCH 20/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Ross Barnowski --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index e98a2ed..3b18891 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -21,7 +21,7 @@ A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tupl As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). -The core concept of PyTrees is being able to flatten them into a flat collection of leafs and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. +The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. This allows to apply generic transformations, e.g. taking the square root of each leaf of a PyTree with a `tree_map(np.sqrt, pytree)` operation: ```python From 4dbf7fdaa5d2ed940506d7dbc66c60ad301bee93 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 21 May 2025 11:37:12 -0400 Subject: [PATCH 21/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 3b18891..80c2dcd 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -22,7 +22,7 @@ As you can imagine (or even experienced in the past), such arbitrary nested coll It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. -This allows to apply generic transformations, e.g. taking the square root of each leaf of a PyTree with a `tree_map(np.sqrt, pytree)` operation: +This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, pytree)`: ```python import optree as pt From ae7764e80cb16718a6c0ad260d16d8a332662a79 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 17 Jun 2025 09:38:39 -0400 Subject: [PATCH 22/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Xuehai Pan --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 80c2dcd..5810445 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -37,7 +37,7 @@ print(f"{sqrt_pytree=}") # >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],) # reductions -all_positive = pt.tree_all(pt.tree_map(lambda x: x > 0.0, pytree)) +all_positive = all(x > 0.0 for x in pt.tree_iter(pytree)) print(f"{all_positive=}") # >> all_positive=True From 8c33d24b731e80fae74b95ef11d00132821ab817 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 17 Jun 2025 10:12:23 -0400 Subject: [PATCH 23/32] pytree -> tree --- content/posts/optree/pytrees/index.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 5810445..e0d2dd1 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -22,26 +22,26 @@ As you can imagine (or even experienced in the past), such arbitrary nested coll It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. -This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, pytree)`: +This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, tree)`: ```python import optree as pt import numpy as np # tuple of a list of a dict with an array as value, and an array -pytree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) +tree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) # sqrt of each leaf array -sqrt_pytree = pt.tree_map(np.sqrt, pytree) -print(f"{sqrt_pytree=}") -# >> sqrt_pytree=([[{'foo': array([2.])}], array([3.])],) +sqrt_tree = pt.tree_map(np.sqrt, tree) +print(f"{sqrt_tree=}") +# >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],) # reductions -all_positive = all(x > 0.0 for x in pt.tree_iter(pytree)) +all_positive = all(x > 0.0 for x in pt.tree_iter(tree)) print(f"{all_positive=}") # >> all_positive=True -summed = pt.tree_reduce(sum, pytree) +summed = pt.tree_reduce(sum, tree) print(f"{summed=}") # >> summed=array([13.]) ``` @@ -50,13 +50,13 @@ The trick here is that these operations can be implemented in three steps, e.g. ```python # step 1: -leafs, treedef = pt.tree_flatten(pytree) +leafs, treedef = pt.tree_flatten(tree) # step 2: new_leafs = tuple(map(fun, leafs)) # step 3: -result_pytree = pt.tree_unflatten(treedef, new_leafs) +result_tree = pt.tree_unflatten(treedef, new_leafs) ``` Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. From 9bde52e2714bc1b91fe1ed46dbe0a9764aba9ee5 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 17 Jun 2025 10:22:07 -0400 Subject: [PATCH 24/32] be more specific about the motivation for scientific data --- content/posts/optree/pytrees/index.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index e0d2dd1..f56cd72 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -18,8 +18,9 @@ summary: | ## Manipulating Tree-like Data using Functional Programming Paradigms A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest. -As you can imagine (or even experienced in the past), such arbitrary nested collections can be cumbersome to manipulate _efficiently_. -It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees). +In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure. +Such collections can be cumbersome to manipulate _efficiently_, especially if they are nested any depth. +It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements. The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, tree)`: From 5e59e37aa89fbb693c38dd5b0249371ccdb83aac Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 17 Jun 2025 10:25:50 -0400 Subject: [PATCH 25/32] mention optree earlier --- content/posts/optree/pytrees/index.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index f56cd72..534ff79 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -23,7 +23,9 @@ Such collections can be cumbersome to manipulate _efficiently_, especially if th It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements. The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. -This allows for the application of generic transformations. For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `tree_map(np.sqrt, tree)`: +This allows for the application of generic transformations. +In this blog post, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. +For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `optree.tree_map(np.sqrt, tree)`: ```python import optree as pt @@ -60,8 +62,6 @@ new_leafs = tuple(map(fun, leafs)) result_tree = pt.tree_unflatten(treedef, new_leafs) ``` -Here, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables all these manipulations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. - ### PyTree Origins Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). From e07a2acee44ac23a5db88b0a633bcfbab94f4acf Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 17 Jun 2025 10:28:24 -0400 Subject: [PATCH 26/32] be a bit more specific about the use case with hierarchical models --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 534ff79..55cbccc 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -171,7 +171,7 @@ print(bestfit_params) This new `minimize` function works with _any_ PyTree! -Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models: +Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models (e.g. a superposition of various probability density functions): ```python import numpy as np From 4c548bc296d3c186e79bf70e50ab00e5b33f9a3a Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 18 Jun 2025 07:47:52 -0400 Subject: [PATCH 27/32] leafs -> leaves --- content/posts/optree/pytrees/index.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 55cbccc..dd6f26e 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -53,13 +53,13 @@ The trick here is that these operations can be implemented in three steps, e.g. ```python # step 1: -leafs, treedef = pt.tree_flatten(tree) +leaves, treedef = pt.tree_flatten(tree) # step 2: -new_leafs = tuple(map(fun, leafs)) +new_leaves = tuple(map(fun, leaves)) # step 3: -result_tree = pt.tree_unflatten(treedef, new_leafs) +result_tree = pt.tree_unflatten(treedef, new_leaves) ``` ### PyTree Origins From 254ad902afc883d661c19d772b6eb0d42cf6985f Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Mon, 7 Jul 2025 15:47:54 -0700 Subject: [PATCH 28/32] Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index dd6f26e..4a5b175 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -65,7 +65,7 @@ result_tree = pt.tree_unflatten(treedef, new_leaves) ### PyTree Origins Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). -However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, see the following (pseudo) Python snippet: +However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet: ```python from typing import NamedTuple, Callable From b5acc95b0f6ccd42a6e5cb2c06a5d1f70068b809 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Mon, 7 Jul 2025 15:55:45 -0700 Subject: [PATCH 29/32] be more explicit about reducing arrays with more than 1 element --- content/posts/optree/pytrees/index.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 4a5b175..632c6ca 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -40,13 +40,13 @@ print(f"{sqrt_tree=}") # >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],) # reductions -all_positive = all(x > 0.0 for x in pt.tree_iter(tree)) +all_positive = all(np.all(x > 0.0) for x in pt.tree_iter(tree)) print(f"{all_positive=}") # >> all_positive=True -summed = pt.tree_reduce(sum, tree) +summed = np.sum(pt.tree_reduce(sum, tree)) print(f"{summed=}") -# >> summed=array([13.]) +# >> summed=np.float64(13.0) ``` The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: From 4e14a11a0f59ecaf89fa411f737703035bdd80da Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Mon, 7 Jul 2025 16:52:57 -0700 Subject: [PATCH 30/32] add a sentence of how pytrees and jax.jit work together --- content/posts/optree/pytrees/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 632c6ca..252df6a 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -96,6 +96,7 @@ prediction = neural_network(layers=layers, x=jnp.array(...)) ``` Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. +Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the compiler in contrast to a Python `list` of `NamedTuples`. ### PyTrees in Scientific Python From aa851f03681ebe65731ec455a6059b4e36b9c62c Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Tue, 8 Jul 2025 06:43:19 -0700 Subject: [PATCH 31/32] clarify what 'compiler' is meant with --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index 252df6a..fc71729 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -96,7 +96,7 @@ prediction = neural_network(layers=layers, x=jnp.array(...)) ``` Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. -Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the compiler in contrast to a Python `list` of `NamedTuples`. +Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python `list` of `NamedTuples`. ### PyTrees in Scientific Python From 5a569a09cdfdbdeed7633370cd2e855a87817ede Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Tue, 8 Jul 2025 12:40:16 -0700 Subject: [PATCH 32/32] Set date --- content/posts/optree/pytrees/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/content/posts/optree/pytrees/index.md b/content/posts/optree/pytrees/index.md index fc71729..d141bb6 100644 --- a/content/posts/optree/pytrees/index.md +++ b/content/posts/optree/pytrees/index.md @@ -1,6 +1,6 @@ --- title: "Pytrees for Scientific Python" -date: 2025-05-14T10:27:59-07:00 +date: 2025-07-08 draft: false description: " Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them.