<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_interp_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import time

# CV/ML.
import numpy as np

import torch
import torch.nn.functional as F

# 2. Load data

In [2]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-09-29 10:10:50--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-09-29 10:10:51--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-09-29 10:10:51 (42.7 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

--2023-09-29 10:10:51--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy [following]
--2023-09-29 10:10:52--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_interp.npy’


2023-09-29 10:10:52 (32.5 MB/s) - ‘finger_far0_non_static_interp.

In [4]:
!ls

finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data


# 3. PyTorch implementation

## 3.1 Implementation based on define-by-run mode

In [5]:
def matrix_interp_torch(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = torch.where(mask != 0)[0]
    ys = track.reshape([tlength, -1])[xs, :]
    x = torch.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.to(ys.dtype)
    x = x.to(ys.dtype)
    # Pad control points for extrapolation.
    # Unexpectedly, torch.finfo(torch.float64).min returns -inf.
    # So we use torch.finfo(torch.float32).min alternatively.
    xs = torch.cat([torch.tensor([torch.finfo(torch.float32).min]), xs, torch.tensor([torch.finfo(torch.float32).max])], dim=0)
    ys = torch.cat([ys[:1], ys, ys[-1:]], dim=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / torch.unsqueeze((xs[1:] - xs[:-1]), dim=-1)
    sloops = F.pad(sloops[:-1], (0, 0, 1, 1))

    # Solve for intercepts.
    intercepts = ys - sloops * torch.unsqueeze(xs, dim=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    mask_bk_indicator = torch.unsqueeze(xs, dim=-2) > torch.unsqueeze(x, dim=-1)
    idx = torch.argmax(mask_bk_indicator.to(torch.int32), dim=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * torch.unsqueeze(x, dim=-1) + intercept
    y = y.to(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_torch(trackdata):
    trackdata = torch.from_numpy(trackdata)
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_torch(pose)
    lhand = matrix_interp_torch(lhand)
    rhand = matrix_interp_torch(rhand)
    face = matrix_interp_torch(face)
    return torch.cat([pose, lhand, rhand, face], dim=1)

In [6]:
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

In [7]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_torch(trackdata)
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_torch(trackdata)
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.15895938873291016
Average time after second call:0.0026750802993774415
Sum of error:-6.935119145623503e-12


In [8]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_torch(trackdata[:-1])
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_torch(trackdata[:-1])
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.0022478103637695312
Average time after second call:0.0020278453826904296
Sum of error:-6.935119145623503e-12


## 3.2 Implementation based on define-and-run (JIT compile)

In [9]:
@torch.jit.script
def matrix_interp_torch_jit(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = torch.where(mask != 0)[0]
    ys = track.reshape([tlength, -1])[xs, :]
    x = torch.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.to(ys.dtype)
    x = x.to(ys.dtype)
    # Pad control points for extrapolation.
    # Unexpectedly, torch.finfo(torch.float64).min returns -inf.
    # So we use torch.finfo(torch.float32).min alternatively.
    # xs = torch.cat([torch.tensor([torch.finfo(torch.float32).min]), xs, torch.tensor([torch.finfo(torch.float32).max])], dim=0)
    # torch.finfo is not been supported in JIT.
    xs = torch.cat([torch.tensor([-1000]), xs, torch.tensor([1000])], dim=0)
    ys = torch.cat([ys[:1], ys, ys[-1:]], dim=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / torch.unsqueeze((xs[1:] - xs[:-1]), dim=-1)
    sloops = F.pad(sloops[:-1], (0, 0, 1, 1))

    # Solve for intercepts.
    intercepts = ys - sloops * torch.unsqueeze(xs, dim=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    mask_bk_indicator = torch.unsqueeze(xs, dim=-2) > torch.unsqueeze(x, dim=-1)
    idx = torch.argmax(mask_bk_indicator.to(torch.int32), dim=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * torch.unsqueeze(x, dim=-1) + intercept
    y = y.to(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_torch_jit(trackdata):
    trackdata = torch.from_numpy(trackdata)
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_torch_jit(pose)
    lhand = matrix_interp_torch_jit(lhand)
    rhand = matrix_interp_torch_jit(rhand)
    face = matrix_interp_torch_jit(face)
    return torch.cat([pose, lhand, rhand, face], dim=1)

In [10]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_torch_jit(trackdata)
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_torch_jit(trackdata)
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.14833760261535645
Average time after second call:0.0033842802047729494
Sum of error:-6.935119145623503e-12


In [11]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_torch_jit(trackdata[:-1])
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_torch_jit(trackdata[:-1])
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.0027930736541748047
Average time after second call:0.0016836881637573241
Sum of error:-6.935119145623503e-12
