In [31]:
import numpy as np
from numpy.linalg import svd as svd

In [151]:
n = 50

def foo(i, j, k, l):
    return 1 / (1 + i + j + k + l)

A = np.fromfunction(foo, [n, n, n, n])
d = A.ndim

# Utils

In [74]:
def crop_vector(a: np.array, eps: float) -> int:
    """crop vector of positive values until all sum minus prefix sum < eps

    Args:
        a (np.array): vector of positive values
        eps (float): tolerance

    Returns:
        int: len of prefix sum
    """    
    r = 0
    cur_sum = 0
    sum = np.sum(a)

    while r < len(a):
        cur_sum += a[r]
        if sum - cur_sum < eps:
            return r + 1
        r += 1
    return r

def croped_svd(a: np.array, eps: float, full_matrices = False):
    """return cropped svd with absolute tolerance eps

    Args:
        a (np.array): matrix
        eps (float): absolute
        full_matrices (bool, optional): full_matrices for np.svd. Defaults to False.

    Returns:
        _type_: u, s, v - cropped svd
    """
    u, s, v = svd(a, full_matrices = full_matrices)
    r = crop_vector(s, eps)
    return u[:, :r], s[:r], v[:r, :]

# TT SVD

In [138]:
def TTSVD(A, eps): #from left to right
    d = A.ndim
    TT = A.reshape(A.shape[0], -1)
    u, s, v = croped_svd(TT, eps / (d**0.5))
    r1 = len(s)
    G1 = u
    TT = np.diag(s) @ v
    ranks = [r1]
    G = [G1]
    for k in range(1, d-1):
        TT = TT.reshape(ranks[k - 1] * A.shape[k], -1)
        u, s, v = svd(TT, full_matrices=False)
        ranks.append(crop_vector(s, eps / (d**0.5)))
        Gk = u[:, :ranks[k]].reshape(ranks[k-1], A.shape[k], ranks[k])
        G.append(Gk)
        TT = np.diag(s[:ranks[k]]) @ v[:ranks[k], :]
    G.append(TT)
    return G

In [139]:
G = TTSVD(A, 1e-8)
print([g.shape for g in G])

[(50, 13), (13, 50, 13), (13, 50, 13), (13, 50)]


In [140]:
def unfold_TT(G):
    d = len(G)
    g = G[0]
    for k in range(1, d-1):
        g = np.einsum(g, [i for i in range(g.ndim)], G[k], [g.ndim-1, g.ndim, g.ndim+1], [i for i in range(g.ndim - 1)] + [g.ndim, g.ndim+1]) # ...k,k...->...
    g = np.einsum(g, [i for i in range(g.ndim)], G[-1], [g.ndim - 1, g.ndim], [i for i in range(g.ndim -1)] + [g.ndim])
    return g

In [141]:
g = unfold_TT(G)
np.linalg.norm(g - A)

4.4758827043168295e-09

In [142]:
[np.linalg.norm(g)**2  for g in G], np.linalg.norm(g)**2

([12.999999999999996,
  12.999999999999986,
  12.999999999999986,
  952.4124701079983],
 952.4124701079974)

# Recompression

In [144]:
def recompess(G, eps): #from right to left
    d = len(G)
    u, s, v = croped_svd(G[d-1], eps / (d ** 0.5))
    Gnew = [v]
    W = u @ np.diag(s)
    for k in range(d-2, 0, -1):
        Gk = np.einsum("ijk,kl->ijl", G[k], W)
        u, s, v = croped_svd(Gk.reshape(Gk.shape[0], -1), eps / (d ** 0.5))
        W = u @ np.diag(s)
        Gnew.append(v.reshape(len(s), G[k].shape[1], Gnew[d-2-k].shape[0]))
    Gd = np.einsum("ij,jk->ik", G[0], W)
    u, s, v = croped_svd(Gd, eps / (d ** 0.5))
    Gnew.append(u @ np.diag(s) @ v)
    return Gnew[::-1]

In [148]:
Gnew = recompess(G, 1e-3)
print("old shape:", [g.shape for g in G], "\n" "new shape:", [g.shape for g in Gnew])

old shape: [(50, 13), (13, 50, 13), (13, 50, 13), (13, 50)] 
new shape: [(50, 7), (7, 50, 8), (8, 50, 7), (7, 50)]


In [150]:
g1 = unfold_TT(G)
g2 = unfold_TT(Gnew)
np.linalg.norm(g1 - g2)

0.0003337647085426793