In [1]:
import torch
from typing import List

In [2]:
def dtw_compute(dist_tensor: torch.Tensor, grad_tensor: torch.Tensor, w: float) -> None:
    print(dist_tensor.shape)
    for i in range(1, dist_tensor.shape[0]):
        for j in range(1, dist_tensor.shape[1]):
            value = torch.minimum(w * torch.minimum(dist_tensor[i, j-1], dist_tensor[i-1, j-1]), dist_tensor[i-1, j])
            id = (w * dist_tensor[i, j-1] < dist_tensor[i-1, j]) & (dist_tensor[i, j-1] < dist_tensor[i-1, j-1])

            dist_tensor[i, j] += value

            grad_tensor[id][:, :, i, j] += w * grad_tensor[id][:, :, i, j-1]

def torch_dtw_fast(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    # rearrange dims
    DTW = torch.permute(euc_d, (2, 3, 0, 1)).contiguous()

    dtw_compute(DTW, p_diff, w)

    # recover dimensions
    DTW = torch.permute(DTW, (2, 3, 0, 1)).contiguous()

    return DTW.sqrt(), p_diff

class torch_dtw(torch.autograd.Function):

    @staticmethod
    def forward(x, y, w):
        DTW, p_diff = torch_dtw_fast(x, y, w)
        return DTW, p_diff
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        DTW, p_diff = output
        ctx.save_for_backward(p_diff)
    
    @staticmethod
    def backward(ctx, dtw_grad, p_diff_grad):
        p_diff, = ctx.saved_tensors
        mult = (p_diff * dtw_grad[:,:,None,:,:])
        return mult.mean(dim=(1, 3)), mult.mean(dim=(0, 4)), None

In [16]:
@torch.jit.script
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (num_kernels, pattern_len, window_size)
        dist_grad of shape (num_kernels, dims, pattern_len, window_size)
        grad of shape (num_kernels, dims, pattern_len)
    '''
    grads = torch.empty(dtw.shape[0], dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (num_kernels, dims, pattern_len, pattern_len, window_size)

    for i in range(dtw.shape[1]):
        grads[:, :, i, i, :] = torch.cumsum(dist_grad[:, :, i, :], dim=2)
        grads[:, :, i, i:, 0] = grads[:, :, i, i, :1]

    for i in range(1, dtw.shape[1]): # pl
        for j in range(1, dtw.shape[2]): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, i, j-1], dtw[:, i-1, j-1]), dtw[:, i-1, j])
            temp_1 = dtw[:, i, j-1] < dtw[:, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, i, j-1] < dtw[:, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, i-1, j-1] < dtw[:, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, i, j] += value

            grads[temp_1 & temp_2][:, :, i, j] += w * grads[temp_1 & temp_2][:, :, i, j-1]
            grads[temp_1 & temp_3][:, :, i, j] += grads[temp_1 & temp_3][:, :, i-1, j]
            grads[temp_2 & temp_3][:, :, i, j] += w * grads[temp_2 & temp_3][:, :, i-1, j-1]

    grad = grads.sum(dim=(-2, -1))

@torch.jit.script
def torch_dtw_fast2(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    grads = torch.empty(x.shape[0], y.shape[0], y.shape[1], y.shape[2])

    futures = [torch.jit.fork(dtw_compute2, euc_d[i], p_diff[i], grads[i], w) for i in range(x.shape[0])]
    results = [torch.jit.wait(future) for future in futures]

    return euc_d.sqrt(), grads.sum(0)

In [4]:
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (num_kernels, pattern_len, window_size)
        dist_grad of shape (num_kernels, dims, pattern_len, window_size)
        grad of shape (num_kernels, dims, pattern_len)
    '''
    grads = torch.empty(dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (dims, pattern_len, pattern_len, window_size)
    print(grads.numel())

    for k in range(dtw.shape[0]): # num_kernels
        grads.zero_()

        for i in range(dtw.shape[1]):
            grads[:, i, i, :] = torch.cumsum(dist_grad[k, :, i, :], dim=1)
            grads[:, i, i:, 0] = grads[:, i, i, :1]

        for i in range(1, dtw.shape[1]): # wl
            for j in range(1, dtw.shape[2]): # ws
                value = torch.minimum(w * torch.minimum(dtw[k, i, j-1], dtw[k, i-1, j-1]), dtw[k, i-1, j])
                temp_1 = dtw[k, i, j-1] < dtw[k, i-1, j-1] # path (i, j-1) or (i-1, j)
                temp_2 = w * dtw[k, i, j-1] < dtw[k, i-1, j] # path (i, j-1) or (i-1, j-1)
                temp_3 = w * dtw[k, i-1, j-1] < dtw[k, i-1, j] # path (i-1, j-1) or (i-1, j)

                dtw[k, i, j] += value

                if temp_1 and temp_2: # path is (i, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i, j-1]
                elif temp_1 and temp_3: # path is (i-1, j)
                    grads[:, :, i, j] += grads[:, :, i-1, j]
                else: # path is (i-1, j-1)
                    grads[:, :, i, j] += w * grads[:, :, i-1, j-1]

        grad[k] += grads.sum(dim=(-2, -1))

In [5]:
dtw = torch.randn(10, 15, 20)
dist_grad = torch.randn(10, 6, 15, 20)
grad = torch.randn(10, 6, 15)

In [35]:
x = torch.randn(128, 3, 50)
y = torch.randn(30, 3, 50)

In [37]:
a, b = torch_dtw_fast2(x, y, w=0.01, eps=1e-6)

In [65]:
a1, b1  = torch_dtw_fast(x, y, w=0.01, eps=1e-6)

torch.Size([25, 20, 10, 15])


In [16]:
a.numel()

37748736