In [1]:
"""
Manifold gradient module.

"""

import torch as tn
from torchtt._decomposition import mat_to_tt, to_tt, lr_orthogonal, round_tt, rl_orthogonal
from torchtt import TT
from torchtt.errors import *

def _delta2cores(tt_cores, R, Sds, is_ttm = False, ortho = None):
    """
    Convert the detla notation to TT.
    Implements Algorithm 5.1 from "AUTOMATIC DIFFERENTIATION FOR RIEMANNIAN OPTIMIZATION ON LOW-RANK MATRIX AND TENSOR-TRAIN MANIFOLDS".

    Args:
        tt_cores (list[torch.tensor]): the TT cores.
        R (list[int]): the rank of the tensor.
        Sds (list[torch.tensor]): deltas.
        # if the number of dimentions of tt format and approximated array match
        is_ttm (bool, optional): is TT matrix or not. Defaults to False.
        ortho (list[list[torch.tensor]], optional): the left and right orthogonal cores of tt_cores. Defaults to None.

    Returns:
        list[torch.tensor]: the resulting TT cores.
    """
    
    if ortho == None:
        l_cores,_  = lr_orthogonal(tt_cores, R, is_ttm)
        r_cores,_  = rl_orthogonal(tt_cores, R, is_ttm)
    else:
        l_cores = ortho[0]
        r_cores = ortho[1]
    
    # first
    cores_new = [tn.cat((Sds[0],l_cores[0]),2 if not is_ttm else 3)]
    # 2...d-1
    for k in range(1,len(tt_cores)-1):
        up = tn.cat((r_cores[k],tn.zeros((r_cores[k].shape),dtype = l_cores[0].dtype, device = l_cores[0].device)),2 if not is_ttm else 3)
        down = tn.cat((Sds[k],l_cores[k]),2 if not is_ttm else 3)
        cores_new.append(tn.cat((up,down),0))
    # last
    cores_new.append(tn.cat((r_cores[-1],Sds[-1]),0))
    
    return cores_new

def riemannian_gradient(x,func):
    """
    Compute the Riemannian gradient using AD.

    Args:
        x (torchtt.TT): the point on the manifold where the gradient is computed.
        func ([type]): function that has to be differentiated. The function takes as only argument `torchtt.TT` instances.

    Returns:
        torchtt.TT: the gradient projected on the tangent space of x.
    """

    l_cores,_  = lr_orthogonal(x.cores, x.R, x.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, x.R, x.is_ttm)
    
    is_ttm = x.is_ttm

    
    R = x.R
    d = len(x.N)
    
    Rs = [ r_cores[0] ]
    Rs += [ x.cores[i]*0 for i in range(1,d)]
    
    # AD part
    for i in range(d):
        Rs[i].requires_grad_(True)
    Ghats = _delta2cores(x.cores, R, Rs, is_ttm = is_ttm,ortho = [l_cores,r_cores])
    fval = func(TT(Ghats))
    fval.backward() 

    # Sds = tape.gradient(fval, Rs)
    Sds = [r.grad for r in Rs]
    # print('Sds ',Sds)
  
    
    # compute Sdeltas
    for k in range(d-1):
        D = tn.reshape(Sds[k],[-1,R[k+1]])
        UL = tn.reshape(l_cores[k],[-1,R[k+1]])
        D = D - UL @ (UL.T @ D)
        Sds[k] = tn.reshape(D,l_cores[k].shape)
        
        
        
    # print([tf.einsum('ijk,ijl->kl',l_cores[i],Sds[i]).numpy() for i in range(d-1)])
    # delta to TT
    grad_cores = _delta2cores(x.cores, R, Sds, is_ttm,ortho = [l_cores,r_cores])
    return TT(grad_cores)
        
def riemannian_projection(Xspace,z):
    """
    Project the tensor z onto the tangent space defined at xspace

    Args:
        Xspace (torchtt.TT): the target where the tensor should be projected.
        z (torchtt.TT): the tensor that should be projected.

    Raises:
        IncompatibleTypes: Both must be of same type.

    Returns:
        torchtt.TT: the projection.
    """

    if Xspace.is_ttm != z.is_ttm:
        raise IncompatibleTypes('Both must be of same type.')
       
    is_ttm = Xspace.is_ttm
     
    l_cores,R  = lr_orthogonal(Xspace.cores, Xspace.R, Xspace.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, R, Xspace.is_ttm)
    
    d = len(Xspace.N)

    N = Xspace.N
    
    # Pleft = [tf.ones((1,1,1),dtype=Xspace.cores[0].dtype)]
    Pleft = []
    tmp = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1):
        if is_ttm:
            tmp = tn.einsum('rs,rijR,sijS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('rs,riR,siS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        Pleft.append(tmp)
        
   
    
    Pright = []
    tmp = tn.ones((1,1), dtype = Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1,0,-1):
        if is_ttm:
            tmp = tn.einsum('RS,rijR,sijS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('RS,riR,siS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        Pright.append(tmp)
    Pright = Pright[::-1]
    
    
    # compute elements of the tangent space
    Sds = []
    for k in range(d):
  
        if k==0:
            L = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
        else:
            L = Pleft[k-1]
        if k==d-1:
            if is_ttm:
                Sds.append(tn.einsum('rs,sjiS->rjiS',L,z.cores[k]))   
            else:
                Sds.append(tn.einsum('rs,siS->riS',L,z.cores[k]))           
        else:
            R = Pright[k]
            if is_ttm:
                tmp1 = tn.einsum('rs,sijS->rijS',L,z.cores[k])
                tmp2 = tn.einsum('rijR,RS->rijS',l_cores[k],tn.einsum('rs,rijR,sijS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('rijS,RS->rijR',tmp1-tmp2,R))
            else:
                tmp1 = tn.einsum('rs,siS->riS',L,z.cores[k])
                tmp2 = tn.einsum('riR,RS->riS',l_cores[k],tn.einsum('rs,riR,siS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('riS,RS->riR',tmp1-tmp2,R))  
        
    # convert Sds to TT
    grad_cores = _delta2cores(Xspace.cores, R, Sds, Xspace.is_ttm,ortho = [l_cores,r_cores])

    return TT(grad_cores)

In [2]:
import torch as tn
import datetime
import numpy as np
import torchtt as tntt

In [3]:
N = [10,11,12,13,14]
Rt = [1,3,4,5,6,1]
Rx = [1,6,6,6,6,1]
target = tntt.randn(N,Rt).round(0)
func = lambda x: 0.5*(x-target).norm(True)

## Rank reduction

### Riemann

In [6]:
x0 = tntt.randn(N,Rx)
x =x0.clone()
for i in range(20):
    # compute riemannian gradient using AD    
    gr = tntt.manifold.riemannian_gradient(x,func)
    
    #stepsize length
    alpha = 1.0
    
    # update step
    x = (x-alpha*gr).round(0,Rx)    
    print('Value ' , func(x).numpy())

Value  244833.65627034567
Value  188978.66681149497
Value  67445.76635865941
Value  12268.368453471641
Value  3199.2757887970793
Value  152.5318579530042
Value  0.012728872401420278
Value  8.501112842678166e-15
Value  1.2381654377069425e-23
Value  6.727284734463914e-24
Value  4.765436867627989e-24
Value  4.3669895779394166e-24
Value  6.045341163657253e-24
Value  5.906954016533238e-24
Value  4.837892670979434e-24
Value  3.557827855807675e-24
Value  3.512077168862015e-24
Value  4.1102357213128275e-24
Value  4.7845678809676356e-24
Value  3.574841193177573e-24


### Classical

In [7]:
y = x0.detach().clone()

for i in range(1000):
    tntt.grad.watch(y)
    fval = func(y)
    deriv = tntt.grad.grad(fval,y)    
    alpha = 0.00001 # for stability
    y = tntt.TT([y.cores[i].detach()-alpha*deriv[i] for i in range(len(deriv))])
    if (i+1)%100 == 0: print(func(y))

tensor(246470.4807, dtype=torch.float64)
tensor(246132.1552, dtype=torch.float64)
tensor(242583.9335, dtype=torch.float64)
tensor(138157.3856, dtype=torch.float64)
tensor(15593.6928, dtype=torch.float64)
tensor(3666.3339, dtype=torch.float64)
tensor(2173.1221, dtype=torch.float64)
tensor(1188.6520, dtype=torch.float64)
tensor(601.5732, dtype=torch.float64)
tensor(295.4684, dtype=torch.float64)


## Tensor completion

In [10]:
N = 25
target = tntt.randn([N]*4,[1,2,3,3,1])
Xs = tntt.meshgrid([tn.linspace(0,1,N, dtype = tn.float64)]*4)
target = Xs[0]+1+Xs[1]+Xs[2]+Xs[3]+Xs[0]*Xs[1]+Xs[1]*Xs[2]+tntt.TT(tn.sin(Xs[0].full()))
target = target.round(1e-10)
print(target.R)

M = 15000 # number of observations 
indices = tn.randint(0,N,(M,4))

# observations are considered to be noisy
sigma_noise = 0.00001
obs = tn.normal(target.apply_mask(indices), sigma_noise)

# define the loss function
loss = lambda x: (x.apply_mask(indices)-obs).norm()**2

#%% Manifold learning
print('Riemannian gradient descent\n')
# starting point
x = tntt.randn([N]*4,[1,4,4,4,1])

tme = datetime.datetime.now()
# iterations
for i in range(10000):
    # manifold gradient 
    gr = tntt.manifold.riemannian_gradient(x,loss)

    step_size = 1.0
    R = x.R
    # step update
    x = (x - step_size * gr).round(0,R)

    # compute loss value
    if (i+1)%100 == 0:
        loss_value = loss(x)
        print('Iteration %4d loss value %e error %e tensor norm %e'%(i+1,loss_value.numpy(),(x-target).norm()/target.norm(), x.norm()**2))

tme = datetime.datetime.now() - tme
print('')
print('Time elapsed',tme)
print('Number of observations %d, tensor shape %s, percentage of entries observed %6.4f'%(M,str(x.N),100*M/np.prod(x.N)))
print('Number of unknowns %d, number of observations %d, DoF/observations %.6f'%(tntt.numel(x),M,tntt.numel(x)/M))

print('Rank after rounding',x.round(1e-6))

[1, 3, 3, 2, 1]
Riemannian gradient descent

Iteration  100 loss value 2.669319e+02 error 5.536844e-02 tensor norm 6.433788e+06
Iteration  200 loss value 1.755638e+02 error 4.436142e-02 tensor norm 6.557243e+06
Iteration  300 loss value 8.471260e+01 error 3.680423e-02 tensor norm 6.567885e+06
Iteration  400 loss value 8.314812e+00 error 2.808722e-02 tensor norm 6.575853e+06
Iteration  500 loss value 1.522958e+00 error 2.557929e-02 tensor norm 6.576202e+06
Iteration  600 loss value 5.109728e-01 error 2.490470e-02 tensor norm 6.576524e+06
Iteration  700 loss value 2.870835e-01 error 2.460458e-02 tensor norm 6.576592e+06
Iteration  800 loss value 1.892848e-01 error 2.443801e-02 tensor norm 6.576565e+06
Iteration  900 loss value 1.423169e-01 error 2.434341e-02 tensor norm 6.576523e+06
Iteration 1000 loss value 1.186208e-01 error 2.428746e-02 tensor norm 6.576488e+06
Iteration 1100 loss value 1.061103e-01 error 2.425224e-02 tensor norm 6.576462e+06
Iteration 1200 loss value 9.920543e-02 err

Iteration 10000 loss value 8.268513e-02 error 2.411061e-02 tensor norm 6.576322e+06

Time elapsed 0:03:33.777151
Number of observations 15000, tensor shape [25, 25, 25, 25], percentage of entries observed 3.8400
Number of unknowns 1000, number of observations 15000, DoF/observations 0.066667
Rank after rounding TT with sizes and ranks:
N = [25, 25, 25, 25]
R = [1, 4, 4, 4, 1]

Device: cpu, dtype: torch.float64
#entries 1000 compression 0.00256



In [13]:
from e3nn.o3 import wigner_3j

In [19]:
(wigner_3j(1, 1, 1) @ wigner_3j(1, 1, 1))

tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 0.1667, 0.0000],
         [0.0000, 0.0000, 0.1667]],

        [[0.1667, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.1667]],

        [[0.1667, 0.0000, 0.0000],
         [0.0000, 0.1667, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

In [21]:
wigner_3j(1, 2, 3) @ wigner_3j(1, 2, 3).T

RuntimeError: The size of tensor a (3) must match the size of tensor b (7) at non-singleton dimension 0