In [8]:

import GPy
import numpy as np


In [10]:


def gp_classification_lambda(Z_ik, t):
    T = len(Z_ik)
    X = t.reshape(-1, 1)
    Y = Z_ik.reshape(-1, 1)
    
    kernel = GPy.kern.RBF(input_dim=1, variance=1., lengthscale=1.)
    model = GPy.models.GPClassification(X, Y, kernel=kernel)
    model.optimize()
    
    # Predict the latent function (before sigmoid)
    lambda_ik_posterior, lambda_ik_var = model.predict(X)
    
    # Ensure the output is always a 1D numpy array
    return np.atleast_1d(lambda_ik_posterior.squeeze()), np.atleast_1d(lambda_ik_var.squeeze())

# Example data
T = 10
t = np.arange(T)
Z_ik = np.array([0, 1, 0, 1, 0, 0, 1, 0, 1, 0])  # Observed z_ikt for individual i and category k

lambda_ik_posterior, lambda_ik_var = gp_classification_lambda(Z_ik, t)


AttributeError: 'float' object has no attribute 'squeeze'

In [None]:

# Calculate probabilities
probabilities = 1 / (1 + np.exp(-lambda_ik_posterior))

# Plotting
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))

# Plot the latent function
plt.subplot(1, 2, 1)
plt.plot(t, lambda_ik_posterior, 'b-', label='Posterior Mean')
plt.fill_between(t, 
                 lambda_ik_posterior - 2*np.sqrt(lambda_ik_var),
                 lambda_ik_posterior + 2*np.sqrt(lambda_ik_var),
                 color='b', alpha=0.2, label='95% CI')
plt.scatter(t[Z_ik == 1], [3] * np.sum(Z_ik == 1), c='r', marker='o', s=100, label='Observed 1')
plt.scatter(t[Z_ik == 0], [-3] * np.sum(Z_ik == 0), c='g', marker='x', s=100, label='Observed 0')
plt.title('GP Classification: Latent Function')
plt.xlabel('Time')
plt.ylabel('Latent Function Value')
plt.legend()
plt.ylim(-4, 4)  # Adjust y-axis limits to show all points

# Plot the probabilities
plt.subplot(1, 2, 2)
plt.plot(t, probabilities, 'b-', label='Probability')
plt.scatter(t[Z_ik == 1], [1] * np.sum(Z_ik == 1), c='r', marker='o', s=100, label='Observed 1')
plt.scatter(t[Z_ik == 0], [0] * np.sum(Z_ik == 0), c='g', marker='x', s=100, label='Observed 0')
plt.title('GP Classification: Probabilities')
plt.xlabel('Time')
plt.ylabel('Probability')
plt.ylim(-0.1, 1.1)
plt.legend()

plt.tight_layout()
plt.show()