In [1]:
import torch
from typing import List
import torch.multiprocessing as mp

from torch.multiprocessing import Pool, Process

In [2]:
from dtw import dtw_fast_no_n as dtw_original

torch.random.manual_seed(45)
x = torch.randn(2, 10, 5)
y = torch.randn(2, 10, 4)

a, b = dtw_original(x, y, w=0.01, eps=1e-6)
print(b.min(), b.max())

tensor(-8.3206) tensor(7.9872)


In [3]:
from dtw import dtw_fast_full

In [4]:
a, b = dtw_fast_full(x, y, w=0.01, eps=1e-6)
print(b.min(), b.max())

tensor(-2.1949) tensor(2.3286)


In [30]:

@torch.jit.script
def dtw_compute_all_script(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
        grads of shape (n, k, dims, pattern_len, pattern_len, window_size)
    '''
    n, k, len_pattern, len_window = dtw.shape
    # very big tensor
    # grads shape (pattern_len(2), window_size, n, k, dims, pattern_len)
    grads = torch.zeros(len_pattern, len_window, n, k, grad.shape[2], len_pattern, device=grad.device)
    temp = torch.cumsum(dist_grad, dim=4).permute(3, 4, 0, 1, 2) # (pattern_len, window_size, n, k, dim)

    for i in range(len_pattern):
        grads[i, :, :, :, :, i] = temp[i, :, :, :, :]
        grads[i:, 0, :, :, :, i] = temp[i, :1, :, :, :]

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])
            temp_1 = dtw[:, :, i, j-1] < dtw[:, :, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, :, i, j-1] < dtw[:, :, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, :, i-1, j-1] < dtw[:, :, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, :, i, j] += value

            print(temp_1.shape, grads[i, j].shape)
            grads[i, j][temp_1 & temp_2] += w * grads[i, j-1][temp_1 & temp_2]
            grads[i, j][temp_1 & temp_3] += grads[i-1, j][temp_1 & temp_3]
            grads[i, j][temp_2 & temp_3] += w * grads[i-1, j-1][temp_2 & temp_3]

    grad += grads.sum(dim=(0, 1))

def dtw_compute_all(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    dtw_compute_all_script(dtw, dist_grad, grad, w)

@torch.jit.script
def dtw_compute_no_grad(dtw: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''

    n, k, len_pattern, len_window = dtw.shape

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])

            dtw[:, :, i, j] += value

@torch.jit.script
def dtw_fast_no_n(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, compute_gradients: bool=True, batched: int = 8):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    if compute_gradients:
        # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
        p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)
        
        grads = torch.zeros((x.shape[0], y.shape[0], y.shape[1], y.shape[2]), device=y.device) # dims (n, k, d, i)
        # grads_buffer = torch.empty((num_workers, y.shape[0], y.shape[1], y.shape[2], y.shape[2], x.shape[2]), device=y.device)

        for i in range(0, x.shape[0], batched):
            initial = i
            last = min(initial + batched, x.shape[0])
            dtw_compute_all_script(euc_d[initial:last], p_diff[initial:last], grads[initial:last], w)

            # processes = [Process(target=dtw_compute_all, args=(euc_d[]))]
            # j = min(i+batched, x.shape[0])

            # dtw_compute_all(euc_d[i:j], p_diff[i:j], grads[i:j], w)
        
        return euc_d.sqrt(), grads
    else:
        dtw_compute_no_grad(euc_d, w)

        return euc_d.sqrt(), None

In [31]:
a, b = dtw_fast_no_n(x, y, w=0.01, eps=1e-6, batched=32)
print(b.min(), b.max())

[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
[2, 2] [2, 2, 10, 4]
tensor(-17.1377) tensor(12.9263)
