In [8]:
import torch
torch.autograd.set_detect_anomaly(True)
torch.pi = torch.acos(torch.zeros(1)) * 2
print(torch.pi)
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

tensor([3.1416])


In [67]:
class LLS():
    def __init__(self, X, y, n_local):
        self.input_dim = X.shape[1]
        self.n_local = n_local
        self.X = X
        self.y = y
        
        # Decide latent locations
        m = KMeans(self.n_local)
        self.X_bar = torch.tensor(m.fit(self.X).cluster_centers_, dtype=torch.float32)
        
        # Initialize parameters
        self.init(0)
        
    def init(self, seed):
        # Initialize learnable parameters
        torch.manual_seed(seed)
        
        self.sigma_f = torch.distributions.Uniform(0.9, 1.1).sample((1,))
        self.sigma_n = torch.distributions.Uniform(0.01, 0.1).sample((1,))
        self.sigma_l_bar = torch.distributions.Uniform(0.9, 1.1).sample((self.input_dim,))
        self.l_bar = torch.distributions.Uniform(0.9, 1.1).sample((self.n_local, self.input_dim))
        
        # Turn on gradient
        self.sigma_f.requires_grad = True
        self.sigma_l_bar.requires_grad = True
        self.l_bar.requires_grad = True
        self.sigma_n.requires_grad = True
        
    def local_kernel(self, x1, x2, i): # tau is scaled distance
        d = x1 - x2.T
        d.where(d == 0, torch.as_tensor(1e-20))
        tau = torch.square((d)/torch.sqrt(torch.tensor(2))/self.sigma_l_bar[i])
        
        return torch.exp(-tau)
    
    def mll(self):
        K = 1.
        mll2 = 1.
        for i in range(self.X.shape[1]):
            x = self.X[:,i].view(-1,1)
            x_bar = self.X_bar[:,i].view(-1,1)
            
            k_star = self.local_kernel(x, x_bar, i)
            k = self.local_kernel(x_bar, x_bar, i)
            k = k + 10**-5*torch.eye(k.shape[0])
#             print(k, torch.det(k))
#             lk = torch.cholesky(k)
            mll2 = mll2 + (0.5*(torch.log(torch.det(k)) + k.shape[0]*torch.log(2*torch.pi)))
            
            k_inv = torch.inverse(k)
            ls = torch.exp(k_star@k_inv@torch.log(self.l_bar[:,i].view(-1,1)))
            
            numer = ls@ls.T
            denom = torch.square(ls + ls.T)/2
            
            tau = torch.square(x-x.T)/denom
            
            K = K*(torch.sqrt(numer/denom) * torch.exp(-tau))
        
#         print(self.sigma_n, self.sigma_f, self.sigma_l_bar, self.l_bar)
#         print(K[:5,:1])
        K = self.sigma_f**2*(K + (torch.square(self.sigma_n)*torch.eye(K.shape[0])))
        
        L = torch.cholesky(K)
        alpha = torch.cholesky_solve(self.y, L)
#         print('alpha', alpha.shape)
        mll1 = 0.5*(self.y.T@alpha + torch.sum(torch.log(L.diag())) + K.shape[0]*torch.log(2*torch.pi))
        return (mll1 + mll2)[0,0]
    
    def fit(self, epochs=100):
        opt = torch.optim.LBFGS
        optimizer = opt([self.sigma_f, self.sigma_n, self.sigma_l_bar, self.l_bar], 0.01, max_iter=epochs)
        losses = []
        def closure():
            optimizer.zero_grad()
            loss = self.mll()
#             print(self.sigma_f, self.sigma_n)
            losses.append(loss.item())
            loss.backward()
            return loss
        optimizer.step(closure)
        plt.plot(losses)
    def predict(self, X_new):
        with torch.no_grad():
            K = 1.
            K_ = 1.
            K__ = 1.
            mll2 = 1.
            for i in range(self.X.shape[1]):
                #######################################################################
                x = self.X[:,i].view(-1,1)
                x_bar = self.X_bar[:,i].view(-1,1)

                k_star = self.local_kernel(x, x_bar, i)
                k = self.local_kernel(x_bar, x_bar, i)
                k += torch.eye(k.shape[0])*10**-5

                k_inv = torch.inverse(k)
                ls = torch.exp(k_star@k_inv@torch.log(self.l_bar[:,i].view(-1,1)))

                numer = ls@ls.T
                denom = torch.square(ls + ls.T)/2

                tau = torch.square(x-x.T)/denom

                K *= (torch.sqrt(numer/denom) * torch.exp(-tau))
                #######################################################################
                x_ = X_new[:,i].view(-1,1)
                k_star = self.local_kernel(x_, x_bar, i)

                ls_ = torch.exp(k_star@k_inv@torch.log(self.l_bar[:,i].view(-1,1)))

                numer = ls_@ls.T
                denom = torch.square(ls_ + ls.T)/2

                tau = torch.square(x_-x.T)/denom

                K_ *= (torch.sqrt(numer/denom) * torch.exp(-tau))
                #######################################################################
                numer = ls_@ls_.T
                denom = torch.square(ls_ + ls_.T)/2

                tau = torch.square(x_-x_.T)/denom

                K__ *= (torch.sqrt(numer/denom) * torch.exp(-tau))

            K += (torch.eye(K.shape[0])*self.sigma_n**2)
            K_inv = torch.inverse(K)
            K__ += (torch.eye(K__.shape[0])*self.sigma_n**2)

            mean = K_@K_inv@self.y
            var = K__ - K_@K_inv@K_.T
            
            self.ls_ = ls_
            self.ls = ls
            
            return mean, var

In [68]:
device = 'cuda'
np.random.seed(0)
N = 500
X = np.sort(np.random.normal(size=(N,1)), axis=0)
y = np.sin(5*X) + np.random.rand(N,1)
# X_new = np.random.rand(15,2)
model = LLS(X, y, 3)
model.fit(200)

plt.figure()
plt.scatter(X, y)


mean, var = model.predict(torch.tensor(X, dtype=torch.float32))

std2 = torch.sqrt(var.diagonal())*2

plt.scatter(X, mean, label='pred')
plt.fill_between(X.ravel(), mean.ravel()-std2, mean.ravel()+std2, alpha=.5)
plt.legend();

ValueError: Type must be a sub-type of ndarray type