In [1]:
import numpy as np

def tt_svd_3d(T, rmax=None, eps=None):
    """
    TT-SVD for a 3D tensor T of shape (n1, n2, n3).
    Returns TT cores G1, G2, G3 with shapes:
      G1: (1, n1, r1)
      G2: (r1, n2, r2)
      G3: (r2, n3, 1)

    Truncation:
      - rmax: hard cap on TT ranks
      - eps:  relative Frobenius error target (very simple heuristic)
    """
    n1, n2, n3 = T.shape

    # Step 1: reshape into matrix (n1, n2*n3) and SVD
    M1 = T.reshape(n1, n2 * n3)
    U1, S1, Vh1 = np.linalg.svd(M1, full_matrices=False)

    # choose rank r1
    r1 = len(S1)
    if eps is not None:
        # keep enough singular values to satisfy relative energy
        energy = np.cumsum(S1**2) / np.sum(S1**2)
        r1 = int(np.searchsorted(energy, 1 - eps) + 1)
    if rmax is not None:
        r1 = min(r1, rmax)

    U1 = U1[:, :r1]
    S1 = S1[:r1]
    Vh1 = Vh1[:r1, :]

    G1 = U1.reshape(1, n1, r1)

    # carry remainder: reshape (r1, n2, n3)
    M2 = (np.diag(S1) @ Vh1).reshape(r1 * n2, n3)

    # Step 2: SVD on (r1*n2, n3)
    U2, S2, Vh2 = np.linalg.svd(M2, full_matrices=False)

    # choose rank r2
    r2 = len(S2)
    if eps is not None:
        energy = np.cumsum(S2**2) / np.sum(S2**2)
        r2 = int(np.searchsorted(energy, 1 - eps) + 1)
    if rmax is not None:
        r2 = min(r2, rmax)

    U2 = U2[:, :r2]
    S2 = S2[:r2]
    Vh2 = Vh2[:r2, :]

    G2 = U2.reshape(r1, n2, r2)
    G3 = (np.diag(S2) @ Vh2).reshape(r2, n3, 1)

    return G1, G2, G3

def tt_reconstruct_3d(G1, G2, G3):
    """Reconstruct T[i,j,k] from TT cores."""
    # contract ranks: (1,n1,r1) x (r1,n2,r2) -> (1,n1,n2,r2)
    X = np.tensordot(G1, G2, axes=(2, 0))
    # (1,n1,n2,r2) x (r2,n3,1) -> (1,n1,n2,n3,1)
    Y = np.tensordot(X, G3, axes=(3, 0))
    # drop singleton dims
    return Y[0, :, :, :, 0]

def tt_num_params(G1, G2, G3):
    return G1.size + G2.size + G3.size

# ----------------------------
# Demo: make a toy 3D tensor
# ----------------------------
np.random.seed(0)
n1, n2, n3 = 20, 20, 20

# A low-rank-ish tensor (sum of a few separable terms) + small noise
R_true = 3
a = np.random.randn(R_true, n1)
b = np.random.randn(R_true, n2)
c = np.random.randn(R_true, n3)

T = np.zeros((n1, n2, n3))
for r in range(R_true):
    T += np.einsum("i,j,k->ijk", a[r], b[r], c[r])
T += 0.01 * np.random.randn(n1, n2, n3)

# Full tensor params:
full_params = T.size

for rmax in [1, 2, 4, 8, 20]:
    G1, G2, G3 = tt_svd_3d(T, rmax=rmax)
    T_hat = tt_reconstruct_3d(G1, G2, G3)

    rel_err = np.linalg.norm(T - T_hat) / np.linalg.norm(T)
    tt_params = tt_num_params(G1, G2, G3)

    print(f"rmax={rmax:>2} | TT params={tt_params:>5} vs full={full_params:>5} "
          f"| compression={full_params/tt_params:6.2f}x | rel_err={rel_err:.3e}")

print("\nCore shapes:")
G1, G2, G3 = tt_svd_3d(T, rmax=4)
print("G1:", G1.shape, "G2:", G2.shape, "G3:", G3.shape)


rmax= 1 | TT params=   60 vs full= 8000 | compression=133.33x | rel_err=7.118e-01
rmax= 2 | TT params=  160 vs full= 8000 | compression= 50.00x | rel_err=3.556e-01
rmax= 4 | TT params=  480 vs full= 8000 | compression= 16.67x | rel_err=5.286e-03
rmax= 8 | TT params= 1600 vs full= 8000 | compression=  5.00x | rel_err=4.746e-03
rmax=20 | TT params= 8800 vs full= 8000 | compression=  0.91x | rel_err=1.693e-15

Core shapes:
G1: (1, 20, 4) G2: (4, 20, 4) G3: (4, 20, 1)
