In [31]:
import numpy as np
import torch
import math

In [48]:
def find_lamb(d, D):

    return D / d

def find_matrix_U(d, D):
    
    G = torch.normal(0, 1, [D, D])
    Q, R = torch.qr(G)
    
    row_norms = torch.norm(Q, dim=1, keepdim=True)
    Q_normalized = Q / row_norms                   # В какой момент нужно проводить нормализацию? 
    
    U = Q_normalized[:d]

    return U

def find_params(D, U: torch.FloatTensor, X: torch.FloatTensor, eps=0.1):      
    # X - array of n d-dimentional torch.FloatTensors; A - first condition set, B - second condition set

    # delta = (1 / 5 ** 4) * ((1 - 1 / (lamb ** 0.5)) ** 2)
    # eta = (3 / 4) + (1 / 4) * (1 / (lamb ** 0.5))
    
    npU = torch.Tensor.numpy(U)
    npX = torch.Tensor.numpy(X)
        
    A_marked = [[x, np.argwhere(x == 0).shape[0]] for x in npX]
    A_sorted = sorted(A_marked, key=lambda x: x[1])
    
    if A_sorted[-1][1] / D >= 1:
        raise Exception("pu pu pu")
    delta = A_sorted[-1][1] / D + eps
    while delta >= 1:
        eps *= eps
        delta = A_sorted[-1][1] / D + eps
    
    B_marked = [[x, np.linalg.norm(np.dot(npU, x)) / np.linalg.norm(x)] for x in npX]
    B_sorted = sorted(B_marked, key=lambda x: x[1])

    if B_sorted[-1][1] >= 1:
        raise Exception("pu pu pu")
    eta = B_sorted[-1][1] + eps
    while eta >= 1:
        eps *= eps
        eta = B_sorted[-1][1] + eps

    return delta, eta

def find_eta(D, U: torch.FloatTensor, X: torch.FloatTensor):

    npU = torch.Tensor.numpy(U)
    npX = torch.Tensor.numpy(X)

    B_marked = [[x, np.linalg.norm(np.dot(npU, x)) / np.linalg.norm(x)] for x in npX]
    B_sorted = sorted(B_marked, key=lambda x: x[1])

    if B_sorted[-1][1] >= 1:
        raise Exception("pu pu pu")
    eta = B_sorted[-1][1] + eps
    while eta >= 1:
        eps *= eps
        eta = B_sorted[-1][1] + eps

    return eta

def hardcode_eta(lamb):
    return (3 / 4) + (1 / 4) * (1 / (math.sqrt(lamb)))

In [40]:
def gridsearch_all_a(X, lamb: int, r: int, delta, eta): 
    # A - array of n D-dimentional torch.FloatTensors a;
    
    A = torch.zeros(X.shape[0], D)   # X.shape[0] - количество n тензоров x 
    
    for i in range(X.shape[0]):
        x = X[i]
        
        for j in range(r):
            a = torch.zeros(D)
            M = torch.norm(x) / (math.sqrt(delta * D))
        
            b = torch.mm(U.T, x.reshape(d, 1))
            b_ = torch.sign(b) * torch.min(torch.abs(b), M)
            x = x - (torch.mm(U, b_)).T
            a = a + b_.T
            M = eta * M

        A[i] = a

    return A


def find_initial_tenzors(A, U: torch.FloatTensor):

    return torch.mm(A, U.T)

In [44]:
tensors = torch.randn(7, 10)

d = tensors.shape[1]
D = 2 * d
lamb = find_lamb(d, D)

U = find_matrix_U(d, D)

delta = 0.9
# eta = find_eta(D, U, tensors)
eta = hardcode_eta(lamb)

A = gridsearch_all_a(tensors, lamb, D, delta, eta)

In [45]:
print(tensors)

tensor([[-1.4279, -0.5669, -0.7766, -0.1700,  1.0480, -0.2874,  1.3806, -0.2000,
          1.3603, -0.7428],
        [-0.7056, -1.1293,  0.5785,  0.4943,  0.8402,  0.8910, -2.0942, -1.0379,
         -0.0473, -1.7113],
        [ 1.9014, -1.4676,  0.5668,  1.6664, -0.0883, -0.0584, -1.1967,  1.3274,
         -0.6571,  0.0211],
        [ 0.4607, -1.3493, -0.4328,  1.8162, -0.6149,  1.5118, -0.3878,  2.2324,
          1.5488,  0.7672],
        [-1.0334, -0.7379,  0.1144,  1.3409,  0.9018,  0.9193, -0.4552,  1.5643,
          0.5804,  0.9108],
        [-0.0184,  2.1789, -0.5542, -1.7763,  1.1840,  0.9995, -0.0973, -0.4935,
         -1.0382,  0.7797],
        [-0.0142, -0.4468,  0.7289, -0.0745,  0.8223, -0.0173, -0.9905,  0.2931,
         -0.1123, -0.4882]])


In [46]:
print(A)

tensor([[ 8.3115e-08, -1.3389e-07,  4.7257e-08, -4.2543e-08, -4.6363e-07,
          3.4409e-07,  3.0667e-07,  5.6985e-07,  2.6006e-07,  4.6097e-07,
          3.2806e-07, -2.3867e-07, -1.4581e-07,  3.6486e-07,  1.8839e-07,
         -2.6275e-07, -4.6939e-07, -8.6685e-08, -3.8949e-07, -2.5736e-07],
        [-4.2570e-07, -3.7640e-08,  5.3613e-07,  2.8840e-07,  3.3858e-08,
         -3.6031e-07, -4.0270e-08,  3.6227e-07,  3.8562e-07, -2.1519e-07,
          1.2471e-07, -3.2318e-08,  3.6127e-07,  5.5914e-07,  2.6575e-07,
         -4.1569e-08,  2.3398e-07, -2.6525e-07,  4.1276e-07,  3.3828e-07],
        [ 1.1175e-05,  3.5788e-06, -1.6169e-05, -4.3347e-06,  3.8772e-05,
          3.8727e-06,  3.8772e-05,  1.8761e-05, -2.4950e-05,  9.4437e-06,
          5.1580e-06,  2.2438e-05,  2.6637e-05, -2.5304e-06,  1.1489e-05,
          4.1491e-06, -1.7504e-06, -1.8929e-05,  7.0530e-06,  1.0579e-05],
        [ 4.5712e-07, -3.3297e-07,  1.1649e-06, -4.9417e-07,  4.8416e-07,
          7.6623e-07, -5.3994e-07, 

In [47]:
find_initial_tenzors(A, U)

tensor([[-9.7319e-07, -1.5987e-07, -5.9050e-08,  1.1001e-07,  8.4945e-09,
         -3.6140e-07, -6.6839e-08,  1.8970e-07, -1.4792e-07, -3.7412e-07],
        [-2.9404e-07,  3.6712e-07,  2.1016e-07,  5.2411e-07, -4.1370e-07,
          1.5960e-07, -2.6191e-07, -5.5541e-07,  2.5969e-07, -4.1262e-07],
        [ 5.8108e-06, -4.0903e-05,  9.5040e-06,  3.8832e-05,  9.6101e-07,
         -1.8322e-05, -1.9220e-05,  1.4477e-06, -1.1697e-05,  1.1844e-05],
        [ 3.1175e-08,  8.2461e-07, -4.1915e-07,  9.7516e-07, -5.0072e-07,
          8.8481e-07, -4.9853e-07,  1.4819e-06,  6.6615e-07, -5.4574e-07],
        [ 1.6440e-08,  4.3485e-07, -2.2103e-07,  5.1424e-07, -2.6405e-07,
          4.6659e-07, -2.6290e-07,  7.8144e-07,  3.5129e-07, -2.8779e-07],
        [-4.6070e-06,  3.2429e-05, -7.5351e-06, -3.0788e-05, -7.6192e-07,
          1.4526e-05,  1.5238e-05, -1.1478e-06,  9.2737e-06, -9.3904e-06],
        [-4.3234e-07,  8.0481e-07,  4.0766e-06,  1.6001e-07,  1.2993e-06,
          1.1663e-06,  3.4016e-0

In [None]:
# a = a.reshape(D, 1) 
# eta = torch.norm(torch.matmul(U, x)) / torch.norm(x)