In [97]:
import torch
from torch.utils.cpp_extension import load
from algo import original_orth
import itertools

In [160]:
qr_func = load(name="qr_func", sources=["qr_orthogonalization.cpp", "qr_orthogonalization.cu"])

In [161]:
M = 8
N = 1024
device = torch.device("cuda")

A = torch.rand((M, N), device=device)
Q = torch.empty((M, N), device=device)

In [162]:
A

tensor([[0.7412, 0.2401, 0.4562,  ..., 0.4902, 0.0404, 0.6421],
        [0.0899, 0.2945, 0.1316,  ..., 0.1007, 0.6630, 0.6538],
        [0.1928, 0.0536, 0.2674,  ..., 0.2595, 0.8048, 0.0408],
        ...,
        [0.5301, 0.3110, 0.5386,  ..., 0.9007, 0.3789, 0.8954],
        [0.7383, 0.9444, 0.8923,  ..., 0.1863, 0.7965, 0.1755],
        [0.3619, 0.4864, 0.4673,  ..., 0.4036, 0.1033, 0.1025]],
       device='cuda:0')

In [163]:
qr_func.qr_orthogonalization(A, out=Q)

tensor([[-0.0399,  0.0356,  0.0134,  ..., -0.0276,  0.0023, -0.0355],
        [-0.0129, -0.0098,  0.0156,  ...,  0.0144, -0.0499, -0.0219],
        [-0.0246,  0.0158, -0.0020,  ..., -0.0028, -0.0430,  0.0461],
        ...,
        [-0.0062, -0.0654, -0.0310,  ...,  0.0573, -0.0225,  0.0387],
        [-0.0181, -0.0523,  0.0046,  ...,  0.0223, -0.0263,  0.0400],
        [-0.0086, -0.0328, -0.0090,  ..., -0.0004, -0.0490, -0.0340]],
       device='cuda:0')

In [166]:
Q @ Q.T

tensor([[ 1.0000e+00, -3.7253e-09,  1.3970e-08, -3.0268e-09,  5.5879e-09,
          6.0536e-09, -1.1176e-08,  1.8626e-09],
        [-3.7253e-09,  1.0000e+00,  1.1642e-09, -9.3132e-09, -1.0245e-08,
          6.9849e-09,  2.2817e-08,  4.6566e-10],
        [ 1.3970e-08,  1.1642e-09,  1.0000e+00, -2.5146e-08,  6.5193e-09,
          6.5193e-09,  0.0000e+00, -9.3132e-09],
        [-3.0268e-09, -9.3132e-09, -2.5146e-08,  1.0000e+00, -4.1910e-09,
         -1.0245e-08, -1.0012e-08,  1.7928e-08],
        [ 5.5879e-09, -1.0245e-08,  6.5193e-09, -4.1910e-09,  1.0000e+00,
         -3.9581e-09,  1.8626e-09, -1.6764e-08],
        [ 6.0536e-09,  6.9849e-09,  6.5193e-09, -1.0245e-08, -3.9581e-09,
          1.0000e+00, -1.8626e-09,  2.8056e-08],
        [-1.1176e-08,  2.2817e-08,  0.0000e+00, -1.0012e-08,  1.8626e-09,
         -1.8626e-09,  1.0000e+00, -2.1420e-08],
        [ 1.8626e-09,  4.6566e-10, -9.3132e-09,  1.7928e-08, -1.6764e-08,
          2.8056e-08, -2.1420e-08,  1.0000e+00]], device='cuda:0'

In [167]:
torch.norm(Q @ Q.T - torch.eye(M, device = device))

tensor(7.8690e-07, device='cuda:0')

In [168]:
def my_qr(A):
     qr_func.qr_orthogonalization(A);

In [169]:
def timing(func, A, repeat=1000):
    start = torch.cuda.Event(enable_timing=True, blocking=True)
    stop  = torch.cuda.Event(enable_timing=True, blocking=True)

    start.record()
    for n in range(repeat):
        func(A)

    stop.record()
    stop.synchronize()
    return start.elapsed_time(stop) / repeat * 1000
    

In [170]:
ms = [2, 4, 8, 16, 32, 64, 128]
ns = [256, 512, 1024, 2048, 4096]
device = torch.device("cuda")

print("shape\t\tcustom (us)\ttorch.qr (us)\t\tspeedup")
for M, N in itertools.product(ms, ns):
    A = torch.rand((M, N), device=device)
    torch_time = timing(torch.qr, A)
    my_time = timing(my_qr, A)
    speedup = torch_time / my_time

    print("({:d}, {:d}):\t {:.1f} \t \t {:.1f} \t\t {:.2f}x".format(M, N, my_time, torch_time, speedup))

shape		custom (us)	torch.qr (us)		speedup
(2, 256):	 89.5 	 	 25022.6 		 279.72x


KeyboardInterrupt: 