# Implementation of the paper

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Data generation

In [29]:
#we write our custom matricization function
def matricization(A, mode):

    if mode == 1:
        return torch.cat([A.select(2, i) for i in range(A.size(2))], dim=1)
    
    elif mode == 2:
        return torch.cat([A.select(2, i).T for i in range(A.size(2))], dim=1)
    
    elif mode == 3:
        return torch.vstack([A.select(2, i).T.reshape(-1) for i in range(A.size(2))])


    """ p1, p2, p3 = A.shape
    if mode == 1:
        res = torch.zeros(p1, p2 * p3)
        for i in range(p1):
            for j in range(p2):
                for k in range(p3):
                    res[i, k * p2 + j] = A[i, j, k]
    elif mode == 2:
        res = torch.zeros(p2, p1 * p3)
        for i in range(p1):
            for j in range(p2):
                for k in range(p3):
                    res[j, k * p1 + i] = A[i, j, k]
    elif mode == 3:
        res = torch.zeros(p3, p1 * p2)
        for i in range(p1):
            for j in range(p2):
                for k in range(p3):
                    res[k, j * p1 + i] = A[i, j, k]
    return res """

In [30]:
#test matricization with the examples in the Kolda and Bader paper
A = torch.zeros(3, 4, 2)
A[:, :, 0] = torch.arange(1, 13).reshape(4, 3).transpose(0, 1)
A[:, :, 1] = torch.arange(13, 25).reshape(4, 3).transpose(0, 1)

print(A)
A1 = matricization(A, 1)
A2 = matricization(A, 2)
A3 = matricization(A, 3)

print('Matricization along mode-1:\n', A1)
print('Matricization along mode-2:\n', A2)
print('Matricization along mode-3:\n', A3)

tensor([[[ 1., 13.],
         [ 4., 16.],
         [ 7., 19.],
         [10., 22.]],

        [[ 2., 14.],
         [ 5., 17.],
         [ 8., 20.],
         [11., 23.]],

        [[ 3., 15.],
         [ 6., 18.],
         [ 9., 21.],
         [12., 24.]]])
Matricization along mode-1:
 tensor([[ 1.,  4.,  7., 10., 13., 16., 19., 22.],
        [ 2.,  5.,  8., 11., 14., 17., 20., 23.],
        [ 3.,  6.,  9., 12., 15., 18., 21., 24.]])
Matricization along mode-2:
 tensor([[ 1.,  2.,  3., 13., 14., 15.],
        [ 4.,  5.,  6., 16., 17., 18.],
        [ 7.,  8.,  9., 19., 20., 21.],
        [10., 11., 12., 22., 23., 24.]])
Matricization along mode-3:
 tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.],
        [13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]])


# we also write the inverse operation, only from A(1) to A

In [125]:
def inv_matricization(A, p1, p2, p3):
    res = torch.zeros(p1, p2, p3)
    for i in range(p1):
        for j in range(p2):
            for k in range(p3):
                res[i, j, k] = A[i, k * p2 + j]
    return res

In [126]:
def generate_A(N, P, diagonal, r1, r2, r3):
    G = torch.zeros(r1, r2, r3)
    G[0, 0, 0] = diagonal[0]
    G[1, 1, 1] = diagonal[1]
    G[2, 2, 2] = diagonal[2]

    #we generate matrices until we have a stable model, according to assumption 1 in the paper
    while True:
        O1 = torch.randn(N, r1)
        O2 = torch.randn(N, r2)
        O3 = torch.randn(P, r3)

        #generate Ui as the top ri singular vectors of the matrix
        U1 = torch.linalg.svd(O1)[0][:, :r1]
        U2 = torch.linalg.svd(O2)[0][:, :r2]
        U3 = torch.linalg.svd(O3)[0][:, :r3]
        

        #build A, given its tucker decomposition above
        #mode 1 product of U1 and G
        A1 = torch.einsum('ij, jkl -> ikl', U1, G)
        #print(A1.shape)
        #mode 2 product of A1 and U2
        A2 = torch.einsum('ij, kjl -> kil', U2, A1)
        #print(A2.shape)
        #mode 3 product of A2 and U3
        A = torch.einsum('ij, klj -> kli', U3, A2)
        #print(A.shape)

        """ A1 = torch.tensordot(U1, G, dims=[[1], [0]])
        #print(A1.shape)
        A2 = torch.tensordot(U2, A1, dims=[[1], [1]])
        #print(A2.shape)
        A = torch.tensordot(A2, U3, dims=[[2], [1]])
        #print(A.shape) """

        #check if the model is stable by 
        # Computing the companion matrix
        C = torch.zeros(N*P, N*P)
        for i in range(P):
            C[:N, i*N:(i+1)*N] = A[:, :, i]
        for i in range(P-1):
            C[(i+1)*N:(i+2)*N, i*N:(i+1)*N] = torch.eye(N)
        
        # Compute the eigenvalues of the companion matrix
        eigenvalues = torch.linalg.eigvals(C)
        # Check if all eigenvalues are inside the unit circle
        if torch.all(torch.abs(eigenvalues) < 1):
            #print('The model is stable')
            return A

In [127]:
def var_generate(A, T=100):

    #A is a (N, N, P) tensor
    N, P = A.shape[0], A.shape[2]

    #let's matricize A to use the model in (3)
    Ac = matricization(A, 1)
    
    #we return x, a (T, NP) tensor
    x = torch.zeros(T, N*P)
    x[0, :] = torch.randn(N*P)

    #we also return y, a (T, N) tensor
    y = torch.zeros(T, N)
    for t in range(T):
        y[t, :] = Ac @ x[t] + torch.randn(N)
        if t < T-1:
            x[t+1, :N] = y[t, :]
            x[t+1, N:] = x[t, :-N]
    
    return x, y

In [128]:
#data generation
N, P = 10, 5
r1, r2, r3 = 3, 3, 3
diagonal_list = [(2, 2, 2), (4, 3, 2), (1, 1, 1), (2, 1, 0.5)]

#let's construct the transitions matrices as built in the paper
A_list = []
for diagonal in diagonal_list:
    A_list.append(generate_A(N, P, diagonal, r1, r2, r3))

# Initial estimator

In [132]:
def initial_estimator(x, y, lambd=0.1, max_iter=100):
    T,N = y.shape
    P = x.shape[1]//N
    #define the loss function as in the paper
    def loss(A):
        #A is of shape (N, NP)
        return torch.sum((y.T - A @ x.T)**2) + lambd * torch.norm(A, p='nuc')
    
    #we use the gradient descent algorithm to minimize the loss function
    A = torch.randn(N, N*P, requires_grad=True)
    optimizer = torch.optim.Adam([A], lr=0.1)

    #We use SGD to compute the argmin of the loss function
    for i in range(max_iter):
        l = loss(A)
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return A.detach()

In [137]:
#test the initial estimator on the first transition matrix
x, y = var_generate(A_list[0], T=100)
A = initial_estimator(x, y, lambd=1, max_iter=1000)

#print the Frobenius norm of the error between the initial estimator and the true transition matrix
print('The Frobenius norm of the error between the initial estimator and the true transition matrix is:\n', torch.norm(A - matricization(A_list[0], 1)))


The Frobenius norm of the error between the initial estimator and the true transition matrix is:
 tensor(3.4645)


# Rank selection consistency

In [110]:
def multilinear_ranks_estimator(A, c):
    ranks = [0, 0, 0]
    for mode in range(3):
        try:
            eigvals = torch.linalg.svdvals(matricization(A, mode+1))
            ranks[mode] = torch.argmin((eigvals[1:]+c)/(eigvals[:-1]+c)).item() + 1
        except torch.linalg.LinAlgError as e:
            ranks[mode] = 3

    return ranks

In [53]:
#we want to estimate the consistency of the multilinear ranks estimator
#we will use the ridge-type ratio estimator, as in the paper
#Theorem 3 in the paper says that the estimator is consistent
#so as T increases, the probability of being correct tends to 1

n_replications = 10

T_list = np.arange(50, 60)
proportions = np.zeros((4, T_list.shape[0]))


for i, T in enumerate(T_list):
    c = np.sqrt(N*P*np.log(T)/(10*T))
    for j, A in enumerate(A_list):
        for n in range(n_replications):
            #we generate the data
            x, y = var_generate(A, T=T)
            #we estimate the transition matrix
            A_hat = initial_estimator(x, y, lambd=1, max_iter=1000)
            
            tensor = inv_matricization(A_hat, N, N, P)
            #we estimate the multilinear ranks
            ranks = multilinear_ranks_estimator(A, c)
            #we check if the estimator is correct
            if ranks == [3, 3, 3]:
                proportions[j, i] += 1/n_replications


#plot the proportion of correct estimations
plt.plot(T_list, proportions[0, :], label='(2, 2, 2)')
plt.plot(T_list, proportions[1, :], label='(4, 3, 2)')
plt.plot(T_list, proportions[2, :], label='(1, 1, 1)')
plt.plot(T_list, proportions[3, :], label='(2, 1, 0.5)')
plt.legend()
plt.xlabel('T')
plt.ylabel('Proportion of correct estimations')
plt.show()

KeyboardInterrupt: 

# Algorithm for the estimator A_MLR

In [138]:
def alternating_squares_MLR(y,x,X,A_0,ranks,max_iter=100):

    T = y.shape[0]
    N = y.shape[1]
    P = x.shape[1] // N
    A = A_0 #shape (N, N, P)

    #get the ranks
    r1, r2, r3 = ranks[0], ranks[1], ranks[2]
    print(r1, r2, r3)

    #compute the HOSVD of A

    A1 = matricization(A, 1)
    A2 = matricization(A, 2)
    A3 = matricization(A, 3)
    U1, S1, V1 = torch.linalg.svd(A1)
    U2, S2, V2 = torch.linalg.svd(A2)
    U3, S3, V3 = torch.linalg.svd(A3)

    #compute the core tensor
    G = torch.zeros(r1, r2, r3)
    for i in range(r1):
        for j in range(r2):
            for k in range(r3):
                G[i, j, k] = S1[i] * S2[j] * S3[k]

    #compute the matricization of G
    G1 = matricization(G, 1)

    #compute the factor matrices
    U1 = U1[:, :r1]
    U2 = U2[:, :r2]
    U3 = U3[:, :r3]

    
    for k in range(max_iter):

        #update U1
        Y = torch.zeros(N*r1)
        A = torch.zeros(N*r1, N*r1)
        H = torch.kron(U3.contiguous(), U2.contiguous()) @ G1.T
        for t in range(T):
            xt = x[t]
            At = torch.kron(xt.T @ H , torch.eye(N))
            Y += At.T @ y[t]
            A += At.T @ At
        if torch.linalg.det(A) == 0:
            A += 1e-2 * torch.eye(N*r1)
        U1 = torch.linalg.solve(A, Y).reshape(N, r1)

        #update U2
        Y = torch.zeros(N*r2)
        A = torch.zeros(N*r2, N*r2)
        H = U1 @ G1
        for t in range(T):
            Xt = X[t]
            At = H @ torch.kron((Xt @ U3).T.contiguous(), torch.eye(r2))
            Y += At.T @ y[t]
            A += At.T @ At
        if torch.linalg.det(A) == 0:
            A += 1e-2 * torch.eye(N*r2)
        U2 = torch.linalg.solve(A, Y).reshape(r2, N).T

        #update U3
        Y = torch.zeros(P*r3)
        A = torch.zeros(P*r3, P*r3)
        H = U1 @ G1
        for t in range(T):
            Xt = X[t]
            At = H @ torch.kron(torch.eye(r3), (U2.T @ Xt))
            Y += At.T @ y[t]
            A += At.T @ At
        if torch.linalg.det(A) == 0:
            A += 1e-2 * torch.eye(P*r3)
        U3 = torch.linalg.solve(A, Y).reshape(P, r3)

        #update G1
        Y = torch.zeros(r1*r2*r3)
        A = torch.zeros(r1*r2*r3, r1*r2*r3)
        for t in range(T):
            xt = x[t]
            At = torch.kron((torch.kron(U3.contiguous(), U2.contiguous()).T @ xt).T.contiguous(), U1.contiguous())
            Y += At.T @ y[t]
            A += At.T @ At
        if torch.linalg.det(A) == 0:
            A += 1e-2 * torch.eye(r1*r2*r3)
        G1 = torch.linalg.solve(A, Y).reshape(r1, r2 * r3)

    #return A, whose tucker decomposition is U1, U2, U3, G
    G = inv_matricization(G1, r1, r2, r3)

    #print(G.shape, G1.shape, U1.shape, U2.shape, U3.shape)
    #mode 1 product of U1 and G
    A1 = torch.einsum('ij, jkl -> ikl', U1, G)
    #print(A1.shape)
    #mode 2 product of A1 and U2
    A2 = torch.einsum('ij, kjl -> kil', U2, A1)
    #print(A2.shape)
    #mode 3 product of A2 and U3
    A = torch.einsum('ij, klj -> kli', U3, A2)
    #print(A.shape)
    
    return A

In [140]:
#trying the alternating squares algorithm
#we will use the same data as before
T = 100
x, y = var_generate(A_list[0], T)
X = x.reshape(T, N, P)

A_0 = initial_estimator(x, y, lambd=1, max_iter=1000)
A_0 = inv_matricization(A_0, N, N, P)
c = np.sqrt(N*P*np.log(T)/(10*T))
ranks = multilinear_ranks_estimator(A_0, c)
A = alternating_squares_MLR(y, x, X, A_0, ranks, max_iter=20)

#print the Frobenius norm of the difference between A and true_A
true_A = A_list[0]
#print(A)
print(torch.norm(A - true_A))

3 3 3
tensor(3.7943)


# Second algorithm

In [None]:
def admm_subroutine(y, X, B, N, r, T, kappa=1, lambd=1, max_iter1=100, max_iter2=100, sgd_steps=10):
    #y is a NT x 1 vector
    #X is a NT x NPT matrix
    
    #initialization

    W = B.clone()
    M = torch.zeros(N * r)
    for k in range(max_iter1):

        #update B
        #using the SOC method, cited in the paper
        I = B.clone()
        J = B.clone()
        K = B.clone()
        def loss(I):
                return (1/T)*torch.norm(y - X @ I)**2 + kappa * torch.norm(I-W+M)**2 + 0.5*torch.norm(I-J+K)**2

        for j in range(max_iter2):

            #update I
            #we do gradient descent for this one
            #Adam optimizer
            optimizer = torch.optim.Adam([I], lr=0.01)
            for _ in range(sgd_steps):
                optimizer.zero_grad()
                loss(I).backward()
                optimizer.step()
            
            #update J
            Y = I + K
            U, D, V = torch.linalg.svd(Y)
            J = U @ torch.eye(N, r) @ V.T

            #update K
            K = K + I - J
        B = I


        #update W
        #we have to solve a least squares problem with a l1 penalty
        #we use ista

        def soft_threshold(x, lambd):
            return torch.sign(x) * torch.clamp(torch.abs(x) - lambd, min=0)

        def ista(b, lambd, max_iter=100):
            x = torch.zeros(b.shape)
            L = b.shape[0]
            for i in range(max_iter):
                x = soft_threshold(x + (b - x) / L, lmbda / L)
            return x

        W = ista(B+M, lambd/kappa, max_iter=100)


        #update M
        M = M + B - W

    return B

# Algorithm for the estimator A_SHORR

In [None]:
def ADMM_SHORR(y,x,X,A_0,ranks,max_iter=100, lambd=1, rho1=1, rho2=1, rho3=1):

    T = y.shape[0]
    N = y.shape[1]
    P = x.shape[1] // N
    A = A_0 #shape (N, N, P)

    #get the ranks
    r1, r2, r3 = ranks[0], ranks[1], ranks[2]
    print(r1, r2, r3)

    #compute the HOSVD of A

    A1 = matricization(A, 1)
    A2 = matricization(A, 2)
    A3 = matricization(A, 3)
    U1, S1, V1 = torch.linalg.svd(A1)
    U2, S2, V2 = torch.linalg.svd(A2)
    U3, S3, V3 = torch.linalg.svd(A3)

    #compute the core tensor
    G = torch.zeros(r1, r2, r3)
    for i in range(r1):
        for j in range(r2):
            for k in range(r3):
                G[i, j, k] = S1[i] * S2[j] * S3[k]

    #compute the matricizations of G
    G_mats = [matricization(G, 1), matricization(G, 2), matricization(G, 3)]


    #compute the factor matrices
    U1 = U1[:, :r1]
    U2 = U2[:, :r2]
    U3 = U3[:, :r3]



    
    for k in range(max_iter):

        #update U1

        H = torch.kron(U3.contiguous(), U2.contiguous()) @ G_mats[0].T
        X1 = torch.zeros(N*T, N*r1)
        for t in range(T):
            xt = x[t]
            At = torch.kron(xt.T @ H , torch.eye(N))
            X1[t*N:(t+1)*N, :] = At
        U1 = admm_subroutine(y.reshape(-1), X1, U1.T.reshape(-1), N, r1, T, kappa=1, lambd=lambd*torch.norm(U2, ord=1)*torch.norm(U3, ord=1), max_iter1=100, max_iter2=100, sgd_steps=10)

        #update U2

        H = U1 @ G_mats[0]
        X2 = torch.zeros(N*T, N*r2)
        for t in range(T):
            Xt = X[t]
            At = H @ torch.kron((Xt @ U3).T.contiguous(), torch.eye(r2))
            X2[t*N:(t+1)*N, :] = At
        U2 = admm_subroutine(y.reshape(-1), X2, U2.T.reshape(-1), N, r2, T, kappa=1, lambd=lambd*torch.norm(U1, ord=1)*torch.norm(U3, ord=1), max_iter1=100, max_iter2=100, sgd_steps=10)

        #update U3
        H = U1 @ G_mats[0]
        X3 = torch.zeros(N*T, N*r3)
        for t in range(T):
            Xt = X[t]
            At = H @ torch.kron(torch.eye(r3), (U2.T @ Xt))
            X3[t*N:(t+1)*N, :] = At
        U3 = admm_subroutine(y.reshape(-1), X3, U3.T.reshape(-1), N, r3, T, kappa=1, lambd=lambd*torch.norm(U1, ord=1)*torch.norm(U2, ord=1), max_iter1=100, max_iter2=100, sgd_steps=10)


        #update G

        #we will use SGD on a square loss

        def loss(G):
            return torch.sum((y - X @ (U1 @ G @ U2.T @ U3.T))**2) 
        

        
        

    #return A, whose tucker decomposition is U1, U2, U3, G
    G = inv_matricization(G1, r1, r2, r3)

    #print(G.shape, G1.shape, U1.shape, U2.shape, U3.shape)
    #mode 1 product of U1 and G
    A1 = torch.einsum('ij, jkl -> ikl', U1, G)
    #print(A1.shape)
    #mode 2 product of A1 and U2
    A2 = torch.einsum('ij, kjl -> kil', U2, A1)
    #print(A2.shape)
    #mode 3 product of A2 and U3
    A = torch.einsum('ij, klj -> kli', U3, A2)
    #print(A.shape)
    
    return A

In [None]:
A_mlr1 = matricization(A_mlr, 1)
A_ols1 = matricization(A_ols, 1)
A_rrr1 = matricization(A_rrr, 1)

mlr_preds = x @ A_mlr1
ols_preds = x @ A_ols1
rrr_preds = x @ A_rrr1

#plot the results on the same plot with the true values of y
#on a big plot

plt.figure(figsize=(20, 10))

plt.plot(y[:, -3], label='True')
plt.plot(mlr_preds[:, -3], label='MLR')
plt.plot(ols_preds[:, -3], label='OLS')
plt.plot(rrr_preds[:, -3], label='RRR')
plt.legend()
plt.show()
