### How are column and row vectors handled in Matlab versus Torch

In [1]:
import torch
from utils import *

# Variables initialized in Matlab. 
# Copying these lines in matlab (making tensors -> arrays and making vectors m -> COLUMN vectors)

a_mat = torch.tensor([[1.0,2.0,3.0,4],[5.0,6.0,7.0,8],[9.0,10.0,11.0,12]])

Sigma = torch.tensor([[1.0,2.0,3.0,],[5.0,6.0,7.0],[8.0,7.0,6.0,]]) # shape (ntilde, ntilde)
dSigma = torch.tensor([[4.0,5.0,6.0],[5.0,6.0,7.0],[6.0,7.0,8.0]]) # shape (ntilde, ntilde)

Sigma = Sigma + Sigma.T
dSigma = dSigma + dSigma.T

m = torch.tensor([1.0,2.0,3.0]) # shape ntilde

dki = torch.tensor([[6.0,7.0,8.0,9.0],[2.0,3.0,4.0,5.0],[9.0,8.0,7.0,6]]) # shape ntilde, nt
ki =dki # ntilde, nt

dkstar = torch.tensor([3,4,5,6])

invV = Sigma*3 # ntilde ntilde

# These are the variables defined in funtion lFunc of matthews code to calculate gradients with respect to mean and variance of the avg lambda_tilde 
# They are in the lfunc (loglikelihood) function cause they are used for the gradient of the loglikelihood

# In the following 3 lines I just copied the code in matlab and translated
da_matlab =  torch.linalg.pinv(Sigma)@( dki - dSigma@a_mat) 
dlambda_m_matlab = da_matlab.T@m
dlambda_var_matlab = dkstar + torch.sum(-dki*a_mat - ki*da_matlab + 2*da_matlab*torch.linalg.solve(invV, a_mat ) , 0 ).T


# In torch vectors and matrices are often transposed, starting from a, which in my code corresponds to K@K_tilde_inv like in the notes

a = a_mat.T # shape ( nt, ntilde ), its KKtildeinv

K_tilde = Sigma
dK_tilde = dSigma

dK = dki.T
K = ki.T

dK_vec = dkstar
a@dK_tilde

K_tilde_inv = torch.linalg.pinv(K_tilde)

V = torch.linalg.inv(invV)

da_torch = (dK - a@dK_tilde)@K_tilde_inv # TODO check if it can be made more efficient puling dK out of the parenthesis


dlambda_m_torch = da_torch@m
dlambda_var_torch = dK_vec + torch.einsum( 'ij,ji->i', 2*da_torch, V@a.T) - torch.einsum( 'ij,ij->i', dK,a ) - torch.einsum( 'ij,ij->i', K, da_torch )

# If tranlation is ok the following shoudl be equal. Of course da is transposed in between the two codes. The other two are just vectors

print( f' da_matlab:\n {da_matlab} \n da_torch.T:\n {da_torch.T} \n')
print( f' dlambda_m_torch:\n {dlambda_m_torch} \n dlambda_m_matlab:\n {dlambda_m_matlab} \n')
print( f' dlambda_var_torch:\n {dlambda_var_torch} \n dlambda_var_matlab:\n {dlambda_var_matlab} \n')


Using device: cuda:0 (from utils.py)
 da_matlab:
 tensor([[ 176.5000,  219.8333,  263.1667,  306.5000],
        [-251.2500, -313.0833, -374.9167, -436.7500],
        [ 113.2500,  142.0833,  170.9167,  199.7500]]) 
 da_torch.T:
 tensor([[ 176.5000,  219.8333,  263.1667,  306.5000],
        [-251.2500, -313.0833, -374.9167, -436.7500],
        [ 113.2500,  142.0833,  170.9167,  199.7500]]) 

 dlambda_m_torch:
 tensor([13.7500, 19.9167, 26.0833, 32.2500]) 
 dlambda_m_matlab:
 tensor([13.7500, 19.9167, 26.0833, 32.2500]) 

 dlambda_var_torch:
 tensor([-2199.1944, -3161.6944, -4291.7500, -5589.3611]) 
 dlambda_var_matlab:
 tensor([-2199.1944, -3161.6944, -4291.7500, -5589.3611]) 



In [16]:
# Gradients of the Loglikelihood

lambda_m = torch.tensor([1.0,2.0,3.0,4.0]) # shape nt

# in Matlab the code is

# dLtheta = -0.5*dVstar'*f + dmstar'*(r'-f); where r is a row vector and f is a column vector, thats why r is trasposed

# In the matlab script the formula is

f = torch.tensor([55.0,4.0,22.0,5.0]) # shape nt
r = torch.tensor([23.0,47.0,2.0,1.0]) # shape nt
Lkhd_matlab = r@lambda_m - torch.sum(f);
dlogLK_matlab = -0.5*dlambda_var_matlab.T@f + dlambda_m_matlab.T@(r-f)

# while in the torch script the formula is

f_mean = f

A=1
lambda0=0 # there are no f parameters on matlab so I have to put A = 1 and lambda0 = 0
logLK_torch     = A*r@lambda_m   + lambda0*torch.sum(r) - torch.sum(f_mean)
dlogLK_torch = r@dlambda_m_torch - A*torch.dot(f_mean, dlambda_m_torch) - 0.5*A*A*torch.dot(f_mean, dlambda_var_torch)

print( f' logLK_torch:\n {logLK_torch} \n Lkhd_matlab:\n {Lkhd_matlab} \n')
print( f' dlogLK_matlab:\n {dlogLK_matlab} \n dlogLK_torch:\n {dlogLK_torch} \n')



 logLK_torch:
 41.0 
 Lkhd_matlab:
 41.0 

 dlogLK_matlab:
 127749.63888888988 
 dlogLK_torch:
 127749.63888888987 



In [10]:
# Gradients of the KL divergence
# In matlab
C_matlab = V@torch.linalg.pinv(Sigma)
b_matlab = torch.linalg.pinv(Sigma)@m

B_matlab = dSigma@torch.linalg.pinv(Sigma) # this is written as B = dSigma(:, :, i)/Sigma; in matlab

# Adding elements to the diagonal only to make this example work with non positive definite matrices
KLD_matlab = 0.5*log_det(Sigma + 10*torch.eye(Sigma.shape[0]))  - 0.5*log_det(V+ 10*torch.eye(V.shape[0])) + 0.5*torch.sum(torch.linalg.eig(C_matlab)[0])+ 0.5*m.T@b_matlab

dKL_matlab = 0.5*torch.trace(B_matlab) - 0.5*torch.trace(C_matlab@B_matlab)  - 0.5*b_matlab.T@B_matlab@m


# In torch:
c_torch = V @ K_tilde_inv # shape (ntilde, ntilde)
b_torch = K_tilde_inv @ m

B_torch = dK_tilde@K_tilde_inv # Shape (ntilde, ntilde

KLD_torch = -0.5*(log_det(V + 10*torch.eye(V.shape[0])) - log_det(K_tilde+ 10*torch.eye(K_tilde.shape[0]))) + 0.5*torch.matmul(m.T, b_torch) + 0.5*torch.trace(c_torch) 
dKL_torch = 0.5*torch.trace(B_torch) - 0.5*torch.trace(c_torch@B_torch) - 0.5*b_torch.T@(B_torch@m)

print(f' C_matlab:\n {C_matlab} \n c_torch:\n {c_torch} \n b_matlab:\n {b_matlab} \n b_torch:\n {b_torch} \n')
print(f' B_matlab:\n {B_matlab} \n B_torch:\n {B_torch} \n')
print(f' KLD_matlab:\n {KLD_matlab} \n KLD_torch:\n {KLD_torch} \n')
print(f' dKL_matlab:\n {dKL_matlab} \n dKL_torch:\n {dKL_torch} \n')


# It looks like gradients are correct

 C_matlab:
 tensor([[ 20.2778, -28.0000,  14.0000],
        [-28.0000,  38.6806, -19.3472],
        [ 14.0000, -19.3472,   9.6806]]) 
 c_torch:
 tensor([[ 20.2778, -28.0000,  14.0000],
        [-28.0000,  38.6806, -19.3472],
        [ 14.0000, -19.3472,   9.6806]]) 
 b_matlab:
 tensor([-1.1667,  1.9167, -0.9167]) 
 b_torch:
 tensor([-1.1667,  1.9167, -0.9167]) 

 B_matlab:
 tensor([[-10.3333,  14.8333,  -6.8333],
        [-13.0000,  18.5000,  -8.5000],
        [-15.6667,  22.1667, -10.1667]]) 
 B_torch:
 tensor([[-10.3333,  14.8333,  -6.8333],
        [-13.0000,  18.5000,  -8.5000],
        [-15.6667,  22.1667, -10.1667]]) 

 KLD_matlab:
 (34.919132834503294+0j) 
 KLD_torch:
 34.91913283450331 

 dKL_matlab:
 110.47222222222383 
 dKL_torch:
 110.47222222222383 

