In [1]:
import torch
import numpy as np
from torch import nn, Tensor

In [29]:

@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    # for n0 in range(n):
    #     for k0 in range(k):
    #         i0 = len_pattern-1
    #         j0 = len_window-1
    #         while i0+j0>=0:
    #             if i0==0:
    #                 grads[n0, k0, :, i0] += dist_grad[n0, k0, :, i0, :(j0+1)].sum(1)
    #                 break
    #             if j0==0:
    #                 grads[n0, k0, :, :(i0+1)] += dist_grad[n0, k0, :, :(i0+1), 0]
    #                 break

    #             grads[n0, k0, :, i0] += dist_grad[n0, k0, :, i0, j0]

    #             paths = torch.stack([
    #                 dtw[n0, k0, i0, j0-1],
    #                 dtw[n0, k0, i0-1, j0],
    #                 dtw[n0, k0, i0-1, j0-1]            
    #             ])

    #             id = paths.argmin(0)
    #             if id!=0:
    #                 i0-=1
    #             if id!=1:
    #                 j0-=1

    i_path = torch.full((n, k), len_pattern-1, dtype=torch.int64, device=dtw.device)
    j_path = torch.full((n, k), len_window-1, dtype=torch.int64, device=dtw.device)

    n_indices = torch.arange(n, device=dtw.device).unsqueeze(1)
    k_indices = torch.arange(k, device=dtw.device).unsqueeze(0)

    for id in range(len_pattern + len_window):
        positive_i = i_path>=0
        grads[n_indices, k_indices, :, i_path[positive_i]] += dist_grad[n_indices, k_indices, :, i_path[positive_i], j_path[positive_i]]

        i_path_restricted = torch.clamp(i_path-1, min=0)
        j_path_restricted = torch.clamp(j_path-1, min=0)

        values = torch.stack([
            dtw[n_indices, k_indices, i_path_restricted, j_path_restricted],
            dtw[n_indices, k_indices, i_path, j_path_restricted],
            dtw[n_indices, k_indices, i_path_restricted, j_path],
        ])

        path = values.argmin(dim=0) # path 0 is (-1, -1) path 1 is (0, -1) path 2 is (-1, 0)

        i_path[path!=1] -= 1
        j_path[path!=2] -= 1
        

    # for i0 in range(len_pattern-1, -1, -1):
    #     for j0 in range(len_window-1, -1, -1):
    #         mask = ~torch.isinf(dtw[:, :, i0, j0])
    #         grads[:, :, :, i0][mask] += dist_grad[:, :, :, i0, j0][mask]

    #         if j0==0 or i0==0:
    #             continue

    #         paths = torch.stack([
    #             dtw[:, :, i0, j0-1],
    #             dtw[:, :, i0-1, j0],
    #             dtw[:, :, i0-1, j0-1]            
    #         ])

    #         id = paths.argmin(0)

    #         dtw[:, :, i0, :j0][(id!=1) & mask] = float("inf")
    #         dtw[:, :, :i0, j0][(id!=0) & mask] = float("inf")

    return grads

@torch.jit.script
def dtw_compute_no_grad(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_no_image(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full(euc_d, p_diff, w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad(euc_d, w)

        return euc_d.sqrt(), None

In [12]:
a = torch.randn(128, 2, 32)
b = torch.randn(64, 2, 32)

In [30]:
dtw, grads = dtw_fast_no_image(a, b, 1, compute_gradients=True)

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/tmp/ipykernel_39942/3787422734.py", line 51, in dtw_fast_no_image
    for id in range(len_pattern + len_window):
        positive_i = i_path>=0
        grads[n_indices, k_indices, :, i_path[positive_i]] += dist_grad[n_indices, k_indices, :, i_path[positive_i], j_path[positive_i]]
                                                              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        i_path_restricted = torch.clamp(i_path-1, min=0)
RuntimeError: shape mismatch: indexing tensors could not be broadcast together with shapes [128, 1], [1, 64], [8192], [8192]


In [17]:
dtw2, grads2 = dtw_fast_no_image(a, b, 1, compute_gradients=True)