In [12]:
import numpy as np

import pytensor
import pytensor.tensor as pt
import pymc as pm

from pytensor.tensor.slinalg import cholesky
from pytensor.graph import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter, in2out
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.graph.op import Op, Apply
from pymc.gp.util import stabilize
from pymc.logprob.abstract import _logprob, _get_measurable_outputs, MeasurableVariable

In [16]:
class Cov(Op):

    __props__ = ("fn",)

    def __init__(self, fn):
        self.fn = fn

    def make_node(self, ls):
        ls = pt.as_tensor(ls)
        out = pt.matrix(shape=(None, None))

        return Apply(self, [ls], [out])

    def __call__(self, ls=1.0):
        return super().__call__(ls)

    def perform(self, node, inputs, output_storage):
        raise NotImplementedError("You should convert Cov into a TensorVariable expression!")

    def do_constant_folding(self, fgraph, node):
        return False


class GP(Op):

    __props__ = ("approx",)

    def __init__(self, approx):
        self.approx = approx

    def make_node(self, mean, cov):
        mean = pt.as_tensor(mean)
        cov = pt.as_tensor(cov)

        if not (cov.owner and isinstance(cov.owner.op, Cov)):
            raise ValueError("Second argument should be a Cov output.")

        out = pt.vector(shape=(None,))

        return Apply(self, [mean, cov], [out])

    def perform(self, node, inputs, output_storage):
        raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.")

    def do_constant_folding(self, fgraph, node):
        return False


class PriorFromGP(Op):
    """This Op will be replaced by the right MvNormal."""

    def make_node(self, gp, x, rng):
        gp = pt.as_tensor(gp)
        if not (gp.owner and isinstance(gp.owner.op, GP)):
            raise ValueError("First argument should be a GP output.")

        # TODO: Assert RNG has the right type
        x = pt.as_tensor(x)
        out = x.type()

        return Apply(self, [gp, x, rng], [out])

    def __call__(self, gp, x, rng=None):
        if rng is None:
            rng = pytensor.shared(np.random.default_rng())
        return super().__call__(gp, x, rng)

    def perform(self, node, inputs, output_storage):
        raise NotImplementedError("You should convert PriorFromGP into a MvNormal!")

    def do_constant_folding(self, fgraph, node):
        return False


cov_op = Cov(fn=pm.gp.cov.ExpQuad)
gp_op = GP("vanilla")
# SymbolicRandomVariable.register(type(gp_op))
prior_from_gp = PriorFromGP()

MeasurableVariable.register(type(prior_from_gp))


@_get_measurable_outputs.register(type(prior_from_gp))
def gp_measurable_outputs(op, node):
    return node.outputs

In [17]:
mean = pt.vector("mean")
x = pt.vector("x", shape=(50,))
ls = pt.scalar("ls")

cov = cov_op(ls)
gp = gp_op(mean, cov)
f = prior_from_gp(gp, x)
pytensor.dprint(f, print_type=True)

PriorFromGP [id A] <TensorType(float64, (50,))>
 |GP{approx='vanilla'} [id B] <TensorType(float64, (?,))>
 | |mean [id C] <TensorType(float64, (?,))>
 | |Cov{fn=<class 'pymc.gp.cov.ExpQuad'>} [id D] <TensorType(float64, (?, ?))>
 |   |ls [id E] <TensorType(float64, ())>
 |x [id F] <TensorType(float64, (50,))>
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FE619D9DD20>) [id G] <RandomGeneratorType>


<ipykernel.iostream.OutStream at 0x7fe6947d53c0>

In [18]:
from pymc.logprob.abstract import get_measurable_outputs

get_measurable_outputs(f.owner.op, f.owner)

[PriorFromGP.0]

In [28]:
# You can only run once
@register_canonicalize
@node_rewriter(tracks=[PriorFromGP])
def prior_from_gp_to_mvnormal(fgraph: FunctionGraph, node: Apply):
    out = node.outputs[0]
    gp, X, rng = node.inputs
    # TODO: Check GP is still a GP Op
    mean, cov = gp.owner.inputs

    if gp.owner.op.approx != "vanilla":
        return False

    # Materialize cov
    ls = cov.owner.inputs[0]
    cov = cov.owner.op.fn(input_dim=1, ls=ls).full(X[:, None])

    size = pt.shape(X)[0]
    fgraph.add_input(rng)

    # TODO: Give names
    L = cholesky(stabilize(cov))
    #     L.name = "L"
    v = pm.Normal.dist(0, 1, size=size, rng=rng)
    f = mean + pt.dot(L, v)

    return [f]

In [29]:
prior_from_gp_to_mvnormal

FromFunctionNodeRewriter(<function prior_from_gp_to_mvnormal at 0x7fe617259510>, [<class '__main__.PriorFromGP'>], ())

In [30]:
fg = FunctionGraph(outputs=[f], clone=False)
[out] = prior_from_gp_to_mvnormal.transform(fg, fg.outputs[0].owner)
# pytensor.dprint(out)

In [67]:
@_logprob.register(PriorFromGP)
def prior_gp_logprob(op, values, gp, X, rng, **kwargs):
    [value] = values

    # TODO: Check GP is still a GP Op
    mean, cov = gp.owner.inputs

    if gp.owner.op.approx != "vanilla":
        raise NotImplementedError()

    # Materialize cov
    ls = cov.owner.inputs[0]
    cov = cov.owner.op.fn(input_dim=1, ls=ls).full(X[:, None])

    f = pm.MvNormal.dist(mu=mean, cov=stabilize(cov))
    return pm.logp(f, value)

In [68]:
with pm.Model() as m:
    mean = pt.zeros(())
    x = pm.ConstantData("x", np.linspace(0, 10, 20))
    ls = pm.Gamma("ls", alpha=4, beta=1)

    cov = cov_op(ls)
    gp = gp_op(mean, cov)
    f = prior_from_gp(gp, x)

    m.register_rv(f, name="f", initval=np.zeros(20))

m.basic_RVs

[ls ~ Gamma(4, f()), f]

In [69]:
ip = m.initial_point()
ip

{'ls_log__': array(1.38629436),
 'f': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])}

In [70]:
m.compile_logp()(m.initial_point())

array(78.72276943)

In [71]:
with m:
    idata = pm.sample_prior_predictive()

Sampling: [f, ls]


In [72]:
# idata.prior["f"].mean(("chain", "draw"))

In [73]:
pm.logp(f, np.ones(20))

Check{posdef}.0

In [74]:
m.logp()

__logp

In [None]:
with m:
    idata = pm.sample()

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [ls, f]


In [13]:
# pytensor.dprint(out)

In [14]:
# out.eval({mean: np.ones(50), x: np.linspace(0, 10, 50), ls: 1.0})