# numba

In [1]:
from numba import cuda

In [2]:
[i for i in cuda.devices.gpus]

[<numba.cuda.cudadrv.devices._DeviceContextManager at 0x7f4625ea8d30>]

In [3]:
import numba

In [4]:
numba.cuda.detect()

Found 1 CUDA devices
id 0    b'NVIDIA GeForce RTX 2060 SUPER'                              [SUPPORTED]
                      Compute Capability: 7.5
                           PCI Device ID: 0
                              PCI Bus ID: 38
                                    UUID: GPU-f5dcddd0-bc57-5ebb-7578-229367d62be8
                                Watchdog: Enabled
             FP32/FP64 Performance Ratio: 32
Summary:
	1/1 devices are supported


True

In [5]:
device = cuda.get_current_device()

In [6]:
cd = device.get_primary_context()

In [7]:
cd.get_max_potential_block_size?

[0;31mSignature:[0m
[0mcd[0m[0;34m.[0m[0mget_max_potential_block_size[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfunc[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mb2d_func[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmemsize[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mblocksizelimit[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mflags[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Suggest a launch configuration with reasonable occupancy.
:param func: kernel for which occupancy is calculated
:param b2d_func: function that calculates how much per-block dynamic
                 shared memory 'func' uses based on the block size.
                 Can also be the address of a C function.
                 Use `0` to pass `NULL` to the underlying CUDA API.
:param memsize: per-block dynamic shared memory usage intended, in bytes
:param blocksizelimit: maximum block size the kernel is designed to
        

In [8]:
cd.get_memory_info()

MemoryInfo(free=7430602752, total=8346664960)

In [9]:
device

<weakproxy at 0x7f4625eb1f40 to Device at 0x7f46605dc880>

# gradient

In [23]:
import torch

In [41]:
from importlib import reload

In [42]:
import pysdtw
pysdtw = reload(pysdtw)

In [25]:
batch_size, seq_len_a, seq_len_b, dims = 10, 512, 1023, 15

In [26]:
A = torch.rand((batch_size, seq_len_a, dims), requires_grad=True)
B = torch.rand((batch_size, seq_len_b, dims))

In [43]:
sdtw = pysdtw.SoftDTW()

In [44]:
%%timeit
sdtw(A.cuda(), B.cuda())



11.1 ms ± 442 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [45]:
sdtw_cpu = pysdtw.SoftDTW(use_cuda=False)

In [46]:
%%timeit
sdtw_cpu(A, B)

60.1 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [47]:
res = sdtw(A.cuda(), B.cuda())

In [48]:
loss = res.sum()

In [49]:
loss.backward()

In [50]:
A.grad[0,0]

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

In [51]:
Ac = A.detach().clone().requires_grad_(True)

In [52]:
res_leg = sdtw_cpu(Ac, B)

In [53]:
loss_leg = res_leg.sum()

In [54]:
loss_leg.backward()

In [55]:
Ac.grad[0,0]

tensor([ 0.6941, -0.3536,  1.7750, -1.5856, -0.1960,  0.8659,  1.0421,  0.1999,
        -0.4437,  0.6173,  2.2783, -0.2620,  0.1124, -0.2571,  1.2449])

# forward loop

In [72]:
MAX_THREADS_PER_BLOCK = 128
M = 5
N = 10

In [85]:
max_i = M
max_j = N

In [86]:
threads_per_block = max(N, M)
n_passes = N + M - 1

In [87]:
antidiag = []
for p in range(n_passes):
    
    antidiag_row = []
    
    for thread_id in range(threads_per_block):
        I = thread_id
        tid = thread_id

        J = max(0, min(p - tid, max_j - 1))

        i = I + 1
        j = J + 1

        if I + J == p and (I < max_i and J < max_j):
            antidiag_row.append((i, j))
    
    print(antidiag_row)
    antidiag.append(antidiag_row)

[(1, 1)]
[(1, 2), (2, 1)]
[(1, 3), (2, 2), (3, 1)]
[(1, 4), (2, 3), (3, 2), (4, 1)]
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
[(1, 6), (2, 5), (3, 4), (4, 3), (5, 2)]
[(1, 7), (2, 6), (3, 5), (4, 4), (5, 3)]
[(1, 8), (2, 7), (3, 6), (4, 5), (5, 4)]
[(1, 9), (2, 8), (3, 7), (4, 6), (5, 5)]
[(1, 10), (2, 9), (3, 8), (4, 7), (5, 6)]
[(2, 10), (3, 9), (4, 8), (5, 7)]
[(3, 10), (4, 9), (5, 8)]
[(4, 10), (5, 9)]
[(5, 10)]


In [88]:
antidiag_leg = antidiag

In [89]:
MAX_THREADS_PER_BLOCK = 1024

In [90]:
T = min(min(N, M), MAX_THREADS_PER_BLOCK)
n_passes = min(N, M) // MAX_THREADS_PER_BLOCK + 1
n_antidiag = M + N - 1

In [91]:
n_passes, T

(1, 5)

In [92]:
max_i, max_j

(5, 10)

In [93]:
antidiag = []
for a in range(n_antidiag):
    
    antidiag_row = []
    
    for thread_id in range(T):
        for p in range(n_passes):
            J = thread_id + p*MAX_THREADS_PER_BLOCK
            I = a - thread_id - p*MAX_THREADS_PER_BLOCK
            
            i = I + 1
            j = J + 1

            if (I + J == a) and (I < max_i and J < max_j) and (I > -1):
                antidiag_row.append((i, j))
    
    print(antidiag_row)  
    antidiag.append(antidiag_row)

[(1, 1)]
[(2, 1), (1, 2)]
[(3, 1), (2, 2), (1, 3)]
[(4, 1), (3, 2), (2, 3), (1, 4)]
[(5, 1), (4, 2), (3, 3), (2, 4), (1, 5)]
[(5, 2), (4, 3), (3, 4), (2, 5)]
[(5, 3), (4, 4), (3, 5)]
[(5, 4), (4, 5)]
[(5, 5)]
[]
[]
[]
[]
[]


In [84]:
for i, j in zip(antidiag, antidiag_leg):
    assert(set(i)==set(j))

In [39]:
batch_size, seq_len_a, seq_len_b, dims = 2, 10, 15, 3

In [59]:
A = torch.rand((batch_size, seq_len_a, dims))
B = torch.rand((batch_size, seq_len_b, dims))

In [60]:
sdtw = pysdtw.SoftDTW()

In [61]:
sdtw(A.cuda(), B.cuda())

tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], device='cuda:0')

# backward loop

In [62]:
MAX_THREADS_PER_BLOCK = 128
M = 12
N = 21

In [63]:
max_i = N
max_j = M

In [64]:
threads_per_block = max(N, M)
n_passes = N + M - 1

In [65]:
antidiag = []
for p in range(n_passes):
    
    rev_p = n_passes - p - 1

    antidiag_row = []
    
    for thread_id in range(threads_per_block):
        I = thread_id
        tid = thread_id

        J = max(0, min(rev_p - tid, max_j - 1))

        i = I + 1
        j = J + 1

        # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
        if I + J == rev_p and (I < max_i and J < max_j):
            antidiag_row.append((i, j))
    
    print(antidiag_row)
    antidiag.append(antidiag_row)

[(21, 12)]
[(20, 12), (21, 11)]
[(19, 12), (20, 11), (21, 10)]
[(18, 12), (19, 11), (20, 10), (21, 9)]
[(17, 12), (18, 11), (19, 10), (20, 9), (21, 8)]
[(16, 12), (17, 11), (18, 10), (19, 9), (20, 8), (21, 7)]
[(15, 12), (16, 11), (17, 10), (18, 9), (19, 8), (20, 7), (21, 6)]
[(14, 12), (15, 11), (16, 10), (17, 9), (18, 8), (19, 7), (20, 6), (21, 5)]
[(13, 12), (14, 11), (15, 10), (16, 9), (17, 8), (18, 7), (19, 6), (20, 5), (21, 4)]
[(12, 12), (13, 11), (14, 10), (15, 9), (16, 8), (17, 7), (18, 6), (19, 5), (20, 4), (21, 3)]
[(11, 12), (12, 11), (13, 10), (14, 9), (15, 8), (16, 7), (17, 6), (18, 5), (19, 4), (20, 3), (21, 2)]
[(10, 12), (11, 11), (12, 10), (13, 9), (14, 8), (15, 7), (16, 6), (17, 5), (18, 4), (19, 3), (20, 2), (21, 1)]
[(9, 12), (10, 11), (11, 10), (12, 9), (13, 8), (14, 7), (15, 6), (16, 5), (17, 4), (18, 3), (19, 2), (20, 1)]
[(8, 12), (9, 11), (10, 10), (11, 9), (12, 8), (13, 7), (14, 6), (15, 5), (16, 4), (17, 3), (18, 2), (19, 1)]
[(7, 12), (8, 11), (9, 10), (10,

In [66]:
MAX_THREADS_PER_BLOCK = 128

In [67]:
T = min(min(N, M), MAX_THREADS_PER_BLOCK)
n_passes = min(N, M) // MAX_THREADS_PER_BLOCK + 1
n_antidiag = M + N - 1

In [68]:
n_passes, T

(1, 12)

In [69]:
max_i, max_j

(21, 12)

In [71]:
antidiag = []
for a in range(n_antidiag):

    rev_a = n_antidiag - a - 1

    antidiag_row = []
    
    for thread_id in range(T):
        for p in range(n_passes):
            J = thread_id + p*MAX_THREADS_PER_BLOCK
            I = rev_a - thread_id - p*MAX_THREADS_PER_BLOCK
            
            i = I + 1
            j = J + 1

            if I + J == rev_a and (I < max_i and J < max_j) and (I > -1):
                antidiag_row.append((i, j))
    
    print(antidiag_row)  
    antidiag.append(antidiag_row)

[(21, 12)]
[(21, 11), (20, 12)]
[(21, 10), (20, 11), (19, 12)]
[(21, 9), (20, 10), (19, 11), (18, 12)]
[(21, 8), (20, 9), (19, 10), (18, 11), (17, 12)]
[(21, 7), (20, 8), (19, 9), (18, 10), (17, 11), (16, 12)]
[(21, 6), (20, 7), (19, 8), (18, 9), (17, 10), (16, 11), (15, 12)]
[(21, 5), (20, 6), (19, 7), (18, 8), (17, 9), (16, 10), (15, 11), (14, 12)]
[(21, 4), (20, 5), (19, 6), (18, 7), (17, 8), (16, 9), (15, 10), (14, 11), (13, 12)]
[(21, 3), (20, 4), (19, 5), (18, 6), (17, 7), (16, 8), (15, 9), (14, 10), (13, 11), (12, 12)]
[(21, 2), (20, 3), (19, 4), (18, 5), (17, 6), (16, 7), (15, 8), (14, 9), (13, 10), (12, 11), (11, 12)]
[(21, 1), (20, 2), (19, 3), (18, 4), (17, 5), (16, 6), (15, 7), (14, 8), (13, 9), (12, 10), (11, 11), (10, 12)]
[(20, 1), (19, 2), (18, 3), (17, 4), (16, 5), (15, 6), (14, 7), (13, 8), (12, 9), (11, 10), (10, 11), (9, 12)]
[(19, 1), (18, 2), (17, 3), (16, 4), (15, 5), (14, 6), (13, 7), (12, 8), (11, 9), (10, 10), (9, 11), (8, 12)]
[(18, 1), (17, 2), (16, 3), (15,

# pairwise

In [30]:
def _euclidean_dist_func(x, y):
    """
    Calculates the Euclidean distance between each element in x and y per timestep
    """
    n = x.size(1)
    m = y.size(1)
    d = x.size(2)
    x = x.unsqueeze(2).expand(-1, n, m, d)
    y = y.unsqueeze(1).expand(-1, n, m, d)
    return torch.pow(x - y, 2).sum(3)

In [31]:
def pairwise_l2_squared_opt(x, y, theta):
    '''
    https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2

    Input: x is an Nxd matrix
           y is an Mxd matrix
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (theta * x**2).sum(1).view(-1, 1)
    y_t = torch.transpose(y, 0, 1)
    y_norm = (theta * y**2).sum(1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(theta*x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

import numpy as np

def pairwise_l2_squared(x, y):
    '''
    https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2

    Input: x is an Nxd matrix
           y is an Mxd matrix
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(-1).unsqueeze(-1)
    y_norm = (y**2).sum(-1).unsqueeze(-2)
    dist = x_norm + y_norm - 2.0 * torch.bmm(x, y.mT)
    return torch.clamp(dist, 0.0, np.inf)

In [35]:
# %%timeit
res_a = _euclidean_dist_func(a_cpu, b_cpu)

In [36]:
# %%timeit
res_b = pairwise_l2_squared(a_cpu, b_cpu)

In [37]:
torch.allclose(res_a, res_b)

True

In [47]:
theta = torch.ones(dims)

In [50]:
# pairwise_l2_squared(a_cpu, b_cpu, torch.ones(dims)).shape

In [64]:
batch_size, seq_len_a, seq_len_b, dims = 10, 512, 1023, 15

In [65]:
a_cpu = torch.rand((batch_size, seq_len_a, dims))
b_cpu = torch.rand((batch_size, seq_len_b, dims))

In [66]:
A = a_cpu
B = b_cpu

In [81]:
A.shape

torch.Size([10, 512, 15])

In [92]:
x_norm = (A**2).sum(-1).unsqueeze(-1)

In [93]:
x_norm.shape

torch.Size([10, 512, 1])

In [94]:
B.shape

torch.Size([10, 1023, 15])

In [95]:
y_norm = (B**2).sum(-1).unsqueeze(-2)

In [96]:
y_norm.shape

torch.Size([10, 1, 1023])

In [97]:
(x_norm + y_norm).shape

torch.Size([10, 512, 1023])

In [102]:
torch.bmm(A, B.mT).shape

torch.Size([10, 512, 1023])

In [86]:
x, y = torch.rand((seq_len_a, dims)), torch.rand((seq_len_b, dims))

In [88]:
x_norm = (x**2).sum(1).view(-1, 1)
y_norm = (y**2).sum(1).view(1, -1)

In [90]:
y_t = torch.transpose(y, 0, 1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)

In [91]:
x.shape, y_t.shape

(torch.Size([512, 15]), torch.Size([15, 1023]))