In [46]:
import torch
from torch.utils.cpp_extension import load
from algo import original_orth
import timeit
import itertools

In [57]:
qr_func = load(name="qr_func", sources=["qr_orthogonalization.cpp", "qr_orthogonalization.cu"])

In [58]:
M = 4
N = 1024
device = torch.device("cuda")
dtype = torch.float32

A = torch.rand((M, N), device=device, dtype=dtype)
A.stride()

(1024, 1)

In [59]:
Q = A.T.contiguous().T
Q.stride()

(1, 4)

In [60]:
qr_func.qr_orthogonalization(Q, 0)
qr_func.qr_orthogonalization(A, 0)

In [61]:
Q == A

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [62]:
torch.norm(Q @ Q.T - torch.eye(M, device = device, dtype=dtype))

tensor(2.7431e-07, device='cuda:0')

In [116]:
Q_or = original_orth(A.clone(), 0)

In [117]:
torch.norm(Q - Q_or)

tensor(0.0074, device='cuda:0')

In [118]:
torch.norm(Q_or @ Q_or.T - torch.eye(M, device = device))

tensor(0.0110, device='cuda:0')

In [64]:
ms = [4, 8, 16]
ns = [1024, 2048]

for M, N in itertools.product(ms, ns):
    A = torch.rand((N, M), device=device, dtype=dtype)
    A = A.T
    print(f"({M}, {N}): ", end='')
    %timeit torch.cuda.synchronize(); qr_func.qr_orthogonalization(A.clone(), 0); torch.cuda.synchronize();

(4, 1024): 151 µs ± 2.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(4, 2048): 199 µs ± 7.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(8, 1024): 318 µs ± 4.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(8, 2048): 527 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(16, 1024): 1.99 ms ± 43.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(16, 2048): 5.1 ms ± 94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [126]:
%timeit torch.cuda.synchronize(); qr_func.qr_orthogonalization(A.clone(), 0); torch.cuda.synchronize();

3.72 ms ± 30.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [127]:
%timeit torch.cuda.synchronize(); torch.qr(A.clone()); torch.cuda.synchronize();

85.2 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
