In [1]:
import time
import torch


class Timer(object):
    def __enter__(self):
        self.t0 = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        print('[time spent: {time:.2f}s]'.format(time = time.time() - self.t0))

In [2]:
def operator_adaptation(Z, incoming_list):
    B, G, D = incoming_list.shape
    Zp, Zf = Z[:, :-1], Z[:, 1:]
    # the same as torch.linalg.pinv(Zp) @ Zf
    K = torch.linalg.lstsq(Zp, Zf).solution 
    n, pred_list = Z[:, -1:], [Z[:, -1:] @ K]
    # K_list = [K]

    for i in range(G):
        m, n = n, incoming_list[:, i].unsqueeze(1)
        Zp = torch.cat((Zp, m), dim=1)
        Zf = torch.cat((Zf, n), dim=1)
        K = torch.linalg.lstsq(Zp, Zf).solution
        pred_list.append(n @ K)
        # K_list.append(K)
        
    return torch.concat(pred_list, dim=1) # ,torch.stack(K_list, dim=0)

In [3]:
def operator_adaptation_accelerate(Z, incoming_list):
    B, G, D = incoming_list.shape
    Zp, Zf = Z[:, :-1], Z[:, 1:]
    Zp_inv = torch.linalg.pinv(Zp)
    K = Zp_inv @ Zf
    X = Zp_inv @ Zp
    n, pred_list = Z[:, -1:], [Z[:, -1:] @ K]
    # K_list = [K]

    for i in range(G):
        m, n = n, incoming_list[:, i].unsqueeze(1)
        mt = m.transpose(1, 2)
        r = mt - X.transpose(1, 2) @ mt
        b = r / r.square().sum(dim=1, keepdim=True)
        K = K - b @ (m @ K - n)
        X = X - b @ (m @ X - m)
        pred_list.append(n @ K)
        # K_list.append(K)
        
    return torch.concat(pred_list, dim=1) # ,torch.stack(K_list, dim=0)

In [4]:
B = 64
F = 10
D = 1024
G = 10

Z = torch.randn(B, F, D)
incoming_list = torch.randn(B, G, D)

In [5]:
with Timer():
    pred1 = operator_adaptation(Z, incoming_list)

[time spent: 6.12s]


In [6]:
pred1.shape

torch.Size([64, 11, 1024])

In [7]:
with Timer():
    pred2 = operator_adaptation_accelerate(Z, incoming_list)

[time spent: 2.89s]


In [8]:
(pred1-pred2).norm()

tensor(0.0001)

In [11]:
assert torch.norm(pred1-pred2) < 1e-3