In [1]:
import matplotlib.pyplot as plt
import numpy as np
import teneva                          # NOTE: version 0.11.3 is required !
from time import perf_counter as tpc
np.random.seed(42)

In [2]:
def cross_matrix_spec(f, I1, I2):
    n1 = I1.shape[0]
    n2 = I2.shape[0]
    
    q1 = int(np.log2(n1))
    q2 = int(np.log2(n2))
    q = q1 + q2
    
    if 2**q1 != n1 or 2**q2 != n2:
        raise NotImplementedError('Mode size and rank should be power of 2')

    def func(I):
        I1_base = teneva.ind_qtt_to_tt(I[:, :q1], q1)[:, 0]
        I2_base = teneva.ind_qtt_to_tt(I[:, q1:], q2)[:, 0]
        I1_base = I1[I1_base, :]
        I2_base = I2[I2_base, :]
        I_base = np.hstack((I1_base, I2_base))
        return f(I_base)
    
    info = {}
    Y = teneva.tensor_rand([2]*q, r=opts_qtt['r'])
    Y = teneva.cross(func, Y, nswp=opts_qtt['nswp'], e=opts_qtt['e'],
        dr_min=opts_qtt['dr_min'], dr_max=opts_qtt['dr_max'],
        cache={}, info=info, log=False)
    Y = teneva.full(Y).reshape(n1, n2, order='F')
    
    return Y, info['m']

In [3]:
def _ones(k, m=1):
    return np.ones((k, m), dtype=int)

In [4]:
def _reshape(A, n, order='F'):
    return np.reshape(A, n, order=order)

In [5]:
def _func_new(f, Ig, Ir, Ic, info, cache=None):
    if cache is not None:
        raise NotImplementedError('Cache is not supported in fast cross')
    
    n = Ig.shape[0]
    r1 = Ir.shape[0] if Ir is not None else 1
    r2 = Ic.shape[0] if Ic is not None else 1

    if info['m_max'] is not None and info['m'] + r1*n*r2 > info['m_max']:
        return None

    if Ic is not None:
        I2 = Ic
    else:
        I = np.kron(np.kron(_ones(r2), Ig), _ones(r1))
        Ir_ = np.kron(_ones(n * r2), Ir)
        I = np.hstack((Ir_, I))
        
        y = f(I)
        info['m'] += len(y)
        
        return _reshape(y, (r1, n, r2))

    I1 = np.kron(Ig, _ones(r1))
    if Ir is not None:
        I1 = np.hstack((np.kron(_ones(n), Ir), I1))
    
    y, m = cross_matrix_spec(f, I1, I2)
    info['m'] += m
    
    return _reshape(y, (r1, n, r2))

In [6]:
def func_cross(func, Y0, nswp=1, with_info=False, is_new=False):
    func.clear()
    
    func.prep(Y0)
    func.cross(nswp=nswp, dr_max=0, cache=None, func=_func_new if is_new else None)
    
    if is_new:
        func.method += '-FAST'
    
    func.check()
    if with_info:
        func.info(f'm = {func.m:7.1e}')
    
    return {'t': func.t, 'm': func.m, 'e': func.e_tst_ind}

In [30]:
opts_qtt = {
    'r': 2,
    'e': 1.E-6,
    'nswp': 999,
    'dr_min': 1,
    'dr_max': 1,
}

In [31]:
d         = 10       # Dimension
n         = 256      # Grid size
r         = 4        # TT-rank
nswp      = 1        # Number of sweeps
with_info = True     # If true, then logs will be printed

print(f'One sweep is : {d * n * r * r:-7.1e}')

One sweep is : 4.1e+04


In [32]:
result = {}

for func in teneva.func_demo_all(d):
    if func.name == 'Dixon':
        continue
    
    func.set_grid(n, kind='uni')
    func.build_tst_ind(1.E+3)

    Y0 = teneva.tensor_rand(func.n, r)

    result[func.name] = {'old': [], 'new': []}
    
    result[func.name]['old'] = func_cross(func, Y0, nswp, with_info)
    result[func.name]['new'] = func_cross(func, Y0, nswp, with_info, is_new=True)

    if with_info:
        print()

Ackley          [CRO          ] > error: 2.6e-03 | rank:  4.0 | time:   0.116 | m = 7.0e+04
Ackley          [CRO-FAST     ] > error: 2.6e-03 | rank:  4.0 | time:   3.118 | m = 4.1e+04

Alpine          [CRO          ] > error: 1.5e-15 | rank:  2.0 | time:   0.118 | m = 7.0e+04
Alpine          [CRO-FAST     ] > error: 3.0e-08 | rank:  2.0 | time:   2.336 | m = 3.6e+04

Exponential     [CRO          ] > error: 3.0e-15 | rank:  1.0 | time:   0.111 | m = 7.0e+04
Exponential     [CRO-FAST     ] > error: 9.0e-15 | rank:  1.0 | time:   0.934 | m = 1.3e+04

Grienwank       [CRO          ] > error: 3.0e-15 | rank:  3.0 | time:   0.135 | m = 7.0e+04
Grienwank       [CRO-FAST     ] > error: 4.8e-15 | rank:  3.0 | time:   0.967 | m = 1.4e+04

Michalewicz     [CRO          ] > error: 1.7e-15 | rank:  2.0 | time:   0.146 | m = 7.0e+04
Michalewicz     [CRO-FAST     ] > error: 6.6e-04 | rank:  2.8 | time:   2.612 | m = 3.7e+04

Qing            [CRO          ] > error: 1.1e-15 | rank:  2.0 | time:   0.1