In [1]:
import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor.graph import FunctionGraph

import pymc as pm

from pymc.model.transform.conditioning import remove_value_transforms
from pymc.pytensorf import toposort_replace

In [2]:
# %load_ext autoreload
# %autoreload 2

In [3]:
class CenterTransform(pm.distributions.transforms.Transform):
    ndim_supp = 0

    name = "CenterTransform"

    def __init__(self, trafo_param, sigma_fn=lambda args: args[-1]):
        self._trafo_params = (trafo_param,)
        self._trafo_param = trafo_param
        self.sigma_fn = sigma_fn

    def get_sigma(self, args):
        # *rv_params, hyper = args
        return self.sigma_fn(args)

    def get_hyperparam(self):
        return pt.sigmoid(self._trafo_param)

    def forward(self, x, *params):
        sigma = self.get_sigma(params)
        hyper = self.get_hyperparam()
        return x / (sigma**hyper)

    def backward(self, y, *params):
        sigma = self.get_sigma(params)
        hyper = self.get_hyperparam()
        return y * (sigma**hyper)


def forward_and_grad(transform, constrained_point, constrained_grad, *params):
    if transform is None:
        return constrained_point.copy(), constrained_grad.copy(), 0.0

    unconstrained_point = transform.forward(constrained_point, *params)
    # Redefine forward so that L_op considers these separate path
    backward_log_jac_det = transform.log_jac_det(
        transform.forward(constrained_point, *params), *params
    )
    unconstrained_grad = pytensor.gradient.Lop(
        f=[unconstrained_point, backward_log_jac_det],
        wrt=constrained_point,
        eval_points=[constrained_grad, pt.ones_like(backward_log_jac_det)],
    )

    return unconstrained_point, unconstrained_grad, -backward_log_jac_det

Steps:

Split the freeRVs into two groups:
    - learnable: RVs that have hyper parameters and are bijections
    - constant: Remaining

In the following, treat the constant trafos as usual, so always transform.

Compile functions

- new_transformation:
    Given a point in the untransformed parameter space, initialize the hyper parameters
    and store them in an array.
- transform_position_and_gradient:
    Given a hyper parameter vector and untransformed position and gradient
    as vector, compute the transformed position and gradient as vector. Also
    compute the sum of logdets of the transforms.
- init_from_untransformed:
    Given a hyper parameter vector and an untransformed point, compute
    - The untransformed total logp and gradient.
    Reuse transform_position_and_gradient to get:
        - The transformed point and gradient as vectors and the sum of all logdets.
    
- init_from_transformed:
    Given a hyper parameter vector and a transformed positon, compute the other three.
    Also compute the total logp and the sum of all logdets.

- update_transformation:
    Given a set of points and gradients on the untransformed space, return optimized hyper params.

In [4]:
with pm.Model() as unconstrained_model:
    sigma = pm.HalfNormal("sigma")

    x_hyper = pytensor.shared(np.array(0.5), name="x_hyper")
    trafo = CenterTransform(x_hyper)
    pm.Normal("x", mu=0, sigma=sigma, transform=trafo, shape=(3,))

In [5]:
constrained_point_value = {"sigma": np.exp(-0.3), "x": [-1.0, 0.0, 1.0]}

constrained_model = remove_value_transforms(unconstrained_model)

constrained_model_logp_value = constrained_model.compile_logp()(constrained_point_value)

raveled_dlogp = constrained_model.compile_dlogp()(constrained_point_value)
constrained_point_grad_value = {
    "sigma": np.asarray(raveled_dlogp[0]),
    "x": raveled_dlogp[1:],
}

In [6]:
constrained_point_value, constrained_point_grad_value, constrained_model_logp_value

({'sigma': 0.7408182206817179, 'x': [-1.0, 0.0, 1.0]},
 {'sigma': array(0.12881158),
  'x': array([ 1.8221188, -0.       , -1.8221188])},
 array(-4.17913157))

In [7]:
logp_f = unconstrained_model.compile_logp()
[x_hyper] = logp_f.f.get_shared()

# ip = unconstrained_model.initial_point()
ip = {"sigma_log__": np.array(-0.3), "x_CenterTransform__": np.array([0.0, 0.0, 0.0])}

x_hyper.set_value(np.array(0.5))
print(logp_f(ip))

x_hyper.set_value(np.array(1 - 0.99999))
print(logp_f(ip))

-3.2172261683874277
-3.107015020305758


In [8]:
# Avoid mutating variables in place
um_copy = unconstrained_model.copy()

constrained_points = []  # root variables, created in the loop
constrained_grads = []  # root variables, created in the loop
unconstrained_points = []
unconstrained_grads = []
sum_log_det_jacobians = 0.0
for rv in um_copy.free_RVs:
    transform = um_copy.rvs_to_transforms[rv]
    constrained_point = rv.type(name=rv.name)
    constrained_grad = rv.type(name=f"{rv.name}_grad")
    unconstrained_point, unconstrained_grad, log_det_jacobian = forward_and_grad(
        transform, constrained_point, constrained_grad, *rv.owner.inputs
    )
    unconstrained_point.name = f"{rv.name}_unconstrained"
    unconstrained_grad.name = f"{rv.name}_grad_unconstrained"

    constrained_points.append(constrained_point)
    constrained_grads.append(constrained_grad)
    unconstrained_points.append(unconstrained_point)
    unconstrained_grads.append(unconstrained_grad)
    sum_log_det_jacobians += log_det_jacobian.sum()

# Replace rvs by the constrained_points
fgraph = FunctionGraph(
    outputs=[*unconstrained_points, *unconstrained_grads, sum_log_det_jacobians], clone=False
)
toposort_replace(fgraph, tuple(zip(um_copy.free_RVs, constrained_points)))

# From constrained space to unconstrained
pullback_grads_f = pytensor.function(
    [*constrained_points, *constrained_grads],
    fgraph.outputs,
    # mode=get_mode("FAST_RUN").excluding("fusion"),
)

In [9]:
x_hyper.set_value(np.array(0.5))
pullback_grads_f(*constrained_point_value.values(), *constrained_point_grad_value.values())

[array(-0.3),
 array([-1.20531121,  0.        ,  1.20531121]),
 array(1.52373625),
 array([ 2.19622022, -0.        , -2.19622022]),
 array(0.8602134)]

In [10]:
x_hyper.set_value(np.array(1.0))
pullback_grads_f(*constrained_point_value.values(), *constrained_point_grad_value.values())

[array(-0.3),
 array([-1.24522667,  0.        ,  1.24522667]),
 array(1.52373625),
 array([ 2.26895092, -0.        , -2.26895092]),
 array(0.95795272)]

In [11]:
loss = pt.add(*[pt.sum((g + p) ** 2) for g, p in zip(unconstrained_grads, unconstrained_points)])
loss_grad = pt.grad(loss, wrt=x_hyper)

loss_fn = pytensor.function([*constrained_points, *constrained_grads], [loss, loss_grad])

In [12]:
x_hyper.set_value(np.array(0.5))
loss_fn(*constrained_point_value.values(), *constrained_point_grad_value.values())

[array(3.46133173), array(0.27690036)]

In [13]:
x_hyper.set_value(x_hyper.get_value() - 0.2769 * 10)
loss_fn(*constrained_point_value.values(), *constrained_point_grad_value.values())

[array(2.92748161), array(0.07287526)]

    fn inv_transform_normalize(
        &mut self,
        params: &Self::TransformParams,
        untransformed_position: &Self::Vector,
        untransofrmed_gradient: &Self::Vector,
        transformed_position: &mut Self::Vector,
        transformed_gradient: &mut Self::Vector,
    ) -> Result<f64, Self::LogpErr>;

    fn init_from_untransformed_position(
        &mut self,
        params: &Self::TransformParams,
        untransformed_position: &Self::Vector,
        untransformed_gradient: &mut Self::Vector,
        transformed_position: &mut Self::Vector,
        transformed_gradient: &mut Self::Vector,
    ) -> Result<(f64, f64), Self::LogpErr>;

    fn init_from_transformed_position(
        &mut self,
        params: &Self::TransformParams,
        untransformed_position: &mut Self::Vector,
        untransformed_gradient: &mut Self::Vector,
        transformed_position: &Self::Vector,
        transformed_gradient: &mut Self::Vector,
    ) -> Result<(f64, f64), Self::LogpErr>;

    fn update_transformation<'a, R: rand::Rng + ?Sized>(
        &'a mut self,
        rng: &mut R,
        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
        untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
        params: &'a mut Self::TransformParams,
    ) -> Result<(), Self::LogpErr>;

    fn new_transformation<R: rand::Rng + ?Sized>(
        &mut self,
        rng: &mut R,
        untransformed_position: &Self::Vector,
        untransfogmed_gradient: &Self::Vector,
        chain: u64,
    ) -> Result<Self::TransformParams, Self::LogpErr>;
