In [1]:
import numpy as np
from torch import nn, Tensor

import torch
import torch.cuda
from numba import cuda, jit, prange 

from s3ts.api.nets.encoders.dtw.dtw_no_matrix import dtw_fast_no_image

In [2]:
torch.manual_seed(45)
a = torch.randn(1, 2, 10)
b = torch.randn(2, 2, 12)

In [3]:
print(cuda.gpus)

<Managed Device 0>


In [4]:
@cuda.jit
def dtw_forward(dtw, w):
    '''
        dtw of shape (n, k, pattern_len, window_size)
    '''
    n, k, len_pattern, len_window = dtw.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i in range(1, len_pattern): # pl
            for j in range(1, len_window): # ws
                value = min(w * min(dtw[x, y, i, j-1], dtw[x, y, i-1, j-1]), dtw[x, y, i-1, j])
                dtw[x, y, i, j] += value

In [5]:
@cuda.jit
def dtw_backward(dtw, dist_grad, grad):
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, d, len_pattern, len_window = dist_grad.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i0 in range(len_pattern-1, -1, -1):
            for j0 in range(len_window-1, -1, -1):

                # A = dtw[x, y, i0, j0-1]
                # B = dtw[x, y, i0-1, j0]
                # C = dtw[x, y, i0-1, j0-1]

                # path is A if (A<B) & (A<C) -> path is not A if (A>=B) | (A>=C)
                # path is B if (B<A) & (B<C) -> path is not B if (B>=A) | (B>=C)

                if dtw[x, y, i0, j0] != np.inf:

                    for l in range(d):
                        cuda.atomic.add(grad, (x, y, l, i0), dist_grad[x, y, l, i0, j0])      
              
                    if j0==0 or i0==0:
                        continue

                    if dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0] or dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0-1]: # path is not A
                        for j in range(j0):
                            dtw[x, y, i0, j] = np.inf
                    if dtw[x, y, i0-1, j0] >= dtw[x, y, i0, j0-1] or dtw[x, y, i0-1, j0] >= dtw[x, y, i0-1, j0-1]: # path is not B
                        for i in range(i0):
                            dtw[x, y, i, j0] = np.inf


In [28]:
p_diff = a[:,None,:,None,:] - b[None,:,:,:,None]


euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

# compute dtw
euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

In [29]:
euc_d[: ,:, -1, -1]

tensor([[1.1652, 8.8709]])

In [30]:
grads = torch.zeros((1, 2, 2, 12), device="cuda")
grads_cuda = cuda.as_cuda_array(grads)
p_diff_cuda = cuda.as_cuda_array(p_diff.cuda())

In [31]:
dtw = cuda.as_cuda_array(euc_d.detach().cuda())
dtw_forward[(16, 16), (16, 16)](dtw, 1)

In [32]:
dtw_backward[(16, 16), (16, 16)](dtw, p_diff_cuda, grads_cuda)

In [33]:
torch.tensor(dtw.copy_to_host())[0, 0]

tensor([[ 3.1809,  3.1977,  3.2883,     inf,     inf,     inf,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,  4.8867,     inf,     inf,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,  5.7061,     inf,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,  6.8401,     inf,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,  7.5416,     inf,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,     inf, 11.0243,     inf,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,     inf,     inf, 11.9491,     inf,
             inf,     inf],
        [    inf,     inf,     inf,     inf,     inf,     inf,     inf, 17.5958,
             inf,     inf],
        [    inf,     inf,     inf,     inf,     inf,     inf,     inf,     inf,
         17.7876,     inf],
        [    inf,  

In [34]:
grads.shape

torch.Size([1, 2, 2, 12])

In [36]:
torch.tensor(grads_cuda.copy_to_host())

tensor([[[[-1.0554,  1.2604,  0.4572,  1.0310, -0.8302,  0.2612,  0.1014,
           -1.5561, -0.1494, -0.3043,  0.4818,  1.0795],
          [-1.5693, -0.0985, -0.7813,  0.2666,  0.1108, -1.8478, -0.9563,
           -1.7959,  0.4116,  0.9128, -0.6228, -0.0028]],

         [[ 1.3708,  1.5932, -0.2534, -0.5999,  2.1670,  0.6882,  0.7017,
            1.7478,  0.5878,  0.1707,  0.8489,  1.7637],
          [-0.9776,  0.8694,  0.7780,  0.2931,  0.6177,  0.1789, -0.6374,
           -0.2487,  0.7245, -1.9783, -0.1086,  2.4000]]]])

In [37]:
@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    for i0 in range(len_pattern-1, -1, -1):
        for j0 in range(len_window-1, -1, -1):
            mask = ~torch.isinf(dtw[:, :, i0, j0])
            grads[:, :, :, i0][mask] += dist_grad[:, :, :, i0, j0][mask]

            if j0==0 or i0==0:
                continue

            paths = torch.stack([
                dtw[:, :, i0, j0-1],
                dtw[:, :, i0-1, j0],
                dtw[:, :, i0-1, j0-1]
            ])

            id = paths.argmin(0)

            dtw[:, :, i0, :j0][(id!=0) & mask] = float("inf")
            dtw[:, :, :i0, j0][(id!=1) & mask] = float("inf")

    return grads

In [38]:
grads_cpu = dtw_compute_full(euc_d, p_diff, 1)

In [39]:
grads_cpu[0, 0]

tensor([[-1.0554,  1.2604,  0.4572,  1.0310, -0.8302,  0.2612,  0.1014, -1.5561,
         -0.1494, -0.3043,  0.4818,  1.0795],
        [-1.5693, -0.0985, -0.7813,  0.2666,  0.1108, -1.8478, -0.9563, -1.7959,
          0.4116,  0.9128, -0.6228, -0.0028]])

In [None]:
class torch_dtw_cuda(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x: torch.Tensor, y: torch.Tensor, w: float):
        DTW, p_diff = dtw_fast_cuda(x, y.detach(), w, compute_gradients=y.requires_grad)

        ctx.save_for_backward(p_diff)

        return DTW[:, :, -1, -1]
    
    @staticmethod
    def backward(ctx, dtw_grad):
        # dtw_grad dims (n, k) p_diff dims (n, k, d, pl)
        p_diff, = ctx.saved_tensors
        mult = (p_diff * dtw_grad[:, :, None, None]) # dims (n, k, d)
        return None, 2*mult.mean(0), None # dims (n, d, k)

In [61]:
@cuda.jit
def dtw_fill(dtw, w):
    '''
        dtw of shape (n, k, pattern_len, window_size)
    '''
    n, k, len_pattern, len_window = dtw.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i in range(1, len_pattern): # pl
            for j in range(1, len_window): # ws
                value = min(w * min(dtw[x, y, i, j-1], dtw[x, y, i-1, j-1]), dtw[x, y, i-1, j])
                dtw[x, y, i, j] += value

        cuda.syncthreads()

@cuda.jit
def dtw_backward(dtw, dist_grad, grad):
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, d, len_pattern, len_window = dist_grad.shape

    x, y = cuda.grid(2)

    if x < n and y < k:
        for i0 in range(len_pattern-1, -1, -1):
            for j0 in range(len_window-1, -1, -1):

                # A = dtw[x, y, i0, j0-1]
                # B = dtw[x, y, i0-1, j0]
                # C = dtw[x, y, i0-1, j0-1]

                # path is A if (A<B) & (A<C) -> path is not A if (A>=B) | (A>=C)
                # path is B if (B<A) & (B<C) -> path is not B if (B>=A) | (B>=C)

                if dtw[x, y, i0, j0] != np.inf:

                    for l in range(d):
                        cuda.atomic.add(grad, (x, y, l, i0), dist_grad[x, y, l, i0, j0])      
              
                    if j0==0 or i0==0:
                        continue

                    if dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0] or dtw[x, y, i0, j0-1] >= dtw[x, y, i0-1, j0-1]: # path is not A
                        for j in range(j0):
                            dtw[x, y, i0, j] = np.inf
                    if dtw[x, y, i0-1, j0] >= dtw[x, y, i0, j0-1] or dtw[x, y, i0-1, j0] >= dtw[x, y, i0-1, j0-1]: # path is not B
                        for i in range(i0):
                            dtw[x, y, i, j0] = np.inf

        cuda.syncthreads()

# @torch.jit.script
def dtw_forward(x: torch.Tensor, y: torch.Tensor, w: float):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # if compute_gradients:
    #     p_diff /= euc_d[:,:, None, :, :] + eps

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    dtw_fill[(16, 16), (16, 16)](cuda.as_cuda_array(euc_d), w)

    return euc_d, p_diff
    
class torch_dtw_cuda(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x: torch.Tensor, y: torch.Tensor, w: float = 1):
        DTW, p_diff = dtw_forward(x, y.detach(), w)

        ctx.save_for_backward(DTW, p_diff)

        return DTW[:, :, -1, -1]
    
    @staticmethod
    def backward(ctx, dtw_grad):
        # dtw_grad dims (n, k) p_diff dims (n, k, d, pl)
        dtw, p_diff = ctx.saved_tensors
        grads = torch.zeros((dtw.shape[0],) + p_diff.shape[1:-1], device=dtw_grad.device)
        dtw_backward[(16, 16), (16, 16)](cuda.as_cuda_array(dtw), cuda.as_cuda_array(p_diff), cuda.as_cuda_array(grads))

        mult = (dtw_grad[:, :, None, None] * grads) # dims (n, k, d)
        return None, 2*mult.mean(0), None # dims (n, d, k)

In [52]:
x = torch.randn(3, 4, 15).cuda()
y = torch.randn(5, 4, 20, requires_grad=True)

In [56]:
res = torch_dtw_cuda.apply(x, y.cuda())

In [57]:
res2 = res.sum()

In [58]:
res2.backward()

torch.Size([3, 5, 4, 20])


In [60]:
y.grad.shape

torch.Size([5, 4, 20])

In [65]:
x.requires_grad

False

In [66]:
y._grad