In [3]:
import numpy as np
from numpy.linalg import svd as svd

In [52]:
n = 50

def foo(i, j, k, l):
    return 1 / (1 + i + j + k + l)

A = np.fromfunction(foo, [n, n, n, n])
d = A.ndim

In [53]:
def crop_vector(a: np.array, eps: float) -> int:
    """crop vector of positive values until all sum minus prefix sum < eps

    Args:
        a (np.array): vector of positive values
        eps (float): tolerance

    Returns:
        int: len of prefix sum
    """    
    r = 0
    cur_sum = 0
    sum = np.sum(a)

    while r < len(a):
        cur_sum += a[r]
        if sum - cur_sum < eps:
            return r + 1
        r += 1
    return r

In [54]:
def TTSVD(A, eps):
    d = A.ndim
    TT = A.reshape(A.shape[0], -1)
    u, s, v = svd(TT, full_matrices=False)
    r1 = crop_vector(s, eps / (d**0.5))
    G1 = u[:, :r1]
    TT = np.diag(s[:r1]) @ v[:r1, :]
    ranks = [r1]
    G = [G1]
    for k in range(1, d-1):
        TT = TT.reshape(ranks[k - 1] * A.shape[k], -1)
        u, s, v = svd(TT, full_matrices=False)
        ranks.append(crop_vector(s, eps / (d**0.5)))
        Gk = u[:, :ranks[k]].reshape(ranks[k-1], A.shape[k], ranks[k])
        G.append(Gk)
        TT = np.diag(s[:ranks[k]]) @ v[:ranks[k], :]
    G.append(TT)
    return G

In [85]:
G = TTSVD(A, 1e-8)
print([g.shape for g in G])

[(50, 13), (13, 50, 13), (13, 50, 13), (13, 50)]


In [83]:
def unfold_TT(G):
    d = len(G)
    g = G[0]
    for k in range(1, d-1):
        g = np.einsum(g, [i for i in range(g.ndim)], G[k], [g.ndim-1, g.ndim, g.ndim+1], [i for i in range(g.ndim - 1)] + [g.ndim, g.ndim+1]) # ...k,k...->...
    g = np.einsum(g, [i for i in range(g.ndim)], G[-1], [g.ndim - 1, g.ndim], [i for i in range(g.ndim -1)] + [g.ndim])
    return g

In [84]:
g = unfold_TT(G)
np.linalg.norm(g - A)

4.4758827043168295e-09