In [1]:
import numpy as np
import tensorly as tl
import torch
import pandas as pd
tl.set_backend("pytorch")

In [2]:
def get_khatri_rao(X,factors,mode):
    X_1 = tl.unfold(X,mode)
    K = tl.tenalg.khatri_rao(factors, skip_matrix=mode)

    return tl.dot(X_1,K)   

In [3]:
def get_Hadamard(factors):
    H = tl.tensor(np.ones((r,r)))
    for fac in factors:
        H = H * tl.dot(tl.transpose(fac), fac)
    return H

In [4]:
def reconstruct_tensor(A,B,T):
    X = tl.cp_tensor.cp_to_tensor((None,[A,B,T]))
    Y = tl.cp_tensor.cp_to_tensor((None,[A,A,T]))
    return X,Y

In [5]:
def getError(X,Y,A,B,T):
    X_hat, Y_hat = reconstruct_tensor(A,B,T)
    error = tl.norm(X)**2 + tl.norm(X_hat)**2 -2*tl.tenalg.inner(X,X_hat) + alpha*(tl.norm(X)**2+tl.norm(Y_hat)**2-2*tl.tenalg.inner(Y,Y_hat))
    return error

In [6]:
def getFit(X,Y,A,B,T):
    X_hat, Y_hat = reconstruct_tensor(A,B,T)
    fit = 1-(tl.norm(X-X_hat)/tl.norm(X)+alpha*tl.norm(Y-Y_hat)/tl.norm(Y))
    return fit


In [7]:
# hyperparameters
N = 10
M = 10
K = 300
r = 5
d = 3
maxiters = 15
alpha = 0.3

In [8]:
# Coupled Tensor Factorization
def RandomCoupledTensorFac(N,M,K,r,d,alpha,maxiters):
    # Tensor Initialization
    X = tl.tensor(np.random.rand(N,M,K))
    Y = tl.tensor(np.random.rand(N,N,K))
    
    # Initialization (random)
    A = tl.tensor(np.random.rand(N,r))
    B = tl.tensor(np.random.rand(M,r))
    T = tl.tensor(np.random.rand(K,r))
    factorsX = [A, B, T]
    factorsY = [A, A, T]
    error = getError(X,Y,A,B,T)
    print("Non-negative Couple Tensor Factorization")
    print(f'Initial Error : {error}')
    print("==========================================")
    
    # optimization
    for iter in range(maxiters):
        ANum = get_khatri_rao(X,factorsX,0) + alpha*get_khatri_rao(Y,factorsY,1)
        ADem = tl.dot(A,get_Hadamard([B,T])) + alpha*tl.dot(A,get_Hadamard([A,T]))
        Ares = (ANum/ADem)**(1/d)
        A = A*Ares

        factorsX[0] = A
        factorsY[0] = A
        factorsY[1] = A

        BNum = get_khatri_rao(X,factorsX,1)
        BDem = tl.dot(B,get_Hadamard([A,T]))
        Bres = (BNum/BDem)**(1/d)
        B = B*Bres

        factorsX[1] = B

        TNum = get_khatri_rao(X,factorsX,2) + alpha*get_khatri_rao(Y,factorsY,2)
        TDem = tl.dot(T,get_Hadamard([A,B]))+alpha*tl.dot(T,get_Hadamard([A,A]))
        Tres = (TNum/TDem)**(1/d)
        T = T*Tres

        factorsX[2] = T
        factorsY[2] = T
        
        oldError = error
        error = getError(X,Y,A,B,T)
        fit = getFit(X,Y,A,B,T)
        print(f'Iteration {iter+1} error : {error:.3f}\tError change : {(oldError-error):.3f} \tRecon Tensor Fit : {fit:.3f}')
    return A,B,T

In [9]:
A,B,T = RandomCoupledTensorFac(N,M,K,r,d,alpha,maxiters)

Non-negative Couple Tensor Factorization
Initial Error : 10998.6845703125
Iteration 1 error : 4553.535	Error change : 6445.150 	Recon Tensor Fit : 0.234
Iteration 2 error : 3960.599	Error change : 592.936 	Recon Tensor Fit : 0.286
Iteration 3 error : 3737.120	Error change : 223.479 	Recon Tensor Fit : 0.306
Iteration 4 error : 3617.041	Error change : 120.080 	Recon Tensor Fit : 0.318
Iteration 5 error : 3546.956	Error change : 70.084 	Recon Tensor Fit : 0.324
Iteration 6 error : 3502.577	Error change : 44.379 	Recon Tensor Fit : 0.329
Iteration 7 error : 3471.658	Error change : 30.919 	Recon Tensor Fit : 0.332
Iteration 8 error : 3448.235	Error change : 23.422 	Recon Tensor Fit : 0.334
Iteration 9 error : 3429.362	Error change : 18.873 	Recon Tensor Fit : 0.336
Iteration 10 error : 3413.492	Error change : 15.870 	Recon Tensor Fit : 0.337
Iteration 11 error : 3399.681	Error change : 13.812 	Recon Tensor Fit : 0.339
Iteration 12 error : 3387.567	Error change : 12.113 	Recon Tensor Fit : 

In [10]:
# Export T
def saveFac(T,label,np):
    if np:
        T = tl.to_numpy(T)
        Tnp = pd.DataFrame(T)
        Tnp = Tnp.to_csv(f'/home/seyunkim/tensorcast_py/T{label}.csv',index=False)
    Tpd = pd.DataFrame(T)
    Tpd = Tpd.to_csv(f'/home/seyunkim/tensorcast_py/T{label}.csv',index=False)


In [11]:
# saveFac(T,'1')

In [13]:
saveFac(T,'300',True)