In [None]:
import torch
import gpytorch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

# Define the GP model for X(t) with a Matérn kernel
class GPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, nu=2.5):
        super(GPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.MaternKernel(nu=nu)
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# Helper function to calculate kernel derivatives for Xdot
def kernel_derivative(kernel, x1, x2):
    x1 = x1.unsqueeze(-1) if x1.dim() == 1 else x1
    x2 = x2.unsqueeze(-1) if x2.dim() == 1 else x2
    matern_kernel = kernel.base_kernel
    lengthscale = matern_kernel.lengthscale.item()
    
    # Compute k(x1, x2) - Matern kernel without scaling
    dist = torch.cdist(x1, x2, p=2)
    nu = matern_kernel.nu
    if nu == 0.5:
        k_xx = torch.exp(-dist / lengthscale)
    elif nu == 1.5:
        k_xx = (1 + dist / lengthscale) * torch.exp(-dist / lengthscale)
    elif nu == 2.5:
        k_xx = (1 + (dist / lengthscale) + (dist**2) / (3 * lengthscale**2)) * torch.exp(-dist / lengthscale)
    else:
        raise NotImplementedError("Only Matérn kernels with ν=0.5, 1.5, and 2.5 are implemented")

    # Derivative w.r.t x2
    k_xdot = -(x1 - x2.transpose(0, 1)) * k_xx / (lengthscale**2)
    
    # Second derivative for variance of the derivative
    k_xdotxdot = (1 - (dist.pow(2) / (lengthscale**2))) * k_xx / (lengthscale**2)
    
    return k_xx, k_xdot, k_xdotxdot

# Define the joint posterior model with non-Gaussian prior on theta and parameter mu for lengthscale
def joint_posterior(train_x, train_y, test_x_tau, test_xdot, y_tau, f_t):
    # Non-Gaussian prior on theta (Gamma prior)
    theta = pyro.sample("theta", dist.Gamma(2.0, 1.0))  # shape=2.0, rate=1.0
    # Prior on mu (for lengthscale of the Matern kernel)
    mu = pyro.sample("mu", dist.Gamma(2.0, 1.0))  # shape=2.0, rate=1.0
    
    # Define likelihood and model with given theta and mu
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = GPModel(train_x, train_y, likelihood)
    model.covar_module.base_kernel.lengthscale = mu  # Set the lengthscale parameter

    # Train the GP model with fixed theta and mu (manual training for demonstration)
    model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    # Training loop
    for i in range(50):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()

    model.eval()
    likelihood.eval()

    # Part 1: Compute p(X(I) = x(I)) as log probability
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        observed_pred = model(train_x)
        marginal_log_prob_XI = mll(observed_pred, train_y)
    
    # Part 2: Compute p(Y(τ) = y(τ) | X(I) = x(I))
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictive_dist_y = model(test_x_tau)
        log_prob_Y_given_XI = predictive_dist_y.log_prob(y_tau)
    
    # Part 3: Compute p(Xdot(t) = f(t) | X(I) = x(I), Y(τ) = y(τ))
    K_xx, K_xdot, K_xdotxdot = kernel_derivative(model.covar_module, train_x, test_xdot)
    K_xx_inv = torch.inverse(K_xx + 1e-6 * torch.eye(K_xx.size(0)))  # Regularize for stability

    # Conditional mean and variance for Xdot given X(I) and Y(τ)
    conditional_mean_Xdot = K_xdot.T @ K_xx_inv @ (train_y - model.mean_module(train_x))
    conditional_variance_Xdot = K_xdotxdot - K_xdot.T @ K_xx_inv @ K_xdot
    log_prob_Xdot_given_XI_Ytau = -0.5 * ((f_t - conditional_mean_Xdot)**2 / conditional_variance_Xdot).sum()
    
    # Sum of log probabilities
    joint_log_prob = marginal_log_prob_XI + log_prob_Y_given_XI + log_prob_Xdot_given_XI_Ytau

    # Return the negative log posterior (for optimization/sampling)
    return -joint_log_prob

# Define the MCMC sampling function with NUTS
def run_mcmc(train_x, train_y, test_x_tau, test_xdot, y_tau, f_t, num_samples=1000, warmup_steps=200):
    nuts_kernel = NUTS(joint_posterior)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
    
    mcmc.run(train_x, train_y, test_x_tau, test_xdot, y_tau, f_t)
    posterior_samples = mcmc.get_samples()
    
    return posterior_samples

# Define observation points and values
train_x = torch.linspace(0, 10, 100)  # I: Observation points for X(t)
train_y = torch.sin(train_x)  # Observed values x(I)
test_x_tau = torch.tensor([5.5, 7.5])  # τ: New points for Y(τ)
y_tau = torch.sin(test_x_tau)  # Observed values for Y(τ)
test_xdot = torch.tensor([3.0])  # t: Point for Xdot(t)
f_t = torch.tensor([0.5])  # Substitute with actual observed derivative

# Run MCMC to sample from the joint posterior
posterior_samples = run_mcmc(train_x, train_y, test_x_tau, test_xdot, y_tau, f_t)

# Output samples
print("Posterior Samples for Theta:")
print(posterior_samples['theta'])
print("Posterior Samples for Mu (lengthscale):")
print(posterior_samples['mu'])
