In [61]:
import torch
from torch.utils.cpp_extension import load
from algo import original_orth
import itertools

In [60]:
qr_func = load(name="qr_func", sources=["qr_orthogonalization.cpp", "qr_orthogonalization.cu"])

In [55]:
M = 30
N = 4000
device = torch.device("cuda")

A = torch.rand((M, N), device=device)
Q = A.clone()

In [57]:
qr_func.qr_orthogonalization(Q, 0)

In [58]:
Q @ Q.T

tensor([[ 1.0001e+00,  8.8540e-07,  5.1019e-08, -3.3227e-07, -6.1694e-06,
          1.0363e-06,  2.6710e-06, -4.6850e-07,  1.5151e-07, -2.0532e-06,
          4.6788e-06,  4.4151e-08, -2.1422e-06,  2.5018e-07,  2.2451e-06,
          1.3324e-05,  3.7383e-06, -1.2185e-03,  8.6240e-07,  4.0722e-05,
          2.4745e-04, -1.7693e-05, -2.2685e-06, -1.8247e-06, -2.0381e-06,
         -1.4418e-06, -2.9174e-06,  1.3256e-06, -1.8249e-06, -7.3074e-07],
        [ 8.8540e-07,  9.9999e-01,  5.7754e-06, -6.4053e-08,  1.1273e-05,
          7.2655e-06,  2.4691e-06, -7.9597e-06,  9.1083e-07,  3.9904e-06,
          5.9813e-07,  5.2894e-06,  2.6480e-06, -1.8996e-06,  1.2384e-05,
         -1.7341e-06,  6.2445e-06,  9.2773e-04,  1.0135e-05,  1.5939e-04,
         -9.3828e-05,  8.0689e-05,  5.4168e-06,  1.9406e-07,  2.4431e-05,
          2.8301e-05,  2.2813e-04,  3.6092e-05,  3.6647e-05,  5.6839e-06],
        [ 5.1019e-08,  5.7754e-06,  1.0000e+00, -2.0344e-06, -2.0909e-06,
         -5.2672e-06, -3.3649e-06, -

In [59]:
torch.norm(Q @ Q.T - torch.eye(M, device = device))

tensor(0.0806, device='cuda:0')

In [63]:
def my_qr(A):
     qr_func.qr_orthogonalization(A, 0);

In [69]:
def timing(func, A, repeat=1000):
    start = torch.cuda.Event(enable_timing=True, blocking=True)
    stop  = torch.cuda.Event(enable_timing=True, blocking=True)

    start.record()
    for n in range(repeat):
        func(A)

    stop.record()
    stop.synchronize()
    return start.elapsed_time(stop) / repeat * 1000
    

In [73]:
ms = [2, 4, 8, 16, 32]
ns = [256, 512, 1024, 2048, 4096, 8192]
device = torch.device("cuda")

print("shape\t\tcustom (us)\ttorch.qr (us)\t\tspeedup")
for M, N in itertools.product(ms, ns):
    A = torch.rand((M, N), device=device)
    torch_time = timing(torch.qr, A)
    my_time = timing(my_qr, A)
    speedup = torch_time / my_time

    print("({:d}, {:d}):\t {:.1f} \t \t {:.1f} \t\t {:.2f}x".format(M, N, my_time, torch_time, speedup))

shape		custom (us)	torch.qr (us)		speedup
