In [2]:
import torch
from typing import List

In [3]:
import torch.multiprocessing as mp

In [4]:
@torch.jit.script
def dtw_compute2(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, pattern_len, window_size)
        dist_grad of shape (n, dims, pattern_len, window_size)
        grad of shape (dims, pattern_len)
    '''
    grads = torch.zeros(dtw.shape[0], dist_grad.shape[1], dtw.shape[1], dtw.shape[1], dtw.shape[2]) # shape (n, dims, pattern_len, pattern_len, window_size)

    for i in range(dtw.shape[1]):
        grads[:, :, i, i, :] = torch.cumsum(dist_grad[:, :, i, :], dim=1)
        grads[:, :, i, i:, 0] = grads[:, :, i, i, :1]

    for i in range(1, dtw.shape[1]): # pl
        for j in range(1, dtw.shape[2]): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, i, j-1], dtw[:, i-1, j-1]), dtw[:, i-1, j])
            temp_1 = dtw[:, i, j-1] < dtw[:, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, i, j-1] < dtw[:, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, i-1, j-1] < dtw[:, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, i, j] += value

            grads[temp_1 & temp_2][:, :, i, j] += w * grads[temp_1 & temp_2][:, :, i, j-1]
            grads[temp_1 & temp_3][:, :, i, j] += grads[temp_1 & temp_3][:, :, i-1, j]
            grads[temp_2 & temp_3][:, :, i, j] += w * grads[temp_2 & temp_3][:, :, i-1, j-1]

    grad[:,:] = grads.sum(dim=(0, -2, -1))

@torch.jit.script
def dtw_compute_one_by_one(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (pattern_len, window_size)
        dist_grad of shape (dims, pattern_len, window_size)
        grad of shape (dims, pattern_len)
    '''
    grads = torch.zeros(dist_grad.shape[0], dtw.shape[0], dtw.shape[0], dtw.shape[1]) # shape (dims, pattern_len, pattern_len, window_size)

    for i in range(dtw.shape[0]):
        grads[:, i, i, :] = torch.cumsum(dist_grad[:, i, :], dim=1)
        grads[:, i, i:, 0] = grads[:, i, i, :1]

    for i in range(1, dtw.shape[0]): # pl
        for j in range(1, dtw.shape[1]): # ws
            value = torch.minimum(w * torch.minimum(dtw[i, j-1], dtw[i-1, j-1]), dtw[i-1, j])
            temp_1 = dtw[i, j-1] < dtw[i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[i, j-1] < dtw[i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[i-1, j-1] < dtw[i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[i, j] += value

            if temp_1 & temp_2:
                grads[:, :, i, j] += w * grads[:, :, i, j-1]
            elif temp_1 & temp_3:
                grads[:, :, i, j] += grads[:, :, i-1, j]
            else:
                grads[:, :, i, j] += w * grads[:, :, i-1, j-1]

    grad[:,:] += grads.sum(dim=(-2, -1))

@torch.jit.script
def dtw_compute_everything(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (n, k, pattern_len, window_size)
        dist_grad of shape (n, k, dims, pattern_len, window_size)
        grad of shape (n, k, dims, pattern_len)
    '''
    n, k, len_pattern, len_window = dtw.shape
    # very big tensor
    grads = torch.zeros(n, k, grad.shape[2], len_pattern, len_pattern, len_window) # shape (n, k, dims, pattern_len, pattern_len, window_size)

    for i in range(len_pattern):
        grads[:, :, :, i, i, :] = torch.cumsum(dist_grad[:, :, :, i, :], dim=-1)
        grads[:, :, :, i, i:, 0] = grads[:, :, :, i, i, :1]

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, :, i, j-1], dtw[:, :, i-1, j-1]), dtw[:, :, i-1, j])
            temp_1 = dtw[:, :, i, j-1] < dtw[:, :, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, :, i, j-1] < dtw[:, :, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, :, i-1, j-1] < dtw[:, :, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, :, i, j] += value

            grads[temp_1 & temp_2][:, :, :i, i, j] += w * grads[temp_1 & temp_2][:, :, :i, i, j-1]
            grads[temp_1 & temp_3][:, :, :i, i, j] += grads[temp_1 & temp_3][:, :, :i, i-1, j]
            grads[temp_2 & temp_3][:, :, :i, i, j] += w * grads[temp_2 & temp_3][:, :, :i, i-1, j-1]

    grad[:,:,:,:] = grads.sum(dim=(-2, -1))

@torch.jit.script
def torch_dtw_fast_compute_everything(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    grads = torch.empty((x.shape[0], y.shape[0], y.shape[1], y.shape[2]))

    dtw_compute_everything(euc_d, p_diff, grads, w)

    return euc_d.sqrt(), grads

In [5]:
def torch_dtw_fast_multiprocess(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5, num_processes: int=8):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    grads = torch.zeros((num_processes, y.shape[0], y.shape[1], y.shape[2]))

    euc_d.share_memory_()
    grads.share_memory_()
    p_diff.share_memory_()

    i=0
    j=0
    while True:
        processes = []
        for rank in range(num_processes):
            print("create")
            p = mp.Process(target=dtw_compute_one_by_one, args=(euc_d[i, j], p_diff[i, j], grads[rank, j], w, ))
            p.start()
            processes.append(p)
            i+=1
            if i>x.shape[0]:
                i=0
                j+=1
            if j>y.shape[0]:
                break

        for p in processes:
            p.join()        

        if j>y.shape[0]:
            break

    return euc_d.sqrt(), grads.sum(0)/(x.shape[0])

In [6]:
torch.random.manual_seed(45)
x = torch.randn(100, 9, 50)
y = torch.randn(64, 9, 32)

In [7]:
a, b = torch_dtw_fast_compute_everything(x, y, w=0.01, eps=1e-6)
print(b.min(), b.max())

tensor(-1004.2710) tensor(1017.8133)


In [None]:
a, b = torch_dtw_fast_multiprocess(x, y, w=0.01, eps=1e-6, num_processes=8)

In [10]:
@torch.jit.script
def dtw_compute_no_n(dtw: torch.Tensor, dist_grad: torch.Tensor, grad: torch.Tensor, w: float) -> None:
    '''
        dtw of shape (k, pattern_len, window_size)
        dist_grad of shape (k, dims, pattern_len, window_size)
        grad of shape (k, dims, pattern_len)
    '''
    k, len_pattern, len_window = dtw.shape
    # very big tensor
    grads = torch.zeros(k, grad.shape[1], len_pattern, len_pattern, len_window) # shape (n, k, dims, pattern_len, pattern_len, window_size)

    for i in range(len_pattern):
        grads[:, :, i, i, :] = torch.cumsum(dist_grad[:, :, i, :], dim=2)
        grads[:, :, i, i:, 0] = grads[:, :, i, i, :1]

    for i in range(1, len_pattern): # pl
        for j in range(1, len_window): # ws
            value = torch.minimum(w * torch.minimum(dtw[:, i, j-1], dtw[:, i-1, j-1]), dtw[:, i-1, j])
            temp_1 = dtw[:, i, j-1] < dtw[:, i-1, j-1] # path (i, j-1) or (i-1, j)
            temp_2 = w * dtw[:, i, j-1] < dtw[:, i-1, j] # path (i, j-1) or (i-1, j-1)
            temp_3 = w * dtw[:, i-1, j-1] < dtw[:, i-1, j] # path (i-1, j-1) or (i-1, j)

            dtw[:, i, j] += value

            grads[temp_1 & temp_2][:, :i, i, j] += w * grads[temp_1 & temp_2][:, :i, i, j-1]
            grads[temp_1 & temp_3][:, :i, i, j] += grads[temp_1 & temp_3][:, :i, i-1, j]
            grads[temp_2 & temp_3][:, :i, i, j] += w * grads[temp_2 & temp_3][:, :i, i-1, j-1]

    grad[:,:,:] += grads.sum(dim=(-2, -1))

@torch.jit.script
def dtw_fast_no_n(x: torch.Tensor, y: torch.Tensor, w: float, eps: float = 1e-5):
    # shape of x (n, dim, x_len) y (m, dim, y_len)

    # performs convolution-like operation, for each kernel the DF
    # (of shape (kernel_size, T)) is computed, then summed across channels
    # x has shape (batch, c, time_dimension)

    # compute pairwise diffs (squared)
    p_diff = x[:,None,:,None,:] - y[None,:,:,:,None] # shape (n, n_kernel, d, Kernel_size, T)
    euc_d = torch.square(p_diff).sum(2) # shape (n, n_kernel, kernel_size, T)

    # p_diff now contains the partial derivatives of DTW[n, k, i, j] wrt K[k, d, i] (dims (n, k, d, i, j))
    p_diff = p_diff / torch.sqrt(euc_d[:,:, None, :, :] + eps)

    # compute dtw
    euc_d[:,:,0,:] = torch.cumsum(euc_d[:,:,0,:], dim=2)
    euc_d[:,:,:,0] = torch.cumsum(euc_d[:,:,:,0], dim=2)

    grads = torch.zeros((x.shape[0], y.shape[0], y.shape[1], y.shape[2]))

    futures = [torch.jit.fork(dtw_compute_no_n, euc_d[i], p_diff[i], grads[i], w) for i in range(x.shape[0])] 
    results = [torch.jit.wait(future) for future in futures]

    return euc_d.sqrt(), grads

In [12]:
a, b = dtw_fast_no_n(x, y, w=0.01, eps=1e-6)
print(b.min(), b.max())

tensor(-1004.2710) tensor(1017.8133)
