### Tensor Decompositions

Tensors in *tntorch* are stored in a decomposed format, and so we can straightforwardly learn a compressed representation for a given uncompressed (*full*) tensor. 

In [1]:
import tntorch as tn
import torch

N = 3
gt = tn.rand(shape=[32]*N, ranks_tt=5).full()  # A 3D groundtruth tensor
t = tn.rand(shape=gt.shape, ranks_tt=5, requires_grad=True)  # Initial state: random
tn.optimize(t, lambda t: tn.relative_error(t, gt))

iter: 0      | loss:   0.543997 | total time:    0.0010
iter: 500    | loss:   0.068096 | total time:    0.8705
iter: 1000   | loss:   0.043438 | total time:    1.7205
iter: 1500   | loss:   0.017763 | total time:    2.5069
iter: 1930   | loss:   0.004130 | total time:    3.2765 <- converged (tol=0.0001)


Check the relative error, the root mean square error (RMSE) and the $R^2$ score of the learned tensor:

In [2]:
print(tn.relative_error(gt, t))
print(tn.rmse(gt, t))
print(tn.r_squared(gt, t))

tensor(0.0041, grad_fn=<DivBackward1>)
tensor(0.0127, grad_fn=<DivBackward0>)
tensor(0.9999, grad_fn=<AddBackward>)


Ok, that was easy --after all, we knew the exact rank of the groundtrutn. How about more realistic tensors?

Let's try out an analytical 3D function:

In [3]:
import numpy as np
X, Y, Z = np.meshgrid(range(32), range(32), range(32))
gt = torch.Tensor(np.sqrt(np.sqrt(X)*Y + Y*Z**2))
t = tn.rand(shape=gt.shape, ranks_tt=3, requires_grad=True)
tn.optimize(t, lambda t: tn.relative_error(t, gt))

iter: 0      | loss:  55.447667 | total time:    0.0010
iter: 500    | loss:  13.461909 | total time:    0.8360
iter: 1000   | loss:   7.227232 | total time:    1.8071
iter: 1500   | loss:   4.676556 | total time:    2.7783
iter: 2000   | loss:   3.282153 | total time:    3.7078
iter: 2500   | loss:   2.404226 | total time:    4.5352
iter: 3000   | loss:   1.805042 | total time:    5.4087
iter: 3500   | loss:   1.376094 | total time:    6.3303
iter: 4000   | loss:   1.060893 | total time:    7.2436
iter: 4500   | loss:   0.826793 | total time:    8.0418
iter: 5000   | loss:   0.652394 | total time:    8.9567
iter: 5500   | loss:   0.520899 | total time:    9.8962
iter: 6000   | loss:   0.416777 | total time:   10.8903
iter: 6500   | loss:   0.324939 | total time:   11.9701
iter: 7000   | loss:   0.230926 | total time:   12.9943
iter: 7500   | loss:   0.122071 | total time:   14.0333
iter: 8000   | loss:   0.010126 | total time:   14.9677
iter: 8500   | loss:   0.008826 | total time:   