In [9]:
import torch
from torch.utils.cpp_extension import load
from algo import original_orth
import timeit
import itertools

In [10]:
qr_func = load(name="qr_func", sources=["qr_orthogonalization.cpp", "qr_orthogonalization.cu"])

In [12]:
M = 4
N = 1024
device = torch.device("cuda")
dtype = torch.float32

A = torch.rand((M, N), device=device, dtype=dtype)
Q = torch.zeros((M,N), device=device, dtype=dtype)

In [13]:
Q = A.clone()
qr_func.qr_orthogonalization(Q, 0)

In [14]:
torch.norm(Q @ Q.T - torch.eye(M, device = device, dtype=dtype))

tensor(4.5132e-07, device='cuda:0')

In [116]:
Q_or = original_orth(A.clone(), 0)

In [117]:
torch.norm(Q - Q_or)

tensor(0.0074, device='cuda:0')

In [118]:
torch.norm(Q_or @ Q_or.T - torch.eye(M, device = device))

tensor(0.0110, device='cuda:0')

In [17]:
ms = [4, 8, 16]
ns = [1024, 2048]

for M, N in itertools.product(ms, ns):
    A = torch.rand((N, M), device=device, dtype=dtype)
    A = A.T
    print(f"({M}, {N}): ", end='')
    %timeit torch.cuda.synchronize(); qr_func.qr_orthogonalization(A.contiguous(), 0); torch.cuda.synchronize();

(4, 1024): 166 µs ± 3.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(4, 2048): 205 µs ± 4.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(8, 1024): 284 µs ± 7.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(8, 2048): 425 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(16, 1024): 1.56 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(16, 2048): 2.65 ms ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [126]:
%timeit torch.cuda.synchronize(); qr_func.qr_orthogonalization(A.clone(), 0); torch.cuda.synchronize();

3.72 ms ± 30.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [127]:
%timeit torch.cuda.synchronize(); torch.qr(A.clone()); torch.cuda.synchronize();

85.2 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
