In [5]:
from copy import deepcopy as copy
import numpy as np
import teneva
from time import perf_counter as tpc
np.random.seed(42)

In [6]:
def _ones(n):
    return np.ones((n, 1), dtype=int)

In [7]:
def _range(n):
    return np.arange(n).reshape(-1, 1)

In [8]:
def optima_tt_beam(Y, k=100, power=2, ret_all=False):
    teneva.orthogonalize(Y, 0)
    
    G = Y[0]
    r1, n, r2 = G.shape
    Q = G.reshape(n, r2)
    I = _range(n)

    for G in Y[1:]:
        r1, n, r2 = G.shape
        Q = np.einsum("ij,jkl->ikl", Q, G).reshape(-1, r2)
        
        I1 = np.kron(I, _ones(n))
        I2 = np.kron(_ones(I.shape[0]), _range(n))
        I = np.hstack((I1, I2))
        
        norms = np.sum(Q**power, axis=1)
        ind = np.argsort(norms)[:-(k+1):-1]
        I = I[ind]
        Q = Q[ind]

    return I if ret_all else I[0]

In [9]:
class SolverBase:
    def __init__(self, f, n, m, m_tst=1.E+4):
        self.f = f
        self.n = n
        self.m = int(m)
        self.m_tst = int(m_tst)

        self.d = len(self.n)

        self.init()

    def check(self):
        get = teneva.getter(self.Y)

        Z = np.array([get(i) for i in self.I_trn])
        self.e_trn = np.linalg.norm(Z - self.Y_trn)          
        self.e_trn /= np.linalg.norm(self.Y_trn)

        Z = np.array([get(i) for i in self.I_tst])
        self.e_tst = np.linalg.norm(Z - self.Y_tst)          
        self.e_tst /= np.linalg.norm(self.Y_tst)

    def info(self, name='ALS-BASE'):
        if self.e_trn < 0:
            self.check()

        text = ''
        text += name + ' | '
        text += f'e_trn: {self.e_trn:-7.1e} | '
        text += f'e_tst: {self.e_tst:-7.1e} | '
        text += f't_trn: {self.t_trn:-7.1e} | '
        text += f't_tst: {self.t_tst:-7.1e} | '
        text += f'evals: {self.m_cur:-7.1e} | '
        text += f'rank: {teneva.erank(self.Y):-7.1e} | '
        print(text)

    def init(self):
        self.e_trn = -1.
        self.e_tst = -1.
        
        self.t_trn = 0.
        self.t_tst = 0.
        
        self.Y = None
        self.m_cur = 0
        
        self.I_trn = None
        self.Y_trn = None

        return self

    def solve(self, r=3, nswp=30, log=False):
        self._tst_build(self.m_tst)

        _t = tpc()
        self.I_trn, self.Y_trn = self._trn_get(self.m)
        
        self.Y = teneva.rand(self.n, r)

        self.Y = teneva.als(self.I_trn, self.Y_trn, self.Y, nswp)
        self.m_cur += self.m
        self.t_trn += tpc() - _t

        if log:
            self.check()
            self.info()
            
    def _tst_build(self, m=1.E+4):
        _t = tpc()
        I = np.vstack([np.random.choice(self.n[i], int(m)) for i in range(self.d)]).T
        Y = self.f(I)
        self.t_tst += tpc() - _t
        self.I_tst, self.Y_tst = I, Y

    def _trn_get(self, m=1.E+4):
        I = teneva.sample_lhs(self.n, int(m)) 
        Y = self.f(I)
        return I, Y

In [18]:
class Solver(SolverBase):
    def info(self, name='XXX-NEW '):
        return super().info(name)

    def solve(self, e=1.E-5, m0=1.E+2, log=False):
        self._tst_build(self.m_tst)

        t = tpc()
        
        I_trn_1, Y_trn_1 = self._trn_get(m0)
        self.Y1 = teneva.anova(I_trn_1, Y_trn_1, r=2)
        
        I_trn_2, Y_trn_2 = self._trn_get(m0)
        self.Y2 = teneva.anova(I_trn_2, Y_trn_2, r=2)
        
        self.I_trn = np.vstack((I_trn_1, I_trn_2))
        self.Y_trn = np.hstack((Y_trn_1, Y_trn_2))
        self.m_cur = int(2 * m0)
        
        self.Y = teneva.add(self.Y1, self.Y2)
        self.Y = teneva.mul(self.Y, 0.5)
        self.Y = teneva.truncate(self.Y, e)

        if log:
            self.check()
            self.info()
                
        while self.m_cur <= self.m:
            dY = teneva.sub(self.Y1, self.Y2)
            i = optima_tt_beam(dY)
            y = self.f(i)
            self._trn_add(i, y)
            # print(f'y = {y:-12.5e} | i = ', ' '.join([f'{ii:-2d}' for ii in i]), sum(i))
            
            y1 = teneva.get(self.Y1, i)
            D1 = teneva.tensor_delta(self.n, i, y - y1)
            self.Y1 = teneva.add(self.Y1, D1)
            self.Y1 = teneva.truncate(self.Y1, e)

            y2 = teneva.get(self.Y2, i)
            D2 = teneva.tensor_delta(self.n, i, y - y2)
            self.Y2 = teneva.add(self.Y2, D2)
            self.Y2 = teneva.truncate(self.Y2, e)
            
            self.Y = teneva.add(self.Y1, self.Y2)
            self.Y = teneva.mul(self.Y, 0.5)
            self.Y = teneva.truncate(self.Y, e)

            self.m_cur += 1
            self.t_trn += tpc() - t

            if log and (self.m_cur % 100 == 0 or self.m_cur == self.m):
                self.check()
                self.info()

    def _trn_add(self, i, y):
        if self.I_trn is None:
            self.I_trn = np.array(i).reshape((1, -1))
        else:
            self.I_trn = np.vstack((self.I_trn, i.reshape((1, -1))))
        if self.Y_trn is None:
            self.Y_trn = np.array(y)
        else:
            self.Y_trn = np.hstack((self.Y_trn, np.array(y)))

In [19]:
d = 7
n = 16
m = 1.E+4

for func in teneva.func_demo_all(d):
    func.set_grid(n, kind='uni')

    sl = Solver(func.get_f_ind, func.n, m)
    sl.solve(log=True)
    
    break

XXX-NEW  | e_trn: 2.4e-02 | e_tst: 2.3e-02 | t_trn: 0.0e+00 | t_tst: 5.3e-03 | evals: 2.0e+02 | rank: 2.0e+00 | 
y =  1.65840e+01 | i =   8  6  6  9  7 10  9 55
y =  1.92930e+01 | i =   8  6  6  9 12 10  9 60
y =  1.59425e+01 | i =   8  6  7  9  7 10  9 56
y =  1.71318e+01 | i =   8  6  6  9  9 10  9 57
y =  1.89956e+01 | i =   8 12  6  9  7 10  9 61
y =  1.89956e+01 | i =   8  6  7  9 12 10  9 61
y =  1.97700e+01 | i =   8  6  6  9 13 10  9 61
y =  1.89956e+01 | i =   8  3  6  9  7 10  9 52
y =  1.65840e+01 | i =   8  9  6  9  7 10  9 58
y =  1.71318e+01 | i =   8  9  6  9  9 10  9 60
y =  1.76370e+01 | i =   8  6  6  9 10 10  9 58
y =  1.94910e+01 | i =   8 13  6  9  7 10  9 62
y =  1.85321e+01 | i =   8  6  6  9  4 10  9 52
y =  1.89956e+01 | i =   8  6  6  9  7 10 12 58
y =  1.86547e+01 | i =   8 12  7  9  7 10  9 62
y =  1.92930e+01 | i =   8 12  6  9  9 10  9 63
y =  1.59425e+01 | i =   8  9  7  9  7 10  9 59
y =  1.81577e+01 | i =   8  6 11  9  7 10  9 60
y =  1.71318e+01 | i = 