In [1]:
import torch
import numpy as np
from torch import nn, Tensor

In [2]:

@torch.jit.script
def dtw_compute_full2(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    for i0 in range(len_pattern-1, -1, -1):
        for j0 in range(len_window-1, -1, -1):
            mask = ~torch.isinf(dtw[:, :, i0, j0])
            grads[:, :, :, i0][mask] += dist_grad[:, :, :, i0, j0][mask]

            if j0==0 or i0==0:
                continue

            paths = torch.stack([
                dtw[:, :, i0, j0-1],
                dtw[:, :, i0-1, j0],
                dtw[:, :, i0-1, j0-1]
            ])

            id = paths.argmin(0)

            dtw[:, :, i0, :j0][(id!=0) & mask] = float("inf")
            dtw[:, :, :i0, j0][(id!=1) & mask] = float("inf")

    return grads

@torch.jit.script
def dtw_compute_no_grad2(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_no_image2(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full2(euc_d, p_diff, w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad2(euc_d, w)

        return euc_d.sqrt(), None

In [3]:
@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    for i0 in range(len_pattern-1, -1, -1):
        for j0 in range(len_window-1, -1, -1):
            mask = ~torch.isinf(dtw[:, :, i0, j0])
            mask_id = mask.nonzero()
            if mask_id.shape[0]==0:
                continue

            grads[mask_id[:, 0], mask_id[:, 1], :, i0] += dist_grad[mask_id[:, 0], mask_id[:, 1], :, i0, j0]

            if j0==0 or i0==0:
                continue

            paths = torch.stack([
                dtw[mask_id[:, 0], mask_id[:, 1], i0, j0-1],
                dtw[mask_id[:, 0], mask_id[:, 1], i0-1, j0],
                dtw[mask_id[:, 0], mask_id[:, 1], i0-1, j0-1]
            ])

            id = paths.argmin(0)

            idnot0 = id!=0
            idnot1 = id!=1
            dtw[mask_id[:, 0][idnot0], mask_id[:, 1][idnot0], i0, :j0] = float("inf")
            dtw[mask_id[:, 0][idnot1], mask_id[:, 1][idnot1], :i0, j0] = float("inf")

    return grads

@torch.jit.script
def dtw_compute_no_grad(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_no_image(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full(euc_d, p_diff, w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad(euc_d, w)

        return euc_d.sqrt(), None

In [4]:
a = torch.randn(1, 2, 5)
b = torch.randn(2, 2, 6)

In [5]:
dtw, grads = dtw_fast_no_image(a, b, 1, compute_gradients=True)

In [6]:
grads

tensor([[[[-0.2284, -0.6878,  0.9921, -0.7903, -0.2551,  0.3082],
          [ 0.9736, -1.6222, -0.1253,  0.6127, -0.9669, -0.9513]],

         [[ 0.8082, -0.9195, -0.6803, -0.2600,  0.4891,  0.4143],
          [-0.0071, -0.3932, -0.7328, -0.9656, -0.8722, -0.9101]]]])

In [7]:
dtw

tensor([[[[2.2502,    inf,    inf,    inf,    inf],
          [   inf, 3.3711, 3.5811,    inf,    inf],
          [   inf,    inf,    inf, 3.7678,    inf],
          [   inf,    inf,    inf, 4.0152,    inf],
          [   inf,    inf,    inf, 4.3625,    inf],
          [   inf,    inf,    inf,    inf, 5.1445]],

         [[3.2596, 3.8193,    inf,    inf,    inf],
          [   inf,    inf, 4.1877,    inf,    inf],
          [   inf,    inf,    inf, 4.1945,    inf],
          [   inf,    inf,    inf,    inf, 4.3122],
          [   inf,    inf,    inf,    inf, 4.4857],
          [   inf,    inf,    inf,    inf, 5.1856]]]])

In [8]:
dtw2, grads2 = dtw_fast_no_image2(a, b, 1, compute_gradients=True)

In [9]:
dtw2

tensor([[[[2.2502,    inf,    inf,    inf,    inf],
          [   inf, 3.3711, 3.5811,    inf,    inf],
          [   inf,    inf,    inf, 3.7678,    inf],
          [   inf,    inf,    inf, 4.0152,    inf],
          [   inf,    inf,    inf, 4.3625,    inf],
          [   inf,    inf,    inf,    inf, 5.1445]],

         [[3.2596, 3.8193,    inf,    inf,    inf],
          [   inf,    inf, 4.1877,    inf,    inf],
          [   inf,    inf,    inf, 4.1945,    inf],
          [   inf,    inf,    inf,    inf, 4.3122],
          [   inf,    inf,    inf,    inf, 4.4857],
          [   inf,    inf,    inf,    inf, 5.1856]]]])

In [10]:
grads2

tensor([[[[-0.2284, -0.6878,  0.9921, -0.7903, -0.2551,  0.3082],
          [ 0.9736, -1.6222, -0.1253,  0.6127, -0.9669, -0.9513]],

         [[ 0.8082, -0.9195, -0.6803, -0.2600,  0.4891,  0.4143],
          [-0.0071, -0.3932, -0.7328, -0.9656, -0.8722, -0.9101]]]])