In [21]:
import torch
import numpy as np
from torch import nn, Tensor

from s3ts.api.nets.encoders.dtw.dtw_no_matrix import dtw_fast_no_image

In [24]:
torch.manual_seed(45)
a = torch.randn(1, 2, 5)
b = torch.randn(2, 2, 6)

In [38]:

@torch.jit.script
def dtw_compute_full(dtw: torch.Tensor, dist_grad: torch.Tensor, dim: int, w: float) -> torch.Tensor:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''
    n, k, len_pattern, len_window = dtw.shape
    # very big tensor
    grads = torch.zeros(n, k, dim, len_pattern, len_pattern, len_window, device=dist_grad.device) # shape (n, k, dims, pattern_len, pattern_len, window_size)

    temp = torch.cumsum(dist_grad, dim=4)
    for i in range(len_pattern):
        grads[:, :, :, i, i, :] = temp[:, :, :, i, :]
        grads[:, :, :, i, i:, 0] = grads[:, :, :, i, i, :1]

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            path = torch.stack([
                w * dtw[:, :, i, j-1],
                dtw[:, :, i-1, j],
                w * dtw[:, :, i-1, j-1]
            ])

            value, id = path.min(0)
            path0 = id == 0
            path1 = id == 2
            path2 = id == 1

            dtw[:, :, i, j] += value

            grads[:, :, :, :, i, j][path0] += w * grads[:, :, :, :, i, j-1][path0]
            grads[:, :, :, :, i, j][path1] += grads[:, :, :, :, i-1, j][path1]
            grads[:, :, :, :, i, j][path2] += w * grads[:, :, :, :, i-1, j-1][path2]

    return grads

@torch.jit.script
def dtw_compute_no_grad(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value
    
@torch.jit.script
def dtw_fast_full(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    if compute_gradients:
        p_diff /= torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        
        grads = dtw_compute_full(euc_d, p_diff, x.shape[1], w) # dims (n, k, d, i, i, j)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad(euc_d, w)

        return euc_d.sqrt(), None

class torch_dtw(torch.autograd.Function):

    @staticmethod
    def forward(x: torch.Tensor, y: torch.Tensor, w: float):
        DTW, p_diff = dtw_fast_full(x, y, w, compute_gradients=y.requires_grad)
        return DTW, p_diff
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        DTW, p_diff = output
        ctx.save_for_backward(p_diff)
    
    @staticmethod
    def backward(ctx, dtw_grad, p_diff_grad):
        # dtw_grad dims (n, k, i, j) p_diff dims (n, k, d, i, i, j)
        p_diff, = ctx.saved_tensors
        mult = (p_diff * dtw_grad[:, :, None, :, None, :]) # dims (n, k, d, i, i, j)
        return None, mult.sum(dim=(-2, -1)).mean(dim=0), None

In [41]:
dtw, grads = dtw_fast_no_image(a, b, 1, compute_gradients=True)

In [43]:
dtw, grads = dtw_fast_full(a, b, 1, compute_gradients=True)

In [42]:
grads

tensor([[[[-0.2995, -0.9950,  0.5440,  0.9148, -0.8579, -0.6878],
          [-0.1696, -0.0999, -0.8391,  0.4040, -0.5138, -0.7259]],

         [[-0.9568,  1.0000,  0.3197,  0.0070,  0.2462,  0.2867],
          [ 0.2908, -0.0061,  0.9475, -1.0000,  0.9692,  0.9580]]]])

In [47]:
grads

tensor([[[[[[-0.6353, -0.2995, -0.7211, -0.3951, -0.3885],
            [-0.6353, -0.2995, -0.7211, -0.7211, -0.3885],
            [-0.6353, -0.2995, -0.2995, -0.7211, -0.7211],
            [-0.6353, -0.6353, -0.2995, -0.7211, -0.7211],
            [-0.6353, -0.6353, -0.6353, -0.7211, -0.7211],
            [-0.6353, -0.6353, -0.6353, -0.7211, -0.7211]],

           [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
            [-0.9992, -0.9849, -1.9799, -3.8073, -2.1075],
            [-0.9992, -0.9849, -0.9849, -3.8073, -3.8073],
            [-0.9992, -0.9992, -0.9849, -3.8073, -3.8073],
            [-0.9992, -0.9992, -0.9992, -3.8073, -3.8073],
            [-0.9992, -0.9992, -0.9992, -3.8073, -3.8073]],

           [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
            [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
            [-0.3507,  0.4711,  0.4761,  1.0201,  2.6376],
            [-0.3507, -0.3507,  0.4711,  1.0201,  2.6376],
            [-0.3507, -0.3507, -0.3507,  1.0201,  1.

In [35]:
dtw

tensor([[[[1.5585, 1.6327,    inf,    inf,    inf],
          [   inf,    inf, 1.7728,    inf,    inf],
          [   inf,    inf,    inf, 2.7650,    inf],
          [   inf,    inf,    inf,    inf, 3.3739],
          [   inf,    inf,    inf,    inf, 3.5688],
          [   inf,    inf,    inf,    inf, 3.6651]],

         [[0.9022,    inf,    inf,    inf,    inf],
          [   inf, 1.7319,    inf,    inf,    inf],
          [   inf,    inf, 1.7665,    inf,    inf],
          [   inf,    inf,    inf, 2.5299,    inf],
          [   inf,    inf,    inf,    inf, 2.6717],
          [   inf,    inf,    inf,    inf, 2.9638]]]])

In [40]:
dtw

tensor([[[[1.5585, 1.6327, 2.0872, 2.7839, 2.8937],
          [1.9251, 2.3111, 1.7728, 3.5062, 2.8246],
          [2.2920, 2.1137, 2.1079, 2.7650, 2.9191],
          [2.6982, 2.8980, 2.7001, 3.6185, 3.3739],
          [3.3596, 3.3751, 3.1210, 4.1523, 3.5688],
          [3.7131, 3.8596, 3.3255, 4.4251, 3.6651]],

         [[0.9022, 1.6903, 1.7581, 3.2890, 3.2952],
          [1.9109, 1.7319, 2.4691, 2.9034, 3.4215],
          [1.9715, 2.3207, 1.7665, 3.3793, 2.9791],
          [2.7682, 2.0845, 2.3754, 2.5299, 2.6941],
          [3.1356, 2.1574, 2.4405, 3.0295, 2.6717],
          [3.5953, 2.2207, 2.7035, 2.9289, 2.9638]]]])