In [1]:
import torch
from torch.utils.cpp_extension import load
from algo import original_orth
import itertools

In [2]:
qr_func = load(name="qr_func", sources=["qr_orthogonalization.cpp", "qr_orthogonalization.cu"])

In [3]:
M = 30
N = 4000
device = torch.device("cuda")

A = torch.rand((M, N), device=device)
Q = A.clone()

In [4]:
qr_func.qr_orthogonalization(Q, 0)

In [5]:
Q @ Q.T

tensor([[ 1.0000e+00, -4.6566e-10,  2.2119e-09,  5.8208e-09,  3.7253e-09,
         -4.0454e-09,  4.0745e-09, -3.4925e-09,  3.2596e-09,  1.1642e-10,
          7.5670e-10, -5.7626e-09, -2.9104e-09,  2.7358e-09,  5.8208e-10,
          4.0745e-09, -6.9849e-10,  5.5879e-09,  1.8626e-09,  4.6566e-09,
         -4.6566e-10, -2.6193e-09, -1.0477e-09,  9.8953e-10, -3.0268e-09,
         -9.3132e-10,  5.0641e-09, -7.2177e-09, -6.7521e-09,  9.8953e-10],
        [-4.6566e-10,  1.0000e+00, -8.1491e-09,  4.4238e-09,  4.6566e-10,
         -6.4028e-10, -9.3132e-09,  5.2387e-09, -8.8476e-09, -5.4715e-09,
         -3.9581e-09, -1.2806e-08, -1.2456e-08,  4.6566e-10, -3.3004e-08,
         -2.5611e-09, -3.8417e-09, -5.8208e-09, -1.9325e-08,  4.0745e-09,
         -1.6764e-08, -5.9372e-09, -1.1642e-09, -1.6880e-08, -6.9849e-10,
          6.9849e-09, -6.4611e-09,  4.4238e-09,  1.1642e-10, -2.8522e-09],
        [ 2.2119e-09, -8.1491e-09,  1.0000e+00,  6.7521e-09,  5.1223e-09,
         -5.6461e-09,  4.6566e-10, -

In [6]:
torch.norm(Q @ Q.T - torch.eye(M, device = device))

tensor(2.2645e-06, device='cuda:0')

In [7]:
def my_qr(A):
     qr_func.qr_orthogonalization(A, 0);

In [8]:
def timing(func, A, repeat=1000):
    start = torch.cuda.Event(enable_timing=True, blocking=True)
    stop  = torch.cuda.Event(enable_timing=True, blocking=True)

    start.record()
    for n in range(repeat):
        func(A)

    stop.record()
    stop.synchronize()
    return start.elapsed_time(stop) / repeat * 1000
    

In [11]:
ms = [2, 4, 8, 16, 32, 64, 128]
ns = [256, 512, 1024, 2048, 4096]
device = torch.device("cuda")

print("shape\t\tcustom (us)\ttorch.qr (us)\t\tspeedup")
for M, N in itertools.product(ms, ns):
    A = torch.rand((M, N), device=device)
    torch_time = timing(torch.qr, A)
    my_time = timing(my_qr, A)
    speedup = torch_time / my_time

    print("({:d}, {:d}):\t {:.1f} \t \t {:.1f} \t\t {:.2f}x".format(M, N, my_time, torch_time, speedup))

shape		custom (us)	torch.qr (us)		speedup
(2, 256):	 126.4 	 	 20480.7 		 162.04x
(2, 512):	 89.1 	 	 30890.0 		 346.55x


KeyboardInterrupt: 