In [1]:
import numpy as np
from torch import nn, Tensor

import torch
import torch.cuda
from numba import cuda, jit, prange 

from s3ts.api.nets.encoders.dtw.dtw_no_matrix import dtw_fast_no_image

In [25]:
torch.manual_seed(45)
a = torch.randn(16, 5, 32)
b = torch.randn(5, 5, 64)

In [4]:
print(cuda.gpus)

<Managed Device 0>


In [82]:
@cuda.jit
def dtw_forward(dtw, w):
    '''
        dtw of shape (n, k, pattern_len, window_size)
    '''
    n, k, len_pattern, len_window = dtw.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i in range(1, len_pattern): # pl
            for j in range(1, len_window): # ws
                value = min(w * min(dtw[x, y, i, j-1], dtw[x, y, i-1, j-1]), dtw[x, y, i-1, j])
                dtw[x, y, i, j] += value

In [243]:
@cuda.jit
def dtw_backward(dtw, dist_grad, grad):
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, d, len_pattern, len_window = dist_grad.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i0 in range(len_pattern-1, -1, -1):
            for j0 in range(len_window-1, -1, -1):

                # A = dtw[x, y, i0, j0-1]
                # B = dtw[x, y, i0-1, j0]
                # C = dtw[x, y, i0-1, j0-1]

                # path is A if (A<B) & (A<C) -> path is not A if (A>=B) | (A>=C)
                # path is B if (B<A) & (B<C) -> path is not B if (B>=A) | (B>=C)

                if dtw[x, y, i0, j0] != np.inf:

                    for l in range(d):
                        cuda.atomic.add(grad, (x, y, l, i0), dist_grad[x, y, l, i0, j0])      
              
                    if j0==0 or i0==0:
                        continue

                    if dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0] or dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0-1]: # path is not A
                        for j in range(j0):
                            dtw[x, y, i0, j] = np.inf
                    if dtw[x, y, i0-1, j0] >= dtw[x, y, i0, j0-1] or dtw[x, y, i0-1, j0] >= dtw[x, y, i0-1, j0-1]: # path is not B
                        for i in range(i0):
                            dtw[x, y, i, j0] = np.inf


In [254]:
p_diff = a[:,None,:,None,:] - b[None,:,:,:,None]


euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)
p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + 1e-6)

# compute dtw
euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

In [261]:
euc_d[: ,:, -1, -1]

tensor([[ 8.3049, 12.5326, 14.2390, 18.5499, 21.3878],
        [11.3199, 14.6318,  3.1673,  6.2476,  7.9558],
        [ 2.9047, 13.5166,  2.5900,  7.3420, 16.9095],
        [11.8028, 19.0041,  5.9039,  7.2419, 11.8238],
        [ 5.6307,  5.2362,  6.0658, 14.2877,  8.3812],
        [11.2682,  8.7904,  6.7414, 14.0469,  9.2293],
        [15.1081,  8.2353, 16.9080, 25.9967, 10.8065],
        [32.1689, 31.2574, 15.9637, 16.0273, 33.0433],
        [18.0312, 19.2785,  6.0626,  5.7663, 14.8308],
        [26.5202, 24.1122, 14.4008, 13.6005, 12.0924],
        [18.6912, 12.9780, 12.8129, 17.3771, 17.1096],
        [11.3255, 24.8915,  2.4223,  3.5488, 27.4078],
        [ 8.2778, 18.2923, 10.6480, 13.8823, 18.5577],
        [ 6.9313,  9.3066,  2.8754,  7.0702, 14.7061],
        [ 5.4527, 14.2640,  4.5757,  8.3152, 14.1447],
        [ 8.1081, 15.0456,  0.7441,  2.7304, 17.6448]])

In [255]:
grads = torch.zeros((16, 5, 5, 64), device="cuda")
grads_cuda = cuda.as_cuda_array(grads)
p_diff_cuda = cuda.as_cuda_array(p_diff.cuda())

In [256]:
dtw = cuda.as_cuda_array(euc_d.detach().cuda())
dtw_forward[(16, 16), (16, 16)](dtw, 1)

In [257]:
dtw_backward[(16, 16), (16, 16)](dtw, p_diff_cuda, grads_cuda)

In [258]:
torch.tensor(dtw.copy_to_host())[0, 0]

tensor([[  9.3449,      inf,      inf,  ...,      inf,      inf,      inf],
        [     inf,  15.8723,      inf,  ...,      inf,      inf,      inf],
        [     inf,      inf,  19.5062,  ...,      inf,      inf,      inf],
        ...,
        [     inf,      inf,      inf,  ...,      inf,      inf, 387.8178],
        [     inf,      inf,      inf,  ...,      inf,      inf, 408.5746],
        [     inf,      inf,      inf,  ...,      inf,      inf, 416.8795]])

In [259]:
grads[0, 0]

tensor([[-7.5245e-03,  4.6715e-01,  9.3521e-01, -4.2035e-01, -3.8310e-01,
         -8.0012e-01, -2.1795e-01,  7.1087e-01,  4.9632e-01,  5.0801e-01,
          2.3749e-01,  7.4432e-01,  5.6123e-01,  5.5381e-01,  8.0801e-01,
         -4.5157e-01,  2.4809e-01,  7.4112e-01, -2.8445e-01, -1.7460e-02,
         -2.1852e-01,  6.7370e-03, -5.7965e-01,  6.2305e-01, -8.9107e-02,
          6.0938e-01,  1.0445e-01, -6.8408e-01, -2.8676e-01, -2.7234e-01,
          1.4446e-01, -2.1205e-01, -4.7354e-01, -3.4677e-01, -7.4158e-02,
         -1.8458e-01,  7.6073e-02, -2.7699e-01, -4.7787e-01,  4.2296e-04,
         -9.4835e-01, -4.7182e-01, -5.7164e-01, -3.2118e-01,  7.5695e-03,
         -2.0316e-01,  8.8038e-02, -2.7533e-01,  3.3248e-01,  5.7210e-01,
          6.6940e-01, -1.1434e-01,  1.1283e-01,  2.9612e-01,  4.3339e-01,
         -3.0245e-01, -5.4845e-01,  3.5049e-01, -2.8474e-01,  2.1408e-01,
          8.2052e-01, -4.5766e-01, -7.0214e-01, -6.1360e-01],
        [-7.2771e-01,  8.0601e-01,  7.6242e-02,  3

In [260]:
torch.tensor(grads_cuda.copy_to_host())[0, 0]

tensor([[-7.5245e-03,  4.6715e-01,  9.3521e-01, -4.2035e-01, -3.8310e-01,
         -8.0012e-01, -2.1795e-01,  7.1087e-01,  4.9632e-01,  5.0801e-01,
          2.3749e-01,  7.4432e-01,  5.6123e-01,  5.5381e-01,  8.0801e-01,
         -4.5157e-01,  2.4809e-01,  7.4112e-01, -2.8445e-01, -1.7460e-02,
         -2.1852e-01,  6.7370e-03, -5.7965e-01,  6.2305e-01, -8.9107e-02,
          6.0938e-01,  1.0445e-01, -6.8408e-01, -2.8676e-01, -2.7234e-01,
          1.4446e-01, -2.1205e-01, -4.7354e-01, -3.4677e-01, -7.4158e-02,
         -1.8458e-01,  7.6073e-02, -2.7699e-01, -4.7787e-01,  4.2296e-04,
         -9.4835e-01, -4.7182e-01, -5.7164e-01, -3.2118e-01,  7.5695e-03,
         -2.0316e-01,  8.8038e-02, -2.7533e-01,  3.3248e-01,  5.7210e-01,
          6.6940e-01, -1.1434e-01,  1.1283e-01,  2.9612e-01,  4.3339e-01,
         -3.0245e-01, -5.4845e-01,  3.5049e-01, -2.8474e-01,  2.1408e-01,
          8.2052e-01, -4.5766e-01, -7.0214e-01, -6.1360e-01],
        [-7.2771e-01,  8.0601e-01,  7.6242e-02,  3

In [67]:
@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    for i0 in range(len_pattern-1, -1, -1):
        for j0 in range(len_window-1, -1, -1):
            mask = ~torch.isinf(dtw[:, :, i0, j0])
            grads[:, :, :, i0][mask] += dist_grad[:, :, :, i0, j0][mask]

            if j0==0 or i0==0:
                continue

            paths = torch.stack([
                dtw[:, :, i0, j0-1],
                dtw[:, :, i0-1, j0],
                dtw[:, :, i0-1, j0-1]
            ])

            id = paths.argmin(0)

            dtw[:, :, i0, :j0][(id!=0) & mask] = float("inf")
            dtw[:, :, :i0, j0][(id!=1) & mask] = float("inf")

    return grads

In [241]:
grads_cpu = dtw_compute_full(euc_d, p_diff, 1)

In [242]:
grads_cpu[0, 0]

tensor([[-7.5245e-03,  4.6715e-01,  9.3521e-01, -4.2035e-01, -3.8310e-01,
         -8.0012e-01, -2.1795e-01,  7.1087e-01,  4.9632e-01,  5.0801e-01,
          2.3749e-01,  7.4432e-01,  5.6123e-01,  5.5381e-01,  8.0801e-01,
         -4.5157e-01,  2.4809e-01,  7.4112e-01, -2.8445e-01, -1.7460e-02,
         -2.1852e-01,  6.7370e-03, -5.7965e-01,  6.2305e-01, -8.9107e-02,
          6.0938e-01,  1.0445e-01, -6.8408e-01, -2.8676e-01, -2.7234e-01,
          1.4446e-01, -2.1205e-01, -4.7354e-01, -3.4677e-01, -7.4158e-02,
         -1.8458e-01,  7.6073e-02, -2.7699e-01, -4.7787e-01,  4.2296e-04,
         -9.4835e-01, -4.7182e-01, -5.7164e-01, -3.2118e-01,  7.5695e-03,
         -2.0316e-01,  8.8038e-02, -2.7533e-01,  3.3248e-01,  5.7210e-01,
          6.6940e-01, -1.1434e-01,  1.1283e-01,  2.9612e-01,  4.3339e-01,
         -3.0245e-01, -5.4845e-01,  3.5049e-01, -2.8474e-01,  2.1408e-01,
          8.2052e-01, -4.5766e-01, -7.0214e-01, -6.1360e-01],
        [-7.2771e-01,  8.0601e-01,  7.6242e-02,  3