In [1]:
import torch as tn
import datetime
import numpy as np
try:
    import torchtt as tntt
except:
    print('Installing torchTT...')
    %pip install git+https://github.com/ion-g-ion/torchTT
    import torchtt as tntt

C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m
C++ implementation not available. Using pure Python.
[0m


In [2]:
N = [10,11,12,13,14]
Rt = [1,3,4,5,6,1]
Rx = [1,6,6,6,6,1]
target = tntt.randn(N,Rt).round(0)
func = lambda x: 0.5*(x-target).norm(True)

In [4]:
target

TT with sizes and ranks:
N = [10, 11, 12, 13, 14]
R = [1, 3, 4, 5, 6, 1]

Device: cpu, dtype: torch.float64
#entries 876 compression 0.0036463536463536466

In [5]:
x0 = tntt.randn(N,Rx)
x =x0.clone()
for i in range(20):
    # compute riemannian gradient using AD    
    gr = tntt.manifold.riemannian_gradient(x,func)
    
    #stepsize length
    alpha = 1.0
    
    # update step
    x = (x-alpha*gr).round(0,Rx)    
    print('Value ' , func(x).numpy())

Value  166936.04052405598
Value  146525.46293949717
Value  62837.138985644546
Value  11552.887519468724
Value  1495.7546192129391
Value  0.08833064372059891
Value  4.2224995893584454e-12
Value  5.8289872931809965e-24
Value  2.543323489500019e-24
Value  1.7459105836121516e-24
Value  2.0984335632379473e-24
Value  6.986298575431298e-25
Value  4.746370819895353e-25
Value  5.381857816753326e-25
Value  1.110581856180454e-24
Value  4.4582863326927625e-25
Value  8.444872233305554e-25
Value  4.542892402847257e-25
Value  3.107769118003823e-25
Value  5.268226074594466e-25


As a comparison, conventional gradient descent with respect to the TT cores is performed:

In [6]:
y = x0.detach().clone()

for i in range(1000):
    tntt.grad.watch(y)
    fval = func(y)
    deriv = tntt.grad.grad(fval,y)    
    alpha = 0.00001 # for stability
    y = tntt.TT([y.cores[i].detach()-alpha*deriv[i] for i in range(len(deriv))])
    if (i+1)%100 == 0: print(func(y))
    

tensor(167369.1985, dtype=torch.float64)
tensor(167227.2266, dtype=torch.float64)
tensor(167101.0193, dtype=torch.float64)
tensor(166421.8154, dtype=torch.float64)
tensor(146771.9458, dtype=torch.float64)
tensor(75280.4094, dtype=torch.float64)
tensor(15629.6864, dtype=torch.float64)
tensor(4262.5242, dtype=torch.float64)
tensor(2301.5486, dtype=torch.float64)
tensor(1161.8528, dtype=torch.float64)


### Manifold tensor completion

One other task where the manifold learning can be applied is tensor completion.
The goal for this problem is to reconstruct a tensor in the TT format given only a few entries (possible noisy).

In [7]:
N = 25
target = tntt.randn([N]*4,[1,2,3,3,1])
Xs = tntt.meshgrid([tn.linspace(0,1,N, dtype = tn.float64)]*4)
target = Xs[0]+1+Xs[1]+Xs[2]+Xs[3]+Xs[0]*Xs[1]+Xs[1]*Xs[2]+tntt.TT(tn.sin(Xs[0].full()))
target = target.round(1e-10)
print(target.R)

M = 15000 # number of observations 
indices = tn.randint(0,N,(M,4))

# observations are considered to be noisy
sigma_noise = 0.00001
obs = tn.normal(target.apply_mask(indices), sigma_noise)

# define the loss function
loss = lambda x: (x.apply_mask(indices)-obs).norm()**2

#%% Manifold learning
print('Riemannian gradient descent\n')
# starting point
x = tntt.randn([N]*4,[1,4,4,4,1])

tme = datetime.datetime.now()
# iterations
for i in range(10000):
    # manifold gradient 
    gr = tntt.manifold.riemannian_gradient(x,loss)

    step_size = 1.0
    R = x.R
    # step update
    x = (x - step_size * gr).round(0,R)

    # compute loss value
    if (i+1)%100 == 0:
        loss_value = loss(x)
        print('Iteration %4d loss value %e error %e tensor norm %e'%(i+1,loss_value.numpy(),(x-target).norm()/target.norm(), x.norm()**2))

tme = datetime.datetime.now() - tme
print('')
print('Time elapsed',tme)
print('Number of observations %d, tensor shape %s, percentage of entries observed %6.4f'%(M,str(x.N),100*M/np.prod(x.N)))
print('Number of unknowns %d, number of observations %d, DoF/observations %.6f'%(tntt.numel(x),M,tntt.numel(x)/M))

print('Rank after rounding',x.round(1e-6))

[1, 3, 3, 2, 1]
Riemannian gradient descent

Iteration  100 loss value 3.665990e+02 error 7.645483e-02 tensor norm 6.364616e+06
Iteration  200 loss value 1.839140e+02 error 5.950258e-02 tensor norm 6.554736e+06
Iteration  300 loss value 1.731703e+02 error 5.765866e-02 tensor norm 6.567661e+06
Iteration  400 loss value 1.106916e+02 error 5.210376e-02 tensor norm 6.572372e+06
Iteration  500 loss value 3.185584e+01 error 4.498703e-02 tensor norm 6.578676e+06
Iteration  600 loss value 1.956111e+01 error 4.345280e-02 tensor norm 6.578665e+06
Iteration  700 loss value 1.091974e+01 error 4.213523e-02 tensor norm 6.578725e+06
Iteration  800 loss value 7.423510e+00 error 4.119987e-02 tensor norm 6.578767e+06
Iteration  900 loss value 5.847073e+00 error 4.049400e-02 tensor norm 6.578576e+06
Iteration 1000 loss value 4.242343e+00 error 3.987350e-02 tensor norm 6.578378e+06
Iteration 1100 loss value 2.142366e+00 error 3.934743e-02 tensor norm 6.578530e+06
Iteration 1200 loss value 1.057562e+00 err

Iteration 10000 loss value 2.433817e-01 error 3.620857e-02 tensor norm 6.581354e+06

Time elapsed 0:02:46.016980
Number of observations 15000, tensor shape [25, 25, 25, 25], percentage of entries observed 3.8400
Number of unknowns 1000, number of observations 15000, DoF/observations 0.066667
Rank after rounding TT with sizes and ranks:
N = [25, 25, 25, 25]
R = [1, 4, 4, 4, 1]

Device: cpu, dtype: torch.float64
#entries 1000 compression 0.00256

