In [1]:
from typing import Callable
import pytensor
import pytensor.tensor as pt
from scipy import linalg
from pytensor.scan.utils import until
from functools import partial

In [2]:
def _newton_step(func, x, args):
    f_x = func(x, *args)
    jac = pt.jacobian(f_x, x)

    # TODO It would be nice to return the factored matrix for the pullback
    # TODO Handle errors of the factorization
    grad = pt.linalg.solve(jac, f_x, assume_a="sym")

    return f_x, x - grad, grad, jac

def _check_convergence(f_x, x, new_x, grad, tol):
    # TODO What convergence criterion? Norm of grad etc...
    converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol)
    return converged

def _scan_step(x, n_steps, *args, func, tol):
    f_x, new_x, grad, jac = _newton_step(func, x, args)
    is_converged = _check_convergence(f_x, x, new_x, grad, tol)
    return (new_x, n_steps + 1, jac), until(is_converged)

def root(
    func: Callable,
    x0: pt.TensorVariable,  # rank 1
    args: tuple[pt.Variable, ...],
    max_iter: int = 113,
    tol: float = 1e-8,
) -> tuple[
    pt.TensorVariable, dict,
]:
    root_func = partial(
        _scan_step,
        func=func,
        tol=tol,
    )

    outputs, updates = pytensor.scan(
        root_func,
        outputs_info=[x0, pt.constant(0, dtype="int64"), None],
        non_sequences=args,
        n_steps=max_iter,
        strict=True,
    )

    x_trace, n_steps_trace, jac_trace = outputs
    assert not updates

    return x_trace[-1], {"n_steps": n_steps_trace[-1], "jac": jac_trace[-1]}


def minimize(cost: Callable, x0: pt.TensorVariable, args):
    def func(x):
        return pt.grad(cost(x), x)

    return root(func, x0, args)

In [3]:
import numpy as np

In [4]:
x0 = pt.tensor("x0", shape=(3,))
#x0 = pt.full((3,), [2., 2., 2.])
#x0 = x0.copy()

mu = pt.tensor("mu", shape=())

def func(x, mu):
    cost = pt.sum((x ** 2 - mu) ** 2)
    return pt.grad(cost, x)


x_root, stats = root(func, x0, args=[mu], tol=1e-8)

(x_root_dmu,) = pt.grad(x_root[0], [mu])

f_x = func(x_root, mu)
dfunc_dmu = pt.jacobian(f_x, mu, consider_constant=[x_root])

In [5]:
func = pytensor.function([x0, mu], [x_root, stats["n_steps"], stats["jac"], dfunc_dmu])

In [8]:
x_root, n_steps, jac, dfunc_dmu_val = func(np.ones(3) * 3, np.full((), 5.))

In [9]:
# Dervivative of x_root with respect to mu
-linalg.solve(jac, dfunc_dmu_val, assume_a="sym")

array([0.2236068, 0.2236068, 0.2236068])