In [10]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_spd_matrix

In [11]:
DIM = 200
PRED_DIM = 10
LENGTH_SCALE = 3

def data_generator(x, noise=True):
    return np.sin(x/10)*x/10 + np.random.normal(0, .1, len(x))*noise

In [12]:
def squared_exp_kernel(x, x_dash, length_scale):
    return length_scale * np.exp(-(x-x_dash)**2/(2*length_scale**2))

def gen_cov_matrix(dim, kernel, length_scale):
    
    cov_matrix = np.zeros((dim, dim))

    for i in range(dim):
        for j in range(dim):
            cov_matrix[i][j] = kernel(i, j, length_scale)
            
    return cov_matrix

cov_matrix = gen_cov_matrix(DIM, squared_exp_kernel, LENGTH_SCALE)

To generate values of a multivariate normal distribution conditioned on the values of the distribution being known in several dimensions, we need to split up the covariance matrix $\Sigma$ into covariance between the training and testing points. We obtain:

$$\Sigma = \begin{bmatrix} \Sigma_{11} & \Sigma_{12} \\ \Sigma_{21} & \Sigma_{22} \end{bmatrix}$$

$\Sigma_{12}$ refers to the covariance matrix between training points. $\Sigma_{12}$ and $\Sigma_{21}$ refer to the covariance between training and test points and test and training points, respectively. $\Sigma_{22}$ refers to the covariance matrix between test points.

In [13]:
total_idx = np.arange(0, DIM)
train_idx = np.sort(np.random.choice(total_idx, DIM-PRED_DIM, replace=False))
pred_idx = np.array(list(set(total_idx) - set(train_idx)))

train_data = data_generator(train_idx)
test_data = data_generator(pred_idx)

sigma_11 = cov_matrix[pred_idx][:,pred_idx]
sigma_12 = cov_matrix[pred_idx][:,train_idx]
sigma_21 = cov_matrix[train_idx][:,pred_idx]
sigma_22 = cov_matrix[train_idx][:,train_idx]

class GaussianProcess():
    def __init__(self, kernel, length_scale):
        self.kernel = kernel
        self.length_scale = length_scale
        self.sigma_22 = None
        self.sigma_22_inv = None
        self.X_train = None
        self.y_train = None
        
    def get_cov_matrix(self, X_1, X_2):
        dim = len(y)
        cov_matrix = np.zeros((dim, dim))

        for i in range(dim):
            for j in range(dim):
                cov_matrix[i][j] = self.kernel(X_1[i], X_2[j], self.length_scale)

        return cov_matrix
        
    def fit(self, X, y):
        self.sigma_22 = self.get_cov_matrix(X, X)
        self.sigma_22_inv = np.linalg.pinv(self.sigma_22)
        self.X_train = X
        self.y_train = y
        
    def predict(self, X):
        sigma_11 = self.get_cov_matrix(X, X)
        sigma_12 = self.get_cov_matrix(X, self.X_train)
        sigma_21 = sigma_12.T
        
        sigma_12_sigma_22_inv = np.matmul(sigma_12, self.sigma_22_inv)
        
        predicted_means = np.matmul(sigma_12_sigma_22_inv, self.y_train.reshape(-1, 1))
        pred_cov_matrix = sigma_11 - np.matmul(np.matmul(sigma_12, self.sigma_22_inv), sigma_21)
        
        pred_samples = []
        
        for i in range(30):
            preds = np.random.multivariate_normal(predicted_means.reshape(-1), pred_cov_matrix)
            pred_samples.append(preds)

        pred_std = np.std(pred_samples, axis=0)
        
        return predicted_means, pred_std, pred_samples

In [14]:
gp = GaussianProcess(squared_exp_kernel, LENGTH_SCALE)

In [15]:
x = np.linspace(0,10, 100)
y = data_generator(x)

x_test = np.sort(np.random.uniform(0, 10, 100))
y_test = data_generator(x_test)

In [16]:
gp.fit(x_test, y_test)

train_preds, train_std, train_pred_samples, = gp.predict(x)
test_preds, test_std, test_pred_samples = gp.predict(x_test)



In [17]:
plt.plot(x_test, test_preds, label="Predicted curve")
#plt.fill_between(x_test, test_preds.reshape(-1)-1.96*test_std, test_preds.reshape(-1)+1.96*test_std, alpha=0.3)

plt.plot(np.linspace(0, 10, 1000), data_generator(np.linspace(0, 10, 1000), False), label="f(x)", linestyle='--', color='red')
plt.scatter(x_test, y_test, s=4, c='black', label="Ground truths")
plt.legend()

for i in range(len(test_pred_samples)):
    plt.plot(x_test, test_pred_samples[i], alpha=0.2)

plt.show()