### Initialize the tensor train cores

In [83]:
import numpy as np

np.random.seed(32)

d = 5

ranks = [1, 4, 4, 4, 4, 1]
Nc = [10, 10, 10, 10, 10]

TT_cores = [np.random.randn(ranks[i], Nc[i], ranks[i+1]) for i in range(d)]

In [84]:
cores0 = TT_cores[0].reshape((ranks[0]*Nc[0], ranks[1]))
Q, R  = np.linalg.qr(cores0)

In [85]:
Q.shape, R.shape

((10, 4), (4, 4))

In [86]:
import copy

# orthogonalization
left = copy.deepcopy(TT_cores)
right = copy.deepcopy(TT_cores)

for i in range(len(left)):
    core = left[i].reshape((ranks[i]*Nc[i], ranks[i+1]))
    Q, R  = np.linalg.qr(core)
    
    print(Q.shape, R.shape)
    left[i] = Q.reshape((ranks[i], Nc[i], ranks[i+1]))
    
    if i < len(left) - 1:
        left[i+1] = np.einsum('ab,bic->aic', R, left[i+1])

        
for i in range(len(right) - 1, -1, -1):
    core = right[i].reshape((ranks[i], Nc[i]*ranks[i+1]))
    Q, R  = np.linalg.qr(core.T)
    
    print(Q.shape, R.shape)
    right[i] = Q.reshape((ranks[i], Nc[i], ranks[i+1]))
    
    if i > 0:
        right[i - 1] = np.einsum('aib, bc ->aic', right[i - 1], R)
        

(10, 4) (4, 4)
(40, 4) (4, 4)
(40, 4) (4, 4)
(40, 4) (4, 4)
(40, 1) (1, 1)
(10, 4) (4, 4)
(40, 4) (4, 4)
(40, 4) (4, 4)
(40, 4) (4, 4)
(40, 1) (1, 1)


In [87]:
for i in range(len(left)):
    core = left[i].reshape((ranks[i]*Nc[i], ranks[i+1]))
    if np.allclose(np.eye(ranks[i+1], ranks[i+1]), core.T @ core):
        print("Core i is left orthogonal")

for i in range(len(right)):
    core = right[i].reshape((ranks[i], Nc[i]*ranks[i+1]))
    print(core @ core.T)
    
    #if np.allclose(np.eye(ranks[i], ranks[i]), core @ core):
    #    print("Core i is orthogonal")
        
    

Core i is left orthogonal
Core i is left orthogonal
Core i is left orthogonal
Core i is left orthogonal
Core i is left orthogonal
[[1.]]
[[ 1.00353156 -0.25523121 -0.07039797  0.01045898]
 [-0.25523121  0.91730163  0.13731065  0.03299841]
 [-0.07039797  0.13731065  1.18016707 -0.03853188]
 [ 0.01045898  0.03299841 -0.03853188  0.89899974]]
[[ 0.90378627  0.08911938 -0.02333987  0.13062166]
 [ 0.08911938  1.02171845 -0.14391855  0.20259527]
 [-0.02333987 -0.14391855  0.97712963  0.1654789 ]
 [ 0.13062166  0.20259527  0.1654789   1.09736565]]
[[ 1.27963043 -0.0626705   0.0148422  -0.09742441]
 [-0.0626705   0.63115522 -0.02501967  0.1629379 ]
 [ 0.0148422  -0.02501967  0.76472507  0.0741632 ]
 [-0.09742441  0.1629379   0.0741632   1.32448928]]
[[ 0.68474995 -0.01773649 -0.23606134  0.07932537]
 [-0.01773649  1.13380579 -0.43168276 -0.04947189]
 [-0.23606134 -0.43168276  1.11581594  0.48299579]
 [ 0.07932537 -0.04947189  0.48299579  1.06562832]]
