In [1]:
import numpy as np
import torch
import time

In [2]:
def groundtruth(left, right, matrix):
    """
    Black-box style: first assemble the query indices (tensor of shape Rl x 2 x Rr x |F|), then compute the function on them: for each x, f(x) = x^T matrix x.
    
    :param left: a matrix of shape Rl x M (the left cross interface)
    :param right: a matrix of shape Rr x (|F|-1-M) (the left cross interface)
    :param matrix: the QUBO interaction matrix of shape |F| x |F|
    """
    
    Rl = left.shape[0]
    Rr = right.shape[0]
    F = matrix.shape[0]
    indices = torch.cat([
        left[:, None, None, :].repeat(1, 2, Rr, 1),
        torch.arange(2)[None, :, None, None].repeat(Rl, 1, Rr, 1),
        right[None, None, :, :].repeat(Rl, 2, 1, 1)
    ], dim=-1)
    return torch.einsum('ri,ij,rj->r', indices.reshape(-1, F), matrix, indices.reshape(-1, F))


def faster(left, right, matrix):
    """
    "Intrusive" approach: use the fact the query indices are structured (there is redundant information) to make the evaluation faster.
    
    :param left: a matrix of shape Rl x M (the left cross interface)
    :param right: a matrix of shape Rr x (|F|-1-M) (the left cross interface)
    :param matrix: the QUBO interaction matrix of shape |F| x |F|
    """
    
    Rl = left.shape[0]
    Rr = right.shape[0]
    F = matrix.shape[0]
    indices = torch.cat([
        left[:, None, None, :].repeat(1, 2, Rr, 1),
        torch.arange(2)[None, :, None, None].repeat(Rl, 1, Rr, 1),
        right[None, None, :, :].repeat(Rl, 2, 1, 1)
    ], dim=-1)
    
    U1 = left @ matrix[:M, :]  # Cost: O(Rl x M x |F|)
    U2 = torch.arange(2).float()[:, None] @ matrix[M:M+1, :]  # Cost: O(2 x 1 x |F|)
    U3 = right @ matrix[M+1:, :]  # Cost: O(Rr x (|F|-1-M) x |F|)
    
    tmp = U1[:, None, None, :] + U2[None, :, None, :] + U3[None, None, :, :]  # Cost: O(Rl x 2 x Rr x |F|)
    return torch.einsum('rn,rn->r', tmp.reshape(-1, F), indices.reshape(-1, F))  # Cost: O(Rl x 2 x Rr x |F|)

Let's test it:

In [4]:
Rl = 25  # Tensor rank from the left
Rr = 30  # Tensor rank from the right
F = 8000  # Nummber of QUBO features
M = F//2
matrix = torch.rand(F, F)  # The QUBO interaction matrix

left = torch.randint(0, 2, [Rl, M]).float()  # Left interface of cross
right = torch.randint(0, 2, [Rr, F-1-M]).float()  # Right interface of cross

start = time.time()
result_groundtruth = groundtruth(left, right, matrix)
print('Elapsed (groundtruth):', time.time()-start)

start = time.time()
result_faster = faster(left, right, matrix)
print('Elapsed (faster):', time.time()-start)

print('Relative error:', torch.linalg.norm(result_faster - result_groundtruth) / torch.linalg.norm(result_groundtruth))

4000
Elapsed (groundtruth): 1.706371784210205
Elapsed (faster): 0.10840201377868652
Relative error: tensor(6.3996e-08)


In [3]:
Rl = 25  # Tensor rank from the left
Rr = 30  # Tensor rank from the right
F = 8000  # Nummber of QUBO features
M = F//2
matrix = torch.rand(F, F)  # The QUBO interaction matrix

left = torch.randint(0, 2, [Rl, M]).float()  # Left interface of cross
right = torch.randint(0, 2, [Rr, F-1-M]).float()  # Right interface of cross

start = time.time()
result_groundtruth = groundtruth(left, right, matrix)
print('Elapsed (groundtruth):', time.time()-start)

start = time.time()
result_faster = faster(left, right, matrix)
print('Elapsed (faster):', time.time()-start)

print('Relative error:', torch.linalg.norm(result_faster - result_groundtruth) / torch.linalg.norm(result_groundtruth))

Elapsed (groundtruth): 0.78399658203125
Elapsed (faster): 0.0443267822265625
Relative error: tensor(6.2843e-08)
