In [1]:
import torch
from typing import List

In [2]:
def dtw_compute(dist_tensor: torch.Tensor, grad_tensor: torch.Tensor, w: float) -> None:
    print(dist_tensor.shape)
    for i in range(1, dist_tensor.shape[0]):
        for j in range(1, dist_tensor.shape[1]):
            value = torch.minimum(w * torch.minimum(dist_tensor[i, j-1], dist_tensor[i-1, j-1]), dist_tensor[i-1, j])
            id = (w * dist_tensor[i, j-1] < dist_tensor[i-1, j]) & (dist_tensor[i, j-1] < dist_tensor[i-1, j-1])

            dist_tensor[i, j] += value

            grad_tensor[id][:, :, i, j] += w * grad_tensor[id][:, :, i, j-1]

def torch_dtw_fast(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    # rearrange dims
    DTW = torch.permute(euc_d, (2, 3, 0, 1)).contiguous()

    dtw_compute(DTW, p_diff, w)

    # recover dimensions
    DTW = torch.permute(DTW, (2, 3, 0, 1)).contiguous()

    return DTW.sqrt(), p_diff

class torch_dtw(torch.autograd.Function):

    @staticmethod
    def forward(x, y, w):
        DTW, p_diff = torch_dtw_fast(x, y, w)
        return DTW, p_diff
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        DTW, p_diff = output
        ctx.save_for_backward(p_diff)
    
    @staticmethod
    def backward(ctx, dtw_grad, p_diff_grad):
        p_diff, = ctx.saved_tensors
        mult = (p_diff * dtw_grad[:,:,None,:,:])
        return mult.mean(dim=(1, 3)), mult.mean(dim=(0, 4)), None

In [46]:
@torch.jit.script
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, pattern_len, window_size)
        dist_grad of shape (n, dims, pattern_len, window_size)
        grad of shape (dims, pattern_len)
    '''
    grads = torch.zeros(dtw.shape[0], dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (n, dims, pattern_len, pattern_len, window_size)

    for i in range(dtw.shape[1]):
        grads[:, :, i, i, :] = torch.cumsum(dist_grad[:, :, i, :], dim=1)
        grads[:, :, i, i:, 0] = grads[:, :, i, i, :1]

    for i in range(1, dtw.shape[1]): # pl
        for j in range(1, dtw.shape[2]): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, i, j-1], dtw[:, i-1, j-1]), dtw[:, i-1, j])
            temp_1 = dtw[:, i, j-1] < dtw[:, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, i, j-1] < dtw[:, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, i-1, j-1] < dtw[:, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, i, j] += value

            grads[temp_1 & temp_2][:, :, i, j] += w * grads[temp_1 & temp_2][:, :, i, j-1]
            grads[temp_1 & temp_3][:, :, i, j] += grads[temp_1 & temp_3][:, :, i-1, j]
            grads[temp_2 & temp_3][:, :, i, j] += w * grads[temp_2 & temp_3][:, :, i-1, j-1]

    grad[:,:] = grads.sum(dim=(0, -2, -1))

@torch.jit.script
def torch_dtw_fast2(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    grads = torch.empty_like(y)

    futures = [torch.jit.fork(dtw_compute2, euc_d[:, i], p_diff[:, i], grads[i], w) for i in range(y.shape[0])] # iterate through k
    results = [torch.jit.wait(future) for future in futures]

    return euc_d.sqrt(), grads

In [5]:
dtw = torch.randn(10, 15, 20)
dist_grad = torch.randn(10, 6, 15, 20)
grad = torch.randn(10, 6, 15)

In [47]:
torch.random.manual_seed(45)
x = torch.randn(128, 3, 45)
y = torch.randn(30, 3, 50)

In [25]:
48*128

6144

In [49]:
a, b = torch_dtw_fast2(x, y, w=0.01, eps=1e-6)
print(a.shape)
b.shape

torch.Size([128, 30, 50, 45])


torch.Size([30, 3, 50])

In [50]:
a

tensor([[[[ 1.2429,  2.3042,  3.5167,  ..., 17.1790, 17.4898, 17.6778],
          [ 2.7351,  0.7122,  0.3267,  ...,  2.7114,  2.0796,  1.9782],
          [ 3.0561,  0.7008,  1.2806,  ...,  2.2275,  2.3618,  1.9711],
          ...,
          [15.3980,  3.1182,  2.4288,  ...,  1.7202,  1.4019,  2.5462],
          [15.5715,  1.7193,  0.1792,  ...,  2.4399,  1.7573,  1.8377],
          [15.5822,  2.5025,  2.3162,  ...,  1.1084,  2.3524,  2.3777]],

         [[ 1.1568,  2.4744,  3.4169,  ..., 11.7298, 11.9421, 12.2152],
          [ 1.6863,  0.7662,  1.3699,  ...,  2.2225,  1.8661,  1.3088],
          [ 2.3879,  3.1689,  3.3850,  ...,  1.1849,  2.9516,  3.2711],
          ...,
          [13.7631,  2.0278,  1.8757,  ...,  2.6156,  3.2949,  3.0390],
          [13.7933,  2.0025,  1.7389,  ...,  1.3082,  1.6418,  1.7087],
          [13.8510,  3.0640,  2.9998,  ...,  1.0187,  2.6645,  2.8978]],

         [[ 2.4771,  3.6460,  4.3860,  ..., 15.1566, 15.2373, 15.4567],
          [ 3.1372,  2.4580,  

In [52]:
b

tensor([[[ -5262.5693,  -8394.1973,  -6993.0674,  ...,   1972.8269,
           -3473.3159,    684.6747],
         [-10954.0479,  -5952.3682,  -8006.1714,  ...,   3765.1443,
           -1934.5146,   -302.2203],
         [-16686.6699,  -3851.7305, -10326.7295,  ...,   7223.1758,
            -326.5641,  -1957.5271]],

        [[  3535.5203,  -5902.5161,   7318.4785,  ...,  -3640.6541,
            -357.7154,   2825.2898],
         [  3766.2673,  -9539.7080,   5372.0586,  ...,  -2416.5520,
           -1211.0760,   1828.5994],
         [  1390.8289,  -9212.5264,   1993.5671,  ...,  -5142.6699,
            -908.1261,    283.6394]],

        [[  4975.9004,   4336.7920,   6972.3286,  ...,    587.9970,
           -1977.9802,  -4613.6880],
         [  8046.6177,   8687.4863,   2552.9612,  ...,   4439.9517,
           -4243.1655,  -4720.6719],
         [ 13220.6338,   7960.4312,   3362.3420,  ...,   2201.1692,
           -5537.7651,  -6246.1221]],

        ...,

        [[ -1485.5029,  -6289.5371,

In [65]:
a1, b1  = torch_dtw_fast(x, y, w=0.01, eps=1e-6)

torch.Size([25, 20, 10, 15])


In [16]:
a.numel()

37748736