In [1]:
import optree as pt
import numpy as np

In [2]:
pytree = ([[{"foo": np.array([2.0])}], np.array([3.0])],)

In [3]:
# sqrt of each leaf array
sqrt_pytree = pt.tree_map(np.sqrt, pytree)

print(f"{sqrt_pytree=}")

sqrt_pytree=([[{'foo': array([1.41421356])}], array([1.73205081])],)


In [4]:
# reductions

all_positive = pt.tree_all(pt.tree_map(lambda x: x>0.0, pytree))
print(f"{all_positive=}")

summed = pt.tree_reduce(sum, pytree)
print(f"{summed=}")

all_positive=True
summed=array([5.])


In [5]:
# Flattening & unflattening
arrays, treedef = pt.tree_flatten(pytree)

print(f"{arrays=}")
print(f"{treedef=}")

pytree_reconstructed = pt.tree_unflatten(treedef, arrays)
print(f"\n{pytree_reconstructed=}")

arrays=[array([2.]), array([3.])]
treedef=PyTreeSpec(([[{'foo': *}], *],))

pytree_reconstructed=([[{'foo': array([2.])}], array([3.])],)


In [6]:
import optree as pt
from typing import NamedTuple, Callable
from scipy.optimize import minimize as sp_minimize


class Params(NamedTuple):
  x: float
  y: float


def rosenbrock(params: Params) -> float:
  """
  Rosenbrock function. Minimum: f(1, 1) = 0.

  https://en.wikipedia.org/wiki/Rosenbrock_function
  """
  return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2


def minimize(fun: Callable, params: Params) -> Params:
  # flatten and store treedef
  flat_params, treedef = pt.tree_flatten(params)

  # wrap fun to work with flat_params
  def wrapped_fun(flat_params: list[float]):
    params = pt.tree_unflatten(treedef, flat_params)
    return fun(params)

  # actual minimization
  res = sp_minimize(wrapped_fun, flat_params)

  # re-wrap the bestfit values into Params with stored treedef
  return pt.tree_unflatten(treedef, res.x)


# scipy minimize that works with any PyTree
x0 = Params(x=0.9, y=1.2)
bestfit_params = minimize(rosenbrock, x0)
print(bestfit_params)

Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226))


In [7]:
from scipy.optimize import minimize

def rosenbrock(params: tuple[float]) -> float:
  """
  Rosenbrock function. Minimum: f(1, 1) = 0.

  https://en.wikipedia.org/wiki/Rosenbrock_function
  """
  x, y = params
  return (1 - x) ** 2 + 100 * (y - x**2) ** 2


x0 = (0.9, 1.2)
res = minimize(rosenbrock, x0)
print(res.x)

[0.99999569 0.99999137]
