In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import scipy.io
import time
from scipy.optimize import minimize
from tqdm import tqdm

In [2]:
import torch
# torch.set_default_dtype(torch.float64)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.is_available()

cuda:0


True

In [3]:
def display_progress(i, N, tstart):
    # display time elapsed and time remaining within for loop
    #
    # [T, Trem] = display_progress(i, N, tstart)
    #
    # INPUTS
    # i:    current iteration
    # N:    total number of iterations
    # tstart: time at start (if empty, then function calls 'tic' with no argument)
    # display (default = true): display time elapsed etc.
    #
    # OUTPUTS
    # T:    time elapsed
    # Trem: time remaining

    T = time.time() - tstart
    Trem = T*(N-(i+1) + 1)/(i+1)

    # DEFAULT('display', true);
    print(f'percent complete {(i)/N*100:.0f}%')
    if i>0:
        print(f'{T/60:.1f} mins passed, {Trem/60:.1f} mins remaining')
    return T, Trem

#Note that in python, we use time.time() to get the current time instead of tic and toc. Also, the nargin parameter is not 
#used in python, so it's removed from the code. Additionally, the fprintf is replaced with print(f'{}') in python.

In [4]:
def Estep(m, f, r, A, K, a):
# perform one update of m and V

    g = A*a @ (r.T - f)   #shape(ntilde,1)
    G = A**2 *a @ (a.T* f)    #shape(ntilde,ntilde)
    V = np.linalg.solve(np.eye(K.shape[0])+K @ G, K)
    #V = np.linalg.solve(V,( (1-Alpha)*K+Alpha*V+Alpha*K@G@V ) )@K
    V = (V + V.T) / 2 + 1e-15 * np.eye(K.shape[0]) # make sure positive definite!
    m = V @ (G @ m + g)  #shape(250,1)
    #m = m - Alpha*((eye(ntilde)+K*G)\(m-K*g))
    
    return m, V

In [5]:
def localsmoothkern(theta, n):  #function to calculate C given a set of hyperparameters
    # INPUT:
    # theta, 4 hyperparameters determinine shape and position of RF
    # idx selects dimension where C>>0 
    # dC derivative of C with respect to each hyperparameter

    # hyperparameters
    rho = np.exp(theta[0])    # how smooth is RF?
    x0  = theta[1]         # spatial location, xcoord
    y0  = theta[2]         # spatial location, ycoord
    beta = np.exp(theta[3])   # how wide is RF?

    # spatial localised prior
    ycord, xcord = np.meshgrid(np.linspace(-1, 1, int(np.sqrt(n))), np.linspace(-1, 1, int(np.sqrt(n)))) #a grid of 108x108 points between -1 and 1
    xcord = xcord.flatten()
    ycord = ycord.flatten()
    logf = -0.5*beta*((xcord- x0)**2+(ycord- y0)**2)
    f = np.exp(logf)    #aplha_local in the paper
    
    # select only stimulus dimensions near RF centre
    idx = f>0.001
    f = f[idx]
    logf = logf[idx]
    xcord = xcord[idx]
    ycord = ycord[idx]
    n = len(f)

    # smooth prior
    logK0 = -0.5*rho*((xcord - xcord[:, np.newaxis])**2+(ycord - ycord[:, np.newaxis])**2)
    K0 = np.exp(logK0)   #cij smooth in the paper

    # multiply smoothness prior and spatial localised prior
    C = f[:, np.newaxis]*K0*f[np.newaxis, :]
    #print(f.shape, K0.shape, C.shape)

    # derivative with repect to different hyperparameters 
    dC = np.zeros((C.shape[0], C.shape[1], len(theta)))
# dC : array-like, optional
#         [nx, nx, ntheta], derivative of C with respect to hyperparameters
    dC[:, :, 0] = logK0*C
    dC[:, :, 1] = beta*C*(xcord[:, np.newaxis]+xcord[np.newaxis, :]- 2*x0)
    dC[:, :, 2] = beta*C*(ycord[:, np.newaxis]+ycord[np.newaxis, :]- 2*y0)
    dC[:, :, 3] = C*(logf[:, np.newaxis] + logf[np.newaxis, :])

    # make sure that K is positive definite
    C = (C+C.T)/2 #+ 1e-10*np.eye(n)
    #print('n:{}'.format(n))
    #print(xcord.shape, ycord.shape, f.shape, logf.shape)
    return C, idx, dC          #the matrix will be nxn and the value n depends on the threshold on f. n=500 for thresh=0.001


In [6]:
def updateA(A, lambda0, mstar0, Vstar0, r, nit=1000, Alpha=0.25):
    psi = [lambda0, A]
    count = 0 
    flag = True
    L = np.zeros((int(nit), 1))
    
    while flag:
        count  = count+1
        f = np.exp(psi[1]*mstar0 + 0.5*Vstar0*psi[1]**2+psi[0])  #shape (3190,1)
        L[count-1] = np.sum(f) - r@(mstar0*psi[1]+psi[0])
        dlambda_A = mstar0 + Vstar0*psi[1]             #shape (3190,1)
        g = np.array([np.sum(f-r.T),  (f.T@dlambda_A - r@mstar0)[0,0] ])
        H = np.array([[np.sum(f), (f.T@dlambda_A)[0,0]], [(f.T@dlambda_A)[0,0], (f.T@Vstar0 + (dlambda_A*f).T@dlambda_A)[0,0]]])
        
        psi = psi - Alpha*np.linalg.solve(H,g)

        if np.sum(abs(g))<1e-6 or count>nit:
            flag = False
        
    
    A = psi[1] 
    lambda0 = psi[0]
    L = L[0:count]
    return A, lambda0, L

In [7]:
def acoskern(theta, x1, x2, C=None, dC=None, diag=False):
    """
    arc cosine covariance function

    Parameters
    ----------
    theta : array-like
        hyperparameters, first 2 are for acos kernel, the rest are for C
    x1 : array-like
        [nx, n1] matrix, first input   #(500,ntilde or r.shape)
    x2 : array-like
        [nx, n2] matrix, second input  #(500, ntilde or r.shape)
    C : array-like, optional
        [nx, nx] matrix, smooth local covariance matrix
    dC : array-like, optional
        [nx, nx, ntheta], derivative of C with respect to hyperparameters
    diag : bool, optional
        if diag ==1, then just return diagonal elements of covariance matrix

    Returns
    -------
    K : array-like
        [n1, n2] kernel
    dK : array-like
        [n1, n2, ntheta], derivative of kernel with respect to theta
    """
    import numpy as np
    from numpy import pi

    if diag is None:
        diag = False

    if C is None:
        C = np.eye(x1.shape[0])

    n1 = x1.shape[1]

    sigmab = np.exp(theta[0])

    if not diag:
        n2 = x2.shape[1]
        
        X1 = np.sqrt(np.sum(x1*(C @ x1), axis=0) + sigmab ** 2) # np.sum(x1*(C@x1), axis=0) is the same as x1.T @ C @ x1 #shape(n1)
        X2 = np.sqrt(np.sum(x2*(C @ x2), axis=0) + sigmab ** 2) # shape(n2,)
        #print((C @ x2)[19,0])
        X1X2 = np.outer(X1,X2)               #shape(n1,n2)
        x1x2 = x1.T @ C @ x2 + sigmab ** 2   
        
        arg = np.clip(x1x2 / (X1X2 + 1e-9), -1, 1)

        theta = np.arccos(arg)

        J = (np.sqrt(1 - arg ** 2) + np.pi * arg - theta * arg) / np.pi   #shape(n1,n2)

        K = X1X2 * J       #shape(n1, n2)
        
        #print('X1:{},X2:{}, J:{}, X1X2:{}, K:{} '.format(X1.shape, X2.shape, J.shape, X1X2.shape, K.shape))
        if dC is not None:
            dK = np.zeros((n1, n2, dC.shape[2] + 1))
#     dK : array-like
#         [n1, n2, ntheta], derivative of kernel with respect to theta

            dX1X2 = sigmab ** 2 * (X2 / X1[:, np.newaxis] + X1[:, np.newaxis] / X2) #shape same as X1X2

            darg = (2 * sigmab ** 2 - arg * dX1X2) / X1X2    #shape same as arg

            dJ = -(theta - np.pi) * darg / np.pi           #shapesame as J

            dK[:, :, 0] =  (X1X2 * dJ + dX1X2 * J)      #shape same as K
                                   
            for j in range(1, dC.shape[2] + 1):

                dX1 = 0.5*np.sum(x1*np.dot(dC[:, :, j-1], x1), axis=0)/X1  #shape(n1,)
                dX2 = 0.5*np.sum(x2*np.dot(dC[:, :, j-1], x2), axis=0)/X2  #shape(n2,)
                
                dX1X2 = dX1[:, np.newaxis]*X2 + X1[:, np.newaxis]*dX2

                darg = (np.dot(x1.T, np.dot(dC[:, :, j-1], x2)) - arg*dX1X2)/X1X2

                dJ =  -(theta-np.pi)*darg/np.pi

                dK[:, :, j] = X1X2*dJ + dX1X2*J
            
            #print('dX1X2:{}, darg:{}, dJ:{}, dK[:,:,1]:{}'.format(dX1X2.shape,darg.shape, dJ.shape, (A * (X1X2 * dJ + dX1X2 * J)).shape))
            # make sure that K is positive definite
        if n1==n2:
            K = (K+K.T)/2 + 1e-15*np.eye(n1)

    else:

        # return just diagonal of covariance
        K = np.sum(x1*np.dot(C, x1), axis=0)[:, np.newaxis]+sigmab**2


        # derivative of covariance
        if dC is not None:
            dK = np.zeros((n1,1, dC.shape[2] + 1))

            dK[:,:, 0] = 2*sigmab**2*np.ones((n1, 1))

            for j in range(1, dC.shape[2] + 1):
                dK[:,:, j] = np.sum(x1 * np.dot(dC[:, :, j-1], x1), axis=0)[:, np.newaxis]

            K += 1e-15 * np.eye(n1, 1)

    if dC is not None: return K, dK    #shape(n1,n2), shape(n1,n2,6) 
    else: return K    #shape (n1,n2)

In [8]:
#compute lsta and connected functions
def compute_f(m, V, theta, A, lambda0, xtilde, x, kernfun=None):
    if kernfun is None:
        kernfun = acoskern

    nx = x.shape[0]

    # get C and index
    C, idx,_ = localsmoothkern(theta[1:], nx)
    #print(C[31,23])
    # compute covariance matrices
    K = kernfun(theta, xtilde[idx,:], xtilde[idx, :], C)     #shape(ntilde, ntilde)
    k = kernfun(theta, xtilde[idx, :], x[idx, :], C)         #shape(ntilde,1)
    kstar = kernfun(theta, x[idx, :], x[idx, :], C, None, True)  #shape(1,1)
    #print(K[2,4], k[4],kstar)
    #print(xtilde[5,33])
    
    # compute a
    a = np.linalg.solve(K, k)    #shape (ntilde,1)
    #a = K/k
#     print(a[4])

    # compute mstar
    mstar = A*a.T@m + lambda0
#     print(mstar)
    # compute Vstar
    Vstar = A**2*(kstar + np.sum(-k*a + a*V@a, axis=0)[:, np.newaxis])
#     print(Vstar)
    # compute fmean and sigma2_f
    fmean = np.exp(mstar + 0.5*Vstar).T             #shape(1,1)
    sigma2_f = (np.exp(Vstar**2) - 1)*fmean**2      #shape(1,1) 
    #print(fmean, sigma2_f)
    return fmean, sigma2_f


In [9]:
def varGP(x, r, Nestep=50, Nmstep=20, Display=2, MaxIter=200, ntilde=250, *kwargs):
    # learn gaussian process model (GP) of neural responses
    # 
    # INPUTS
    # x = [nx, n], stimulus
    # r = [1, n],  spike counts
    
    # OPTIONAL INPUTS
    # Nestep: number of  of iterations in E-step (updating m, and V)
    # Nmstep: number of steps in M-step (updating theta )
    # Display: 0/1/2, display nothing/progress/plots
    # ntilde: number of inducing data points (how accurate is approximation)
    # theta: initial hyperparameters
    # lb, ub: lower and upper bound for theta
    # kernfun: kernel function (acoskern is default)
    # m, V: initial mean and variance of variational distribution, q
    # 
    # OUTPUTS
    # theta, hyperparameters
    # L, loss function during learning
    # m, V: mean and variance of variational distribution q(lambda) = N(m, V)
    # xtilde: set of 'inducing' datapoints
    
    # optional arguments (described above)
    opts = { 
            'theta': [ 0, 5, 0, 0, 5.5], # sigma_b, log(rho), eps_0_x, eps_0_y, log(beta)
            'A':1e-4,
            'lambda0':-1,
            'lb': [ -float("inf"), -float("inf"), -1, -1, 4 ], 
            'ub': [float("inf"), float("inf"), 1, 1, float("inf")], 
            'kernfun': acoskern, 
            'm': [], 
            'V': [], 
            'xtilde': []}
    opts.update(kwargs)

    #r = r.flatten()          # reformat spike count  

    nx, n = x.shape     # number of data points
    
    # inducing points (chosen at random from data points)
    if not opts['xtilde']:
        ntilde = min(n, ntilde)
        xtilde = x[:, np.random.permutation(n)[:ntilde]]+1e-6*np.random.randn(nx, ntilde) #shape(11664, ntilde)
    else:
        xtilde = opts['xtilde']
        ntilde = xtilde.shape[1]

    # initialize mean and covariance of q(xtilde), to prior
    if not opts['m']:
        m = np.zeros((ntilde, 1))    #shape (ntilde, 1)
    else:
        m = opts['m']
    
    if not opts['V']:
        C, idx, dC = localsmoothkern(opts['theta'][1:], nx) #the C is calculated only for pixels around epsilon_0
        V = opts['kernfun'](opts['theta'], xtilde[idx, :], xtilde[idx, :], C=C)  #(ntilde, ntilde)
    else:
        V = opts['V']
    #print(log_det(C))  
    # track loss and parameters
    L = np.zeros((MaxIter, 1))         # log marginal
    theta_track = np.zeros((5, MaxIter)) # track hyperparamers
    A_track = np.zeros((MaxIter, 1))       # track hyperparamers
    lambda0_track = np.zeros((MaxIter, 1))   # track hyperparamers
    
    # options for optimization in M-step
    options = { 'maxiter': Nmstep}    

    theta=opts['theta']
    A=opts['A']
    lambda0=opts['lambda0']
    t0 = time.time()
    
    for Iteration in tqdm(range(MaxIter)):
#         if Display>0:
#             display_progress(Iteration, MaxIter, t0)
            
#%%%%%%%%%%%%%% E- step  (update q = N(m, V))$$$$$$$$$$$$$

        # local smooth covariance function
        C, idx, dC = localsmoothkern(theta[1:], nx)

        # pre-compute kernel
        K = opts['kernfun'](theta, xtilde[idx, :], xtilde[idx, :], C=C) #shape (ntilde,ntilde) #K_ij in the notes
        k = opts['kernfun'](theta, xtilde[idx, :], x[idx, :], C=C)     #shape(ntilde,3190)     #k_i  in the notes
        kstar = opts['kernfun'](theta, x[idx, :], [], C=C, diag=True)   #shape(3190,1)          #k_ii in the notes
        a = np.linalg.solve(K, k)                                       #shape(ntilde,3190)
        #print(K.shape,k.shape,kstar.shape,a.shape)

        for i in range(Nestep):
            mstar0 = a.T@m     #mean of lambda(x) #shape (3190,1)
            Vstar0 = kstar + np.sum(-k*a + a*(V@a), 0)[:, np.newaxis]  # variance of lambda(x) #shape (3190,1)
            A, lambda0, _ = updateA(A, lambda0, mstar0, Vstar0, r, 1e4)
        
            f = np.exp(A*mstar0+ 0.5*Vstar0*A**2 + lambda0); # <f(x)> mean firing rate
            
            m, V = Estep(m, f, r, A, K, a)
            
            #print('{}:{}:  {},{}'.format(Iteration,i,m.shape,V.shape))
        #print(logmarginal(m, V, theta, xtilde, x, r, opts['kernfun']))


#In python, you can use the scipy.optimize library to perform gradient descent. Specifically, you can use the minimize 
#function with the "BFGS" algorithm to optimize the log marginal function. The equivalent code would be:

#%%%%%%%%%%%%%% M- step (update hyperparameters, theta) $$$$$$$$$$$$$
        bnds =((opts['lb'][0], opts['ub'][0]), (opts['lb'][1], opts['ub'][1]),(opts['lb'][2], opts['ub'][2]),(opts['lb'][3], opts['ub'][3]),(opts['lb'][4], opts['ub'][4]))
        # gradient descent 
        result = minimize(lambda theta: logmarginal(m, V, theta, A, lambda0, xtilde, x, r, acoskern_1),
                          theta,args=(), method='L-BFGS-B', jac=True, bounds=bnds, options=options)
        
        theta = result.x
        L[Iteration] = result.fun
        theta_track[:, Iteration] = theta
        A_track[Iteration] = A
        lambda0_track[Iteration] = lambda0
#Note: I have replaced minConf_TMP with minimize function from scipy

    if Display > 1:
        plt.figure(figsize=(10,10))
        plt.subplot(2, 2, 1)
        plt.plot(-L[1:])
        plt.title('loss function')
        plt.subplot(2, 2, 2)
        for hyp in range(5):
            plt.plot(theta_track[hyp, :]- theta_track[hyp, 0])
        plt.title('hyperparameters')
        plt.subplot(2, 2, 3)
        plt.plot(A_track[1:])
        plt.title('A')
        plt.subplot(2, 2, 4)
        plt.plot(lambda0_track[1:])
        plt.title('lambda0')
        plt.savefig(r'/media/samuele/Samuele_02/GP from MAtlab to python/Python/trial_3.png')
        plt.close()
        print('Done!')
            
    return theta_track, A_track, lambda0_track, m, V, xtilde, L 

In [12]:
#varGP and connected functions
def log_det(M):
    M.to(torch.float64)
    try:
        L=torch.linalg.cholesky(M, upper=True)
        return 2*torch.sum(torch.log(torch.diag(L)))
    except:
        L = torch.linalg.eig(M)
        L = L[L>1e-6]
        return torch.sum(torch.log(L)) 

def logmarginal(m, V, theta, A, lambda0, xtilde, x, r, kernfun):
    # compute negative log-evidence
    nx = x.shape[0]

    # precompute kernel terms
    C, idx, dC = localsmoothkern(theta[1:], nx)
    
#     Sigma=torch.from_numpy(Sigma).to(device)
#     dSigma=torch.from_numpy(dSigma).to(device)
    C=torch.from_numpy(C).to(device)
    dC=torch.from_numpy(dC).to(device)
    m=torch.from_numpy(m).to(torch.float64).to(device)
    V=torch.from_numpy(V).to(torch.float64).to(device)
    A=torch.from_numpy(np.array(A)).to(device)
    lambda0=torch.from_numpy(np.array(lambda0)).to(device)
    r=torch.from_numpy(r.astype(np.float64)).to(device)
    theta=torch.from_numpy(theta).to(device)
    xtilde=torch.from_numpy(xtilde).to(device)
    x=torch.from_numpy(x).to(device)
    
    Sigma, dSigma = kernfun(theta, xtilde[idx, :], xtilde[idx, :], C, dC)
    ki, dki = kernfun(theta, xtilde[idx, :], x[idx, :], C, dC)
    kstar, dkstar = kernfun(theta, x[idx, :], [], C, dC, True)
#     print(dSigma)
#     ki=torch.from_numpy(ki).to(device)
#     kstar=torch.from_numpy(kstar).to(device)
#     dki=torch.from_numpy(dki).to(device)
#     dkstar=torch.from_numpy(dkstar).to(device)
    
    # compute log-likelihood 
    lkhd, dlkhd = lfun(m, V, A, lambda0, r, Sigma, ki, kstar, dSigma, dki, dkstar)
    
    # compute KL divergence
    Dkl, dDkl = KLdiv(m, V, Sigma, dSigma)
    dL = dDkl - dlkhd

    # compute total loss
    L = Dkl - lkhd        #shape (1,1)
    #print(Dkl, lkhd)
    return L[0,0].detach().cpu().numpy().astype(np.float64), dL.detach().cpu().numpy().astype(np.float64)

def lfun(m, V, A, lambda0, r, Sigma, ki, kstar, dSigma, dki, dkstar):
    # log-likelihood

    # compute mean and variance
    a = torch.linalg.inv(Sigma)@ ki  #shape(ntilde, 3190)

    mstar = A*a.T @ m + lambda0           #shape(3190,1)
    Vstar = A**2*(kstar + torch.sum(-ki * a + a * (V @ a), axis=0)[:, None])  #shape(3190,1)

    # mean firing rate
    f = torch.exp(mstar + 0.5 * Vstar)

    # log likelihood
    Lkhd = r @ mstar - torch.sum(f)

    # derivative of log likelihood with respect to theta
    nparams = dSigma.shape[2]
    dLtheta = torch.zeros(nparams)
    for i in range(nparams):
        da = torch.linalg.inv(Sigma)@ (dki[:, :, i] - dSigma[:, :, i] @ a)  #shape same as a
        
        dVstar = A**2*(dkstar[:,:, i] + torch.sum(-dki[:, :, i] * a - ki * da + 2 * da * (V @ a), axis=0)[:, None])  #shape same as Vstar

        dmstar = A*da.T @ m    #shape same as mstar

        dLtheta[i] = -0.5 * dVstar.T @ f + dmstar.T @ (r.T - f)

    return Lkhd, dLtheta


def KLdiv(m, V, Sigma, dSigma):
    # KL divergence between prior and posterior
#     C = np.linalg.solve(V,Sigma)        #shape(ntilde, ntilde)
    C = V@torch.linalg.inv(Sigma)        #shape(ntilde, ntilde)
    b = torch.linalg.inv(Sigma)@m        #shape(ntilde, 1)
#     print('C:{}, b:{}'.format(C[1,4],b[6]))
    
    #print(np.log(np.linalg.det(V)), np.log(np.linalg.det(Sigma)), np.trace(C), m.T.dot(b))
    #KL divergence
    Dkl = 0.5*log_det(Sigma) -0.5*log_det(V) + 0.5*torch.trace(C) + 0.5*m.T@b

    # derivative with respect to theta
    nparams = dSigma.shape[2]
    dDtheta = torch.zeros(nparams)
    for i in range(nparams):
        B = dSigma[:, :, i]@torch.linalg.inv(Sigma)

        dDtheta[i] = 0.5*torch.trace(B) - 0.5*torch.trace(C@B) - 0.5*b.T@(B@m)

    #print(Dkl.shape, dDtheta.shape)
    return Dkl, dDtheta    #shape (1,1), shape (7,)

In [10]:
def acoskern_1(theta, x1, x2, C=None, dC=None, diag=False):
    """
    arc cosine covariance function

    Parameters
    ----------
    theta : array-like
        hyperparameters, first 2 are for acos kernel, the rest are for C
    x1 : array-like
        [nx, n1] matrix, first input   #(500,ntilde or r.shape)
    x2 : array-like
        [nx, n2] matrix, second input  #(500, ntilde or r.shape)
    C : array-like, optional
        [nx, nx] matrix, smooth local covariance matrix
    dC : array-like, optional
        [nx, nx, ntheta], derivative of C with respect to hyperparameters
    diag : bool, optional
        if diag ==1, then just return diagonal elements of covariance matrix

    Returns
    -------
    K : array-like
        [n1, n2] kernel
    dK : array-like
        [n1, n2, ntheta], derivative of kernel with respect to theta
    """

    if diag is None:
        diag = False

    if C is None:
        C = torch.eye(x1.shape[0]).to(device)

    n1 = x1.shape[1]

    sigmab = torch.exp(theta[0])

    if not diag:
        n2 = x2.shape[1]
        
        X1 = torch.sqrt(torch.sum(x1*(C @ x1), axis=0) + sigmab ** 2) # np.sum(x1*(C@x1), axis=0) is the same as x1.T @ C @ x1 #shape(n1)
        X2 = torch.sqrt(torch.sum(x2*(C @ x2), axis=0) + sigmab ** 2) # shape(n2,)
        #print((C @ x2)[19,0])
        X1X2 = torch.outer(X1,X2)               #shape(n1,n2)
        x1x2 = x1.T @ C @ x2 + sigmab ** 2   
        
        arg = torch.clip(x1x2 / (X1X2 + 1e-9), -1, 1)

        theta = torch.arccos(arg)

        J = (torch.sqrt(1 - arg ** 2) + torch.pi * arg - theta * arg) / torch.pi   #shape(n1,n2)

        K = X1X2 * J       #shape(n1, n2)
        
        #print('X1:{},X2:{}, J:{}, X1X2:{}, K:{} '.format(X1.shape, X2.shape, J.shape, X1X2.shape, K.shape))
        if dC is not None:
            dK = torch.zeros((n1, n2, dC.shape[2] + 1)).to(device)
#     dK : array-like
#         [n1, n2, ntheta], derivative of kernel with respect to theta

            dX1X2 = sigmab ** 2 * (X2 / X1[:, None] + X1[:, None] / X2) #shape same as X1X2

            darg = (2 * sigmab ** 2 - arg * dX1X2) / X1X2    #shape same as arg

            dJ = -(theta - torch.pi) * darg / torch.pi           #shapesame as J

            dK[:, :, 0] =  (X1X2 * dJ + dX1X2 * J)      #shape same as K
                                   
            for j in range(1, dC.shape[2] + 1):

                dX1 = 0.5*torch.sum(x1*(dC[:, :, j-1]@x1), axis=0)/X1  #shape(n1,)
                dX2 = 0.5*torch.sum(x2*(dC[:, :, j-1]@x2), axis=0)/X2  #shape(n2,)
                
                dX1X2 = dX1[:, None]*X2 + X1[:, None]*dX2

                darg = (x1.T@(dC[:, :, j-1]@x2) - arg*dX1X2)/X1X2

                dJ =  -(theta-torch.pi)*darg/torch.pi

                dK[:, :, j] = X1X2*dJ + dX1X2*J
            
            #print('dX1X2:{}, darg:{}, dJ:{}, dK[:,:,1]:{}'.format(dX1X2.shape,darg.shape, dJ.shape, (A * (X1X2 * dJ + dX1X2 * J)).shape))
            # make sure that K is positive definite
        if n1==n2:
            K = (K+K.T)/2 + 1e-15*torch.eye(n1).to(device)

    else:

        # return just diagonal of covariance
        K = torch.sum(x1*(C@x1), axis=0)[:, None]+sigmab**2


        # derivative of covariance
        if dC is not None:
            dK = torch.zeros((n1,1, dC.shape[2] + 1)).to(device)

            dK[:,:, 0] = 2*sigmab**2*torch.ones((n1, 1)).to(device)

            for j in range(1, dC.shape[2] + 1):
                dK[:,:, j] = torch.sum(x1 * (dC[:, :, j-1]@x1), axis=0)[:, None]

            K += 1e-15 * torch.eye(n1, 1).to(device)

    if dC is not None: return K.to(torch.float64), dK.to(torch.float64)    #shape(n1,n2), shape(n1,n2,6) 
    else: return K.to(torch.float64)    #shape (n1,n2)

In [14]:
# main script

#import of the single cell database. 3190 108x108 images and 3190 responses to them
data = scipy.io.loadmat(r'/media/samuele/Samuele_02/GP from MAtlab to python/data_cell21.mat')
r = data['r']  #shape (1, 3190)
X = data['X']  #shape (108,108,3190)

Xvec = np.reshape(X, (-1, X.shape[2])) # shape(11664, 3190) each image is considered as a 1D vector

# learn hyperparams
save_param = 1

if save_param==1:
    results={}
    theta, A, lambda0, m, V, xtilde, L = varGP(Xvec, r, ntilde=250,MaxIter=80, Nmstep=20, Nestep=50, Display=2)
    
    results['theta']=theta
    results['L']=L
    results['m']=m
    results['V']=V
    results['xtilde']=xtilde
    results['A']=A
    results['lambda0']=lambda0
    
    np.save('saved_params_cell21_gp-v2.npy', results)
    
elif save_param ==0:
    data = np.load('saved_params_cell21.npy', allow_pickle=True).item()
    theta = data['theta']
    L = data['L']
    m = data['m']
    V = data['V']
    xtilde = data['xtilde']
    A=data['A'][-1]
    lambda0=data['lambda0'][-1]
    
    # compute local sta
    X0 = Xvec[:, 31][:, np.newaxis]  #shape(11664,1)

    _, idx,_ = localsmoothkern(theta[1:,-1], X0.shape[0])
    idx = np.where(idx)[0]

    lsta = np.zeros((X0.shape[0],1))

    f,sf  = compute_f(m, V, theta[:,-1], A, lambda0, xtilde, X0, kernfun=acoskern)
    # print(f,sf)
    eta = 1e-5
    for i in range(idx.size):         
        Xdash = np.copy(X0)
        Xdash[idx[i]] = X0[idx[i]]+eta
        fdash,_  = compute_f(m, V, theta[:,-1], A, lambda0, xtilde, Xdash, kernfun=acoskern)
        lsta[idx[i]] = (fdash-f)/eta

    #plot_lsta
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(X0.reshape(108,108))
    plt.subplot(1,2,2)
    plt.imshow(lsta[:,0].reshape(108,108))

100%|███████████████████████████████████████████| 80/80 [12:38<00:00,  9.48s/it]


Done!
