In [1]:
import torch
import numpy as np
from torch import nn, Tensor

In [2]:

@torch.jit.script
def dtw_compute_full2(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

    summed_times = 0
    for n0 in range(n):
        for k0 in range(k):
            i0 = len_pattern-1
            j0 = len_window-1
            while i0+j0>=0:
                if i0==0:
                    summed_times +=grads[n0, k0, :, i0].numel()
                    grads[n0, k0, :, i0] += dist_grad[n0, k0, :, i0, :(j0+1)].sum(1)
                    break
                if j0==0:
                    summed_times += grads[n0, k0, :, i0].numel()
                    grads[n0, k0, :, :(i0+1)] += dist_grad[n0, k0, :, :(i0+1), 0]
                    break

                summed_times += grads[n0, k0, :, i0].numel()
                grads[n0, k0, :, i0] += dist_grad[n0, k0, :, i0, j0]

                paths = torch.stack([
                    dtw[n0, k0, i0, j0-1],
                    dtw[n0, k0, i0-1, j0],
                    dtw[n0, k0, i0-1, j0-1]            
                ])

                id = paths.argmin(0)
                if id!=0:
                    i0-=1
                if id!=1:
                    j0-=1

    print(summed_times)
    return grads

@torch.jit.script
def dtw_compute_no_grad2(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_no_image2(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full2(euc_d, p_diff, w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad2(euc_d, w)

        return euc_d.sqrt(), None

In [15]:

@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pl)
    '''
    n, k, len_pattern, len_window = dtw.shape
    grads = torch.zeros((n, k, dist_grad.shape[2], len_pattern), device=dtw.device)

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
       
    summed_timems = 0
    for i0 in range(len_pattern-1, -1, -1):
        for j0 in range(len_window-1, -1, -1):
            mask = ~torch.isinf(dtw[:, :, i0, j0])
            summed_timems += grads[:, :, :, i0][mask].numel()
            grads[:, :, :, i0][mask] += dist_grad[:, :, :, i0, j0][mask]

            if j0==0 or i0==0:
                continue

            paths = torch.stack([
                dtw[:, :, i0, j0-1],
                dtw[:, :, i0-1, j0],
                dtw[:, :, i0-1, j0-1]
            ])

            id = paths.argmin(0)

            dtw[:, :, i0, :j0][(id!=0) & mask] = float("inf")
            dtw[:, :, :i0, j0][(id!=1) & mask] = float("inf")
    print(summed_timems)
    return grads

@torch.jit.script
def dtw_compute_no_grad(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_no_image(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full(euc_d, p_diff, w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad(euc_d, w)

        return euc_d.sqrt(), None

In [8]:
a = torch.randn(1, 2, 5)
b = torch.randn(2, 2, 6)

In [16]:
dtw, grads = dtw_fast_no_image(a, b, 1, compute_gradients=True)

28


In [17]:
dtw

tensor([[[[1.0344, 1.4252,    inf,    inf,    inf],
          [   inf,    inf, 1.4503,    inf,    inf],
          [   inf,    inf,    inf, 1.8254,    inf],
          [   inf,    inf,    inf,    inf, 1.8379],
          [   inf,    inf,    inf,    inf, 2.2167],
          [   inf,    inf,    inf,    inf, 2.7907]],

         [[1.9512,    inf,    inf,    inf,    inf],
          [3.1375,    inf,    inf,    inf,    inf],
          [3.1697,    inf,    inf,    inf,    inf],
          [   inf, 3.2347, 3.3188,    inf,    inf],
          [   inf,    inf,    inf, 3.3505,    inf],
          [   inf,    inf,    inf,    inf, 3.3848]]]])

In [18]:
grads

tensor([[[[ 0.5242,  0.6484,  0.4469, -0.6742, -0.6949,  0.9127],
          [ 1.2423, -0.7612, -0.8946,  0.7384,  0.7191, -0.4086]],

         [[ 0.9442,  0.8270, -0.9976, -0.2119,  0.7793,  0.9939],
          [-0.3293, -0.5622, -0.0694, -0.6793, -0.6267,  0.1103]]]])

In [12]:
dtw2, grads2 = dtw_fast_no_image2(a, b, 1, compute_gradients=True)

22


In [13]:
dtw2

tensor([[[[1.0344, 1.4252, 1.6123, 2.2851, 2.7376],
          [1.0378, 1.8451, 1.4503, 1.7515, 2.0362],
          [1.9630, 2.7984, 2.4076, 1.8254, 2.2029],
          [2.2078, 2.2315, 2.4624, 1.8830, 1.8379],
          [2.7859, 2.2417, 2.7192, 2.3990, 2.2167],
          [3.1308, 3.5783, 2.8172, 2.8042, 2.7907]],

         [[1.9512, 3.9517, 4.4983, 5.1169, 5.7590],
          [3.1375, 4.3891, 4.7841, 5.2580, 5.9098],
          [3.1697, 3.3288, 3.3588, 3.4224, 3.4751],
          [3.2884, 3.2347, 3.3188, 3.4439, 3.5011],
          [3.4339, 3.7666, 3.4591, 3.3505, 3.4249],
          [3.4807, 3.7467, 3.5414, 3.3679, 3.3848]]]])

In [14]:
grads2

tensor([[[[ 0.5242,  0.6484,  0.4469, -0.6742, -0.6949,  0.9127],
          [ 1.2423, -0.7612, -0.8946,  0.7384,  0.7191, -0.4086]],

         [[ 0.9442,  0.8270, -0.9976, -0.2119,  0.7793,  0.9939],
          [-0.3293, -0.5622, -0.0694, -0.6793, -0.6267,  0.1103]]]])