In [1]:
! nvidia-smi -L | cut -d '(' -f 1

GPU 0: GeForce RTX 2070 SUPER 


In [13]:
import torch
import time
import itertools

nb = 500

def main(s: str):
    def prof(b_, n_, dtype, f):
        # print(b_, n_)
        x = torch.randn(*b_, n_, n_, device='cuda', dtype=dtype)

        xc = x.clone().cpu()

        t1 = time.time()
        for _ in range(nb):
            yc = torch.inverse(xc)
        t2 = time.time()
        cpu_time = (t2-t1)/nb*1e3
        # print('cpu', cpu_time, 'ms')

        for _ in range(nb):
            y = torch.inverse(x)
        torch.cuda.synchronize()

        c, d = torch.testing._compare_tensors_internal(xc.cuda(), x, rtol=1e-7, atol=1e-7, equal_nan=False)
        if not c:
            print('original matrix compare')
            print(d)
            raise RuntimeError('original value modified')

        torch.cuda.synchronize()

        t1 = time.time()
        for _ in range(nb):
            y = torch.inverse(x)
        torch.cuda.synchronize()
        t2 = time.time()
        gpu_time = (t2-t1)/nb*1e3
        # print('gpu', gpu_time, 'ms')

        a, b = torch.testing._compare_tensors_internal(yc.cuda(), y, rtol=1e-3, atol=1e-3, equal_nan=False)
        if not a:
            print('numerical mismatch: inverse value compare')
            print(b)

        print(f'{b_} {n_} {dtype}'.ljust(35) + f'{cpu_time : .3f}  {gpu_time : .3f}')
        f.write(f'{b_} {n_} {dtype}; ' + f'{cpu_time : .3e}, {gpu_time : .3e}\n')
        torch.cuda.synchronize()
    
    print(s)
    print(torch.__version__)
    print()
    print('batch_size, matrix_size, dtype'.ljust(35) + 'cpu_time(ms), gpu_time(ms)')
    
    shapes = itertools.product(
        [[]] + [[2**x] for x in range(3)],
        [2**i for i in range(1, 11)],
        [torch.float]
    )

    with open(s+'.txt', 'w') as f:
        for b, n, dtype in shapes:
            if len(b) > 0 and b[0] * n >= 2**15:
                continue
            prof(b, n, dtype, f)


In [2]:
main('before')

before
1.7.0a0+4ae832e

batch_size, matrix_size, dtype     cpu_time(ms), gpu_time(ms)
[] 2 torch.float32                  0.011   7.446
[] 4 torch.float32                  0.009   7.427
[] 8 torch.float32                  0.011   7.571
[] 16 torch.float32                 0.016   7.522
[] 32 torch.float32                 0.033   7.548
[] 64 torch.float32                 0.072   7.708
[] 128 torch.float32                0.352   8.024
[] 256 torch.float32                1.141   11.338
[] 512 torch.float32                5.312   15.013
[] 1024 torch.float32               19.364   19.271
[1] 2 torch.float32                 0.009   0.114
[1] 4 torch.float32                 0.009   0.117
[1] 8 torch.float32                 0.011   0.126
[1] 16 torch.float32                0.017   0.127
[1] 32 torch.float32                0.033   0.178
[1] 64 torch.float32                0.072   0.420
[1] 128 torch.float32               0.294   0.801
[1] 256 torch.float32               1.044   1.674
[1] 512 to

In [16]:
main('after')

after
1.7.0a0+de440b5

batch_size, matrix_size, dtype     cpu_time(ms), gpu_time(ms)
[] 2 torch.float32                  0.022   0.168
[] 4 torch.float32                  0.014   0.151
[] 8 torch.float32                  0.016   0.152
[] 16 torch.float32                 0.021   0.154
[] 32 torch.float32                 0.040   0.204
[] 64 torch.float32                 0.083   0.292
[] 128 torch.float32                0.314   0.515
[] 256 torch.float32                1.059   1.095
[] 512 torch.float32                4.988   2.602
[] 1024 torch.float32               17.429   7.000
[1] 2 torch.float32                 0.009   0.142
[1] 4 torch.float32                 0.009   0.142
[1] 8 torch.float32                 0.011   0.140
[1] 16 torch.float32                0.015   0.130
[1] 32 torch.float32                0.032   0.173
[1] 64 torch.float32                0.069   0.270
[1] 128 torch.float32               0.387   0.492
[1] 256 torch.float32               1.223   1.101
[1] 512 torch.

In [17]:
import re

def readfile(fn):
    with open(fn, 'r') as f:
        fl = f.readlines()
    
    dc = {}
    dg = {}
    for _line in fl:
        key, cpu_time, gpu_time = re.split(';|,', _line.rstrip())
        dc[key] = float(cpu_time)
        dg[key] = float(gpu_time)
    
    return (dc, dg)

def compare(f, before: str, *afters):
    assert len(afters) >= 1, 'provide at least one after data'

    print('shape'.ljust(26), 'cpu_time, gpu_time_before (magma)', end='')
    f.write('| shape | cpu_time (ms) | gpu_time_before (magma) (ms) |')
    for after in afters:
        print(', gpu_time_' + after.rstrip('.txt'), end='')
        f.write(' gpu_time_' + after.rstrip('.txt') + ' (ms) |')
    print()
    f.write('\n')
    f.write('| --- ' * (len(afters) + 3) + '| \n')

    dc_b, dg_b = readfile(before)
    dc_as = []
    dg_as = []
    for after in afters:
        dc_a, dg_a = readfile(after)
        dc_as.append(dc_a)
        dg_as.append(dg_a)
    
    for key in dc_b:
        cpu_time = (dc_b[key] + sum(dc_a[key] for dc_a in dc_as)) / (1 + len(dc_as))
        gpu_time_before = dg_b[key]
        gpu_time_after = dg_as[0][key]
        
        if gpu_time_after > gpu_time_before:
            gs = ' ' * 5 + '*' * 20 + ' regressed'
            gss = '***regressed'
        else:
            gs = ''
            gss = ''

        print(f'{key: <26} {cpu_time: .3f}, {gpu_time_before: .3f}, {gpu_time_after: .3f}, ' + ' '*5, end='')
        f.write(f'| {key} | {cpu_time: .3f} | {gpu_time_before: .3f} | {gpu_time_after: .3f} {gss} | ')
        for dg_a in dg_as[1:]:
            gpu_time_after = dg_a[key]
            print(f'{gpu_time_after: .3f}, ', end='')
            f.write(f'{gpu_time_after: .3f} |')
        print(gs)
        f.write('\n')

with open('table.md', 'w') as f:
    compare(f, 'before.txt', 'after.txt')

shape                      cpu_time, gpu_time_before (magma), gpu_time_after
[] 2 torch.float32          0.016,  7.446,  0.168,      
[] 4 torch.float32          0.012,  7.427,  0.150,      
[] 8 torch.float32          0.013,  7.571,  0.152,      
[] 16 torch.float32         0.018,  7.522,  0.154,      
[] 32 torch.float32         0.036,  7.548,  0.204,      
[] 64 torch.float32         0.078,  7.708,  0.292,      
[] 128 torch.float32        0.333,  8.024,  0.515,      
[] 256 torch.float32        1.100,  11.340,  1.095,      
[] 512 torch.float32        5.150,  15.010,  2.602,      
[] 1024 torch.float32       18.395,  19.270,  7.000,      
[1] 2 torch.float32         0.009,  0.114,  0.142,           ******************** regressed
[1] 4 torch.float32         0.009,  0.117,  0.142,           ******************** regressed
[1] 8 torch.float32         0.011,  0.126,  0.140,           ******************** regressed
[1] 16 torch.float32        0.016,  0.127,  0.130,           ************