In [None]:
# | export
from isssm.kalman import state_conditional_on_signal
from scipy.optimize import root_scalar
from tensorflow_probability.substrates.jax.distributions import Normal
from jax.scipy.optimize import minimize as minimize_jax


def continuous_quantile(cdf, p, x0=0.0):
    return minimize_jax(lambda x: (cdf(x) - p) ** 2, x0, method="BFGS").x


def predict_x(
    signal_samples: Float[Array, "N n+1 p"],
    log_weights: Float[Array, "N"],
    model: PGSSM,
    probs: Float[Array, "k"],
):

    x_smooth, Xi_smooth = vmap(state_conditional_on_signal, (None, 0))(
        model, signal_samples
    )

    var_X_cond_S = vmap(vmap(jnp.diag))(Xi_smooth)
    E_X2_cond_S = var_X_cond_S + x_smooth**2

    E_X_cond_Y = mc_integration(x_smooth, log_weights)
    E_X2_cond_Y = mc_integration(E_X2_cond_S, log_weights)

    var_X_cond_Y = E_X2_cond_Y - E_X_cond_Y**2
    sd_X_cond_Y = jnp.sqrt(var_X_cond_Y)

    def quantile(means, vars, p):
        cdf = lambda x: (
            Normal(means, vars).cdf(x).reshape(-1) * normalize_weights(log_weights)
        ).sum()
        return continuous_quantile(
            cdf,
            p,
            x0=(
                Normal(means, vars).quantile(p).reshape(-1)
                * normalize_weights(log_weights)
            )
            .sum()
            .reshape(-1),
        )

    quantiles = vmap(vmap(vmap(quantile, (1, 1, None)), (1, 1, None)), (None, None, 0))(
        x_smooth, var_X_cond_S, probs
    ).squeeze(-1)
    return E_X_cond_Y, sd_X_cond_Y, quantiles

In [None]:
# | export
from tensorflow_probability.substrates.jax.distributions import LogNormal


def predict_exp_x(
    signal_samples: Float[Array, "N n+1 p"],
    log_weights: Float[Array, "N"],
    model: PGSSM,
    probs: Float[Array, "k"],
):
    x_smooth, Xi_smooth = vmap(state_conditional_on_signal, (None, 0))(
        model, signal_samples
    )

    var_X_cond_S = vmap(vmap(jnp.diag))(Xi_smooth)

    E_expX_cond_S = jnp.exp(x_smooth + 0.5 * var_X_cond_S)
    E_expX2_cond_S = jnp.exp(2 * x_smooth + 2 * var_X_cond_S)

    E_expX_cond_Y = mc_integration(E_expX_cond_S, log_weights)
    E_expX2_cond_Y = mc_integration(E_expX2_cond_S, log_weights)

    var_expX_cond_Y = E_expX2_cond_Y - E_expX_cond_Y**2
    sd_expX_cond_Y = jnp.sqrt(var_expX_cond_Y)

    def quantile(means, vars, p):
        cdf = lambda x: (
            LogNormal(means, vars).cdf(x).reshape(-1) * normalize_weights(log_weights)
        ).sum()
        return continuous_quantile(
            cdf,
            p,
            x0=(
                LogNormal(means, vars).quantile(p).reshape(-1)
                * normalize_weights(log_weights)
            )
            .sum()
            .reshape(-1),
        )

    quantiles = vmap(vmap(vmap(quantile, (1, 1, None)), (1, 1, None)), (None, None, 0))(
        x_smooth, var_X_cond_S, probs
    ).squeeze(-1)
    return E_expX_cond_Y, sd_expX_cond_Y, quantiles

In [None]:
# | export
def predict_y_prime(
    signal_samples: Float[Array, "N n+1 p"],
    log_weights: Float[Array, "N"],
    model: PGSSM,
    probs: Float[Array, "k"],
):

    E_Yprime_cond_S = vmap(lambda s: model.dist(s, model.xi).mean())(signal_samples)
    var_Yprime_cond_S = vmap(lambda s: model.dist(s, model.xi).variance())(
        signal_samples
    )
    E_Yprime2_cond_S = var_Yprime_cond_S + E_Yprime_cond_S**2

    E_Yprime_cond_Y = mc_integration(E_Yprime_cond_S, log_weights)
    E_Yprime2_cond_Y = mc_integration(E_Yprime2_cond_S, log_weights)
    var_Yprime_cond_Y = E_Yprime2_cond_Y - E_Yprime_cond_Y**2
    sd_Yprime_cond_Y = jnp.sqrt(var_Yprime_cond_Y)

    def quantile(signals, xis, p):
        cdf = lambda x: (
            model.dist(signals, xis).cdf(x).reshape(-1) * normalize_weights(log_weights)
        ).sum()
        return continuous_quantile(
            cdf,
            p,
            x0=(
                model.dist(signals, xis).quantile(p).reshape(-1)
                * normalize_weights(log_weights)
            )
            .sum()
            .reshape(-1),
        )

    quantiles = vmap(vmap(vmap(quantile, (1, 1, None)), (1, 1, None)), (None, None, 0))(
        signal_samples, model.xi[None], probs
    ).squeeze(-1)
    return E_Yprime_cond_Y, sd_Yprime_cond_Y, quantiles