In [39]:
import torch
from typing import List

In [12]:
def dtw_compute(dist_tensor: torch.Tensor, grad_tensor: torch.Tensor, w: float) -> None:
    print(dist_tensor.shape)
    for i in range(1, dist_tensor.shape[0]):
        for j in range(1, dist_tensor.shape[1]):
            value = torch.minimum(w * torch.minimum(dist_tensor[i, j-1], dist_tensor[i-1, j-1]), dist_tensor[i-1, j])
            id = (w * dist_tensor[i, j-1] < dist_tensor[i-1, j]) & (dist_tensor[i, j-1] < dist_tensor[i-1, j-1])

            dist_tensor[i, j] += value

            grad_tensor[id][:, :, i, j] += w * grad_tensor[id][:, :, i, j-1]

def torch_dtw_fast(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    # rearrange dims
    DTW = torch.permute(euc_d, (2, 3, 0, 1)).contiguous()

    dtw_compute(DTW, p_diff, w)

    # recover dimensions
    DTW = torch.permute(DTW, (2, 3, 0, 1)).contiguous()

    return DTW.sqrt(), p_diff

class torch_dtw(torch.autograd.Function):

    @staticmethod
    def forward(x, y, w):
        DTW, p_diff = torch_dtw_fast(x, y, w)
        return DTW, p_diff
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        DTW, p_diff = output
        ctx.save_for_backward(p_diff)
    
    @staticmethod
    def backward(ctx, dtw_grad, p_diff_grad):
        p_diff, = ctx.saved_tensors
        mult = (p_diff * dtw_grad[:,:,None,:,:])
        return mult.mean(dim=(1, 3)), mult.mean(dim=(0, 4)), None

In [68]:
@torch.jit.script
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (num_kernels, pattern_len, window_size)
        dist_grad of shape (num_kernels, dims, pattern_len, window_size)
        grad of shape (num_kernels, dims, pattern_len)
    '''
    grad = torch.empty(dtw.shape[0], dist_grad.shape[1], dtw.shape[1]) # shape (num_kernels, dims, pattern_len)
    grads = torch.empty(dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (dims, pattern_len, pattern_len, window_size)

    for k in range(dtw.shape[0]): # num_kernels
        grads.zero_()

        for i in range(dtw.shape[1]):
            grads[:, i, i, :] = torch.cumsum(dist_grad[k, :, i, :], dim=1)
            grads[:, i, i:, 0] = grads[:, i, i, :1]

        for i in range(1, dtw.shape[1]): # wl
            for j in range(1, dtw.shape[2]): # ws
                value = torch.minimum(w * torch.minimum(dtw[k, i, j-1], dtw[k, i-1, j-1]), dtw[k, i-1, j])
                temp_1 = dtw[k, i, j-1] < dtw[k, i-1, j-1] # path (i, j-1) or (i-1, j)
                temp_2 = w * dtw[k, i, j-1] < dtw[k, i-1, j] # path (i, j-1) or (i-1, j-1)
                temp_3 = w * dtw[k, i-1, j-1] < dtw[k, i-1, j] # path (i-1, j-1) or (i-1, j)

                dtw[k, i, j] += value

                if temp_1 and temp_2: # path is (i, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i, j-1]
                elif temp_1 and temp_3: # path is (i-1, j)
                    grads[:, :, i, j] += grads[:, :, i-1, j]
                else: # path is (i-1, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i-1, j-1]

        grad[k] += grads.sum(dim=(-2, -1))

    return grad

@torch.jit.script
def torch_dtw_fast2(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    futures : List[torch.jit.Future[torch.Tensor]] = []
    for i in range(x.shape[0]):
        futures.append(torch.jit.fork(dtw_compute2, euc_d[i], p_diff[i], w)) # euc_d (dtw) changed in place

    results = [torch.jit.wait(future) for future in futures]

    grads = torch.stack(results).sum(0)    

    return euc_d.sqrt(), grads

In [35]:
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (num_kernels, pattern_len, window_size)
        dist_grad of shape (num_kernels, dims, pattern_len, window_size)
        grad of shape (num_kernels, dims, pattern_len)
    '''
    grads = torch.empty(dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (dims, pattern_len, pattern_len, window_size)
    print(grads.numel())

    for k in range(dtw.shape[0]): # num_kernels
        grads.zero_()

        for i in range(dtw.shape[1]):
            grads[:, i, i, :] = torch.cumsum(dist_grad[k, :, i, :], dim=1)
            grads[:, i, i:, 0] = grads[:, i, i, :1]

        for i in range(1, dtw.shape[1]): # wl
            for j in range(1, dtw.shape[2]): # ws
                value = torch.minimum(w * torch.minimum(dtw[k, i, j-1], dtw[k, i-1, j-1]), dtw[k, i-1, j])
                temp_1 = dtw[k, i, j-1] < dtw[k, i-1, j-1] # path (i, j-1) or (i-1, j)
                temp_2 = w * dtw[k, i, j-1] < dtw[k, i-1, j] # path (i, j-1) or (i-1, j-1)
                temp_3 = w * dtw[k, i-1, j-1] < dtw[k, i-1, j] # path (i-1, j-1) or (i-1, j)

                dtw[k, i, j] += value

                if temp_1 and temp_2: # path is (i, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i, j-1]
                elif temp_1 and temp_3: # path is (i-1, j)
                    grads[:, :, i, j] += grads[:, :, i-1, j]
                else: # path is (i-1, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i-1, j-1]

        grad[k] += grads.sum(dim=(-2, -1))

In [36]:
dtw = torch.randn(10, 15, 20)
dist_grad = torch.randn(10, 6, 15, 20)
grad = torch.randn(10, 6, 15)

In [59]:
dtw_compute2(dtw, dist_grad, 0.1)

tensor([[[-2.0309e+01, -3.2442e+00, -1.3458e+01,  8.4600e+01, -1.1269e+01,
           1.0207e+02, -6.2906e+00, -1.4982e+01, -6.4075e+00,  9.6667e+01,
           1.1837e+02,  1.2234e+02, -2.3722e+00, -6.5837e+01, -4.9917e+01],
         [ 6.3669e-01, -5.3787e+01,  1.2285e+01,  3.3467e+01,  2.7149e+00,
           1.2115e+02,  4.8013e+00, -2.6668e+01, -1.7822e+01,  6.8604e+00,
          -4.3326e+01, -3.2684e+01, -5.1875e+01, -4.6764e+01,  7.0709e-01],
         [-1.6777e+01,  1.7212e+02,  7.3891e+01,  6.3038e+01,  5.5883e+01,
          -4.7924e+01,  8.2706e+01, -5.7632e+00, -5.7387e+01, -8.4569e+01,
           1.8049e+02,  5.5887e+02, -5.0924e-03,  2.7504e+01, -3.9577e+01],
         [ 1.2510e+02, -3.6201e+01,  3.6084e+01, -1.4702e+01,  8.0949e+01,
          -6.0405e+01,  3.1432e+01,  2.5460e+01, -1.9123e+01,  3.2267e+01,
           8.9589e+00,  8.7169e+00, -2.6847e+01,  3.2196e+01,  8.5016e+00],
         [ 1.8384e+01,  1.9731e+01, -7.6010e-01, -2.8320e+00,  2.4105e+01,
           7.9186e+01

In [78]:
x = torch.randn(10, 3, 20)
y = torch.randn(15, 3, 25)

In [17]:
def add1(x):
    x += 1

In [20]:
a = torch.ones(2, 10)

In [81]:
a, b = torch_dtw_fast2(x, y, w=0.01, eps=1e-6)

In [65]:
a1, b1  = torch_dtw_fast(x, y, w=0.01, eps=1e-6)

torch.Size([25, 20, 10, 15])


In [16]:
a.numel()

37748736