<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_interp_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import gc
import sys
import time
from functools import partial

# CV/ML.
import numpy as np

import torch
import torch.nn.functional as F

In [2]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"PyTorch:{torch.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
PyTorch:2.1.0+cu118


# 2. Load data

In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-10-30 04:24:25--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-10-30 04:24:25--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-10-30 04:24:26 (30.7 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [4]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

--2023-10-30 04:24:26--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy [following]
--2023-10-30 04:24:26--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_interp.npy’


2023-10-30 04:24:26 (31.0 MB/s) - ‘finger_far0_non_static_interp.

In [5]:
!ls

finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data


# 3. Evaluation settings

In [6]:
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str

In [7]:
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    if top_k:
        print(f"Top {top_k} summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
    else:
        print(f"Overall summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")

In [8]:
class PerfMeasure():
    def __init__(self,
                 trials=100,
                 top_k=10):
        self.trials = trials
        self.top_k = top_k

    def __call__(self, func):
        gc.collect()
        gc.disable()
        intervals = []
        for _ in range(self.trials):
            start = time.perf_counter()
            func()
            end = time.perf_counter()
            intervals.append(end - start)
        intervals = np.array(intervals)
        print_perf_time(intervals)
        if self.top_k:
            print_perf_time(intervals, self.top_k)
        gc.enable()
        gc.collect()

In [9]:
TRIALS = 100
TOPK = 10
pmeasure = PerfMeasure(TRIALS, TOPK)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Target device is {DEVICE}.")
JIT_OPT = False

Target device is cpu.


# 4. PyTorch implementation

## 4.1 Implementation based on define-by-run mode

In [10]:
def matrix_interp_torch(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = torch.where(mask != 0)[0]
    ys = track.reshape([tlength, -1])[xs, :]
    x = torch.arange(tlength, device=xs.device)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.to(ys.dtype)
    x = x.to(ys.dtype)
    # Pad control points for extrapolation.
    # Unexpectedly, torch.finfo(torch.float64).min returns -inf.
    # So we use torch.finfo(torch.float32).min alternatively.
    xs = torch.cat([torch.tensor([torch.finfo(torch.float32).min], device=xs.device),
                    xs,
                    torch.tensor([torch.finfo(torch.float32).max], device=xs.device)], dim=0)
    ys = torch.cat([ys[:1], ys, ys[-1:]], dim=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / torch.unsqueeze((xs[1:] - xs[:-1]), dim=-1)
    sloops = F.pad(sloops[:-1], (0, 0, 1, 1))

    # Solve for intercepts.
    intercepts = ys - sloops * torch.unsqueeze(xs, dim=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    mask_bk_indicator = torch.unsqueeze(xs, dim=-2) > torch.unsqueeze(x, dim=-1)
    idx = torch.argmax(mask_bk_indicator.to(torch.int32), dim=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * torch.unsqueeze(x, dim=-1) + intercept
    y = y.to(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_torch(trackdata, device="cpu"):
    trackdata = torch.from_numpy(trackdata).to(device)
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_torch(pose)
    lhand = matrix_interp_torch(lhand)
    rhand = matrix_interp_torch(rhand)
    face = matrix_interp_torch(face)
    return torch.cat([pose, lhand, rhand, face], dim=1)

In [11]:
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

In [12]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_torch(trackdata, device=DEVICE)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_torch, trackdata=trackdata, device=DEVICE)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 147.673ms, Min 147.673ms, Mean +/- Std 147.673ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 20.5691ms, Min 1.33844ms, Mean +/- Std 2.20689ms +/- 2.63837ms
Top 10 summary: Max 1.40347ms, Min 1.33844ms, Mean +/- Std 1.37999ms +/- 22.2161µs


In [13]:
# Torch.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_torch(trackdata[:-1], device=DEVICE)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_torch, trackdata=trackdata[:-1], device=DEVICE)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 10.7508ms, Min 10.7508ms, Mean +/- Std 10.7508ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 6.57484ms, Min 1.06022ms, Mean +/- Std 1.59762ms +/- 723.317µs
Top 10 summary: Max 1.30698ms, Min 1.06022ms, Mean +/- Std 1.23992ms +/- 67.8562µs


## 4.2 Implementation based on define-and-run (JIT compile)

In [14]:
@torch.jit.script
def matrix_interp_torch_jit(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = torch.where(mask != 0)[0]
    ys = track.reshape([tlength, -1])[xs, :]
    x = torch.arange(tlength, device=xs.device)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.to(ys.dtype)
    x = x.to(ys.dtype)
    # Pad control points for extrapolation.
    # Unexpectedly, torch.finfo(torch.float64).min returns -inf.
    # So we use torch.finfo(torch.float32).min alternatively.
    # xs = torch.cat([torch.tensor([torch.finfo(torch.float32).min]), xs, torch.tensor([torch.finfo(torch.float32).max])], dim=0)
    # torch.finfo is not been supported in JIT.
    xs = torch.cat([torch.tensor([-1000], device=xs.device),
                    xs,
                    torch.tensor([1000], device=xs.device)], dim=0)
    ys = torch.cat([ys[:1], ys, ys[-1:]], dim=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / torch.unsqueeze((xs[1:] - xs[:-1]), dim=-1)
    sloops = F.pad(sloops[:-1], (0, 0, 1, 1))

    # Solve for intercepts.
    intercepts = ys - sloops * torch.unsqueeze(xs, dim=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    mask_bk_indicator = torch.unsqueeze(xs, dim=-2) > torch.unsqueeze(x, dim=-1)
    idx = torch.argmax(mask_bk_indicator.to(torch.int32), dim=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * torch.unsqueeze(x, dim=-1) + intercept
    y = y.to(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_torch_jit(trackdata, device="cpu"):
    trackdata = torch.from_numpy(trackdata).to(device)
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_torch_jit(pose)
    lhand = matrix_interp_torch_jit(lhand)
    rhand = matrix_interp_torch_jit(rhand)
    face = matrix_interp_torch_jit(face)
    return torch.cat([pose, lhand, rhand, face], dim=1)

In [15]:
# Torch.
with torch.jit.optimized_execution(JIT_OPT):
    # The 1st call may be slow because of the computation graph construction.
    print(f"Time of first call.")
    start = time.time()
    newtrack = partsbased_interp_torch_jit(trackdata, device=DEVICE)
    interval = time.time() - start
    print_perf_time(np.array([interval]))

    diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
    print(f"Sum of error:{diff}")

    print("Time after second call.")
    target_fn = partial(partsbased_interp_torch_jit, trackdata=trackdata, device=DEVICE)
    pmeasure(target_fn)

Time of first call.
Overall summary: Max 6.38986ms, Min 6.38986ms, Mean +/- Std 6.38986ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 23.9719ms, Min 1.15678ms, Mean +/- Std 3.9527ms +/- 4.79447ms
Top 10 summary: Max 1.25467ms, Min 1.15678ms, Mean +/- Std 1.21958ms +/- 30.8053µs


In [16]:
# Torch.
with torch.jit.optimized_execution(JIT_OPT):
    # The 1st call may be slow because of the computation graph construction.
    print(f"Time of first call.")
    start = time.time()
    newtrack = partsbased_interp_torch_jit(trackdata[:-1], device=DEVICE)
    interval = time.time() - start
    print_perf_time(np.array([interval]))

    diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
    print(f"Sum of error:{diff}")

    print("Time after second call.")
    target_fn = partial(partsbased_interp_torch_jit, trackdata=trackdata[:-1], device=DEVICE)
    pmeasure(target_fn)

Time of first call.
Overall summary: Max 2.16985ms, Min 2.16985ms, Mean +/- Std 2.16985ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 22.0589ms, Min 1.09035ms, Mean +/- Std 2.68458ms +/- 3.44788ms
Top 10 summary: Max 1.16749ms, Min 1.09035ms, Mean +/- Std 1.14482ms +/- 22.0932µs
