# numba

In [1]:
from numba import cuda

In [2]:
[i for i in cuda.devices.gpus]

[<numba.cuda.cudadrv.devices._DeviceContextManager at 0x7f0b5bea84c0>]

In [3]:
import numba

In [4]:
numba.cuda.detect()

Found 1 CUDA devices
id 0    b'NVIDIA GeForce RTX 2060 SUPER'                              [SUPPORTED]
                      Compute Capability: 7.5
                           PCI Device ID: 0
                              PCI Bus ID: 38
                                    UUID: GPU-f5dcddd0-bc57-5ebb-7578-229367d62be8
                                Watchdog: Enabled
             FP32/FP64 Performance Ratio: 32
Summary:
	1/1 devices are supported


True

In [5]:
device = cuda.get_current_device()

In [6]:
cd = device.get_primary_context()

In [7]:
cd.get_max_potential_block_size?

[0;31mSignature:[0m
[0mcd[0m[0;34m.[0m[0mget_max_potential_block_size[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfunc[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mb2d_func[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmemsize[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mblocksizelimit[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mflags[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Suggest a launch configuration with reasonable occupancy.
:param func: kernel for which occupancy is calculated
:param b2d_func: function that calculates how much per-block dynamic
                 shared memory 'func' uses based on the block size.
                 Can also be the address of a C function.
                 Use `0` to pass `NULL` to the underlying CUDA API.
:param memsize: per-block dynamic shared memory usage intended, in bytes
:param blocksizelimit: maximum block size the kernel is designed to
        

In [8]:
cd.get_memory_info()

MemoryInfo(free=7392985088, total=8346664960)

In [9]:
device

<weakproxy at 0x7f0b5bea97c0 to Device at 0x7f0b9462ae50>

# gradient

In [10]:
import torch

In [11]:
from pysdtw import SoftDTW

In [12]:
batch_size, seq_len_a, seq_len_b, dims = 10, 512, 1023, 15

In [13]:
A = torch.rand((batch_size, seq_len_a, dims), requires_grad=True)
B = torch.rand((batch_size, seq_len_b, dims))

In [14]:
sdtw = SoftDTW()

In [15]:
%%timeit
sdtw(A.cuda(), B.cuda())

128 5 1534




128 5 1534
128 5 1534
128 5 1534
128 5 1534
128 5 1534
128 5 1534
128 5 1534
20 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)




In [16]:
sdtw_cpu = SoftDTW(use_cuda=False)

In [17]:
%%timeit
sdtw_cpu(A, B)

346 ms ± 5.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
res = sdtw(A.cuda(), B.cuda())

128 5 1534




In [19]:
loss = res.sum()

In [20]:
loss.backward()

In [21]:
A.grad[0,0]

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

In [23]:
Ac = A.detach().clone().requires_grad_(True)

In [24]:
res_leg = sdtw_cpu(Ac, B)

In [25]:
loss_leg = res_leg.sum()

In [26]:
loss_leg.backward()

In [27]:
Ac.grad[0,0]

tensor([-0.9011, -1.5398,  1.4161,  1.2702, -0.8651,  1.8042, -0.5897,  1.2158,
         1.2448, -0.4859, -0.1417, -0.3461, -0.5727, -1.5326, -0.7352])

# gradient loop

In [29]:
MAX_THREADS_PER_BLOCK = 128
M = 12
N = 21

In [30]:
max_i = N
max_j = M

In [31]:
threads_per_block = max(N, M)
n_passes = N + M - 1

In [32]:
antidiag = []
for p in range(n_passes):
    
    antidiag_row = []
    
    for thread_id in range(threads_per_block):
        I = thread_id
        tid = thread_id

        # The index is actually 'p - tid' but need to force it in-bounds
        J = max(0, min(p - tid, max_j - 1))

        # For simplicity, we define i, j which start from 1 (offset from I, J)
        i = I + 1
        j = J + 1

        # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
        if I + J == p and (I < max_i and J < max_j):
            antidiag_row.append((i, j))
    
    print(antidiag_row)
    antidiag.append(antidiag_row)

[(1, 1)]
[(1, 2), (2, 1)]
[(1, 3), (2, 2), (3, 1)]
[(1, 4), (2, 3), (3, 2), (4, 1)]
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
[(1, 6), (2, 5), (3, 4), (4, 3), (5, 2), (6, 1)]
[(1, 7), (2, 6), (3, 5), (4, 4), (5, 3), (6, 2), (7, 1)]
[(1, 8), (2, 7), (3, 6), (4, 5), (5, 4), (6, 3), (7, 2), (8, 1)]
[(1, 9), (2, 8), (3, 7), (4, 6), (5, 5), (6, 4), (7, 3), (8, 2), (9, 1)]
[(1, 10), (2, 9), (3, 8), (4, 7), (5, 6), (6, 5), (7, 4), (8, 3), (9, 2), (10, 1)]
[(1, 11), (2, 10), (3, 9), (4, 8), (5, 7), (6, 6), (7, 5), (8, 4), (9, 3), (10, 2), (11, 1)]
[(1, 12), (2, 11), (3, 10), (4, 9), (5, 8), (6, 7), (7, 6), (8, 5), (9, 4), (10, 3), (11, 2), (12, 1)]
[(2, 12), (3, 11), (4, 10), (5, 9), (6, 8), (7, 7), (8, 6), (9, 5), (10, 4), (11, 3), (12, 2), (13, 1)]
[(3, 12), (4, 11), (5, 10), (6, 9), (7, 8), (8, 7), (9, 6), (10, 5), (11, 4), (12, 3), (13, 2), (14, 1)]
[(4, 12), (5, 11), (6, 10), (7, 9), (8, 8), (9, 7), (10, 6), (11, 5), (12, 4), (13, 3), (14, 2), (15, 1)]
[(5, 12), (6, 11), (7, 10), (8, 9), (

In [33]:
antidiag_leg = antidiag

In [34]:
MAX_THREADS_PER_BLOCK = 5

In [35]:
T = min(min(N, M), MAX_THREADS_PER_BLOCK)
n_passes = min(N, M) // MAX_THREADS_PER_BLOCK + 1
n_antidiag = M + N - 1

In [36]:
n_passes, T

(3, 5)

In [37]:
max_i, max_j

(21, 12)

In [44]:
antidiag = []
for a in range(n_antidiag):
    
    antidiag_row = []
    
    for thread_id in range(T):
        for p in range(n_passes):
            J = thread_id + p*MAX_THREADS_PER_BLOCK
            I = a - thread_id - p*MAX_THREADS_PER_BLOCK
            
            i = I + 1
            j = J + 1

            if I + J == a and (I < max_i and J < max_j) and (I > -1):
                antidiag_row.append((i, j))
    
    print(antidiag_row)  
    antidiag.append(antidiag_row)

[(1, 1)]
[(2, 1), (1, 2)]
[(3, 1), (2, 2), (1, 3)]
[(4, 1), (3, 2), (2, 3), (1, 4)]
[(5, 1), (4, 2), (3, 3), (2, 4), (1, 5)]
[(6, 1), (1, 6), (5, 2), (4, 3), (3, 4), (2, 5)]
[(7, 1), (2, 6), (6, 2), (1, 7), (5, 3), (4, 4), (3, 5)]
[(8, 1), (3, 6), (7, 2), (2, 7), (6, 3), (1, 8), (5, 4), (4, 5)]
[(9, 1), (4, 6), (8, 2), (3, 7), (7, 3), (2, 8), (6, 4), (1, 9), (5, 5)]
[(10, 1), (5, 6), (9, 2), (4, 7), (8, 3), (3, 8), (7, 4), (2, 9), (6, 5), (1, 10)]
[(11, 1), (6, 6), (1, 11), (10, 2), (5, 7), (9, 3), (4, 8), (8, 4), (3, 9), (7, 5), (2, 10)]
[(12, 1), (7, 6), (2, 11), (11, 2), (6, 7), (1, 12), (10, 3), (5, 8), (9, 4), (4, 9), (8, 5), (3, 10)]
[(13, 1), (8, 6), (3, 11), (12, 2), (7, 7), (2, 12), (11, 3), (6, 8), (10, 4), (5, 9), (9, 5), (4, 10)]
[(14, 1), (9, 6), (4, 11), (13, 2), (8, 7), (3, 12), (12, 3), (7, 8), (11, 4), (6, 9), (10, 5), (5, 10)]
[(15, 1), (10, 6), (5, 11), (14, 2), (9, 7), (4, 12), (13, 3), (8, 8), (12, 4), (7, 9), (11, 5), (6, 10)]
[(16, 1), (11, 6), (6, 11), (15, 2), 

In [45]:
for i, j in zip(antidiag, antidiag_leg):
    print(set(i)==set(j))

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [46]:
batch_size, seq_len_a, seq_len_b, dims = 2, 10, 15, 3

In [47]:
A = torch.rand((batch_size, seq_len_a, dims))
B = torch.rand((batch_size, seq_len_b, dims))

In [48]:
sdtw = SoftDTW()

In [49]:
sdtw(A.cuda(), B.cuda())

10 1 24




tensor([inf, inf], device='cuda:0')

# pairwise

In [30]:
def _euclidean_dist_func(x, y):
    """
    Calculates the Euclidean distance between each element in x and y per timestep
    """
    n = x.size(1)
    m = y.size(1)
    d = x.size(2)
    x = x.unsqueeze(2).expand(-1, n, m, d)
    y = y.unsqueeze(1).expand(-1, n, m, d)
    return torch.pow(x - y, 2).sum(3)

In [31]:
def pairwise_l2_squared_opt(x, y, theta):
    '''
    https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2

    Input: x is an Nxd matrix
           y is an Mxd matrix
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (theta * x**2).sum(1).view(-1, 1)
    y_t = torch.transpose(y, 0, 1)
    y_norm = (theta * y**2).sum(1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(theta*x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

import numpy as np

def pairwise_l2_squared(x, y):
    '''
    https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2

    Input: x is an Nxd matrix
           y is an Mxd matrix
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(-1).unsqueeze(-1)
    y_norm = (y**2).sum(-1).unsqueeze(-2)
    dist = x_norm + y_norm - 2.0 * torch.bmm(x, y.mT)
    return torch.clamp(dist, 0.0, np.inf)

In [35]:
# %%timeit
res_a = _euclidean_dist_func(a_cpu, b_cpu)

In [36]:
# %%timeit
res_b = pairwise_l2_squared(a_cpu, b_cpu)

In [37]:
torch.allclose(res_a, res_b)

True

In [47]:
theta = torch.ones(dims)

In [50]:
# pairwise_l2_squared(a_cpu, b_cpu, torch.ones(dims)).shape

In [64]:
batch_size, seq_len_a, seq_len_b, dims = 10, 512, 1023, 15

In [65]:
a_cpu = torch.rand((batch_size, seq_len_a, dims))
b_cpu = torch.rand((batch_size, seq_len_b, dims))

In [66]:
A = a_cpu
B = b_cpu

In [81]:
A.shape

torch.Size([10, 512, 15])

In [92]:
x_norm = (A**2).sum(-1).unsqueeze(-1)

In [93]:
x_norm.shape

torch.Size([10, 512, 1])

In [94]:
B.shape

torch.Size([10, 1023, 15])

In [95]:
y_norm = (B**2).sum(-1).unsqueeze(-2)

In [96]:
y_norm.shape

torch.Size([10, 1, 1023])

In [97]:
(x_norm + y_norm).shape

torch.Size([10, 512, 1023])

In [102]:
torch.bmm(A, B.mT).shape

torch.Size([10, 512, 1023])

In [86]:
x, y = torch.rand((seq_len_a, dims)), torch.rand((seq_len_b, dims))

In [88]:
x_norm = (x**2).sum(1).view(-1, 1)
y_norm = (y**2).sum(1).view(1, -1)

In [90]:
y_t = torch.transpose(y, 0, 1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)

In [91]:
x.shape, y_t.shape

(torch.Size([512, 15]), torch.Size([15, 1023]))