<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_interp_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [None]:
# Standard modules.
import gc
import sys
import time
from functools import partial

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

# Enable float64.
jax.config.update("jax_enable_x64", True)

In [None]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"JAX:{jax.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
JAX:0.4.16


# 2. Load data

In [None]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-10-29 11:46:09--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-10-29 11:46:09--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-10-29 11:46:09 (31.0 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [None]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

--2023-10-29 11:46:09--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy [following]
--2023-10-29 11:46:09--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_interp.npy’


2023-10-29 11:46:09 (31.5 MB/s) - ‘finger_far0_non_static_interp.

In [None]:
!ls

finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data


# 3. Evaluation settings

In [None]:
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str

In [None]:
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    if top_k:
        print(f"Top {top_k} summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
    else:
        print(f"Overall summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")

In [None]:
class PerfMeasure():
    def __init__(self,
                 trials=100,
                 top_k=10):
        self.trials = trials
        self.top_k = top_k

    def __call__(self, func):
        gc.collect()
        gc.disable()
        intervals = []
        for _ in range(self.trials):
            start = time.perf_counter()
            func()
            end = time.perf_counter()
            intervals.append(end - start)
        intervals = np.array(intervals)
        print_perf_time(intervals)
        if self.top_k:
            print_perf_time(intervals, self.top_k)
        gc.enable()
        gc.collect()

In [None]:
TRIALS = 100
TOPK = 10
pmeasure = PerfMeasure(TRIALS, TOPK)

# 4. JAX implementation

## 4.1 Implementation based on define-by-run mode

In [None]:
def simple_interp_jax(trackdata):
    tlength, num_joints, _ = trackdata.shape
    newtrack = jnp.zeros_like(trackdata)
    for i in range(num_joints):
        temp = trackdata[:, i, :]
        mask = temp[:, -1] != 0
        valid = mask.sum()
        if valid == tlength:
            newtrack = newtrack.at[:, i].set(temp)
            continue
        xs = jnp.where(mask != 0, size=valid)[0]
        # ys = temp[xs, :] <- can't be compiled.
        ys = jnp.take(temp, xs, axis=0)
        newys = jnp.zeros_like(temp)
        for j in range(temp.shape[-1]):
            newy = jnp.interp(jnp.arange(tlength), xs, ys[:, j])
            newys = newys.at[:, j].set(newy)
        newtrack = newtrack.at[:, i].set(newys)
    return newtrack


def matrix_interp_jax(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = jnp.where(mask != 0, size=valid)[0]
    # ys = track.reshape([tlength, -1])[xs, :] <- can't be compiled
    ys = jnp.take(track.reshape([tlength, -1]), xs, axis=0)
    x = jnp.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.astype(ys.dtype)
    x = x.astype(ys.dtype)
    # Pad control points for extrapolation.
    xs = jnp.concatenate([jnp.array([jnp.finfo(xs.dtype).min]), xs, jnp.array([jnp.finfo(xs.dtype).max])], axis=0)
    ys = jnp.concatenate([ys[:1], ys, ys[-1:]], axis=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / jnp.expand_dims((xs[1:] - xs[:-1]), axis=-1)
    sloops = jnp.pad(sloops[:-1], [(1, 1), (0, 0)])

    # Solve for intercepts.
    intercepts = ys - sloops * jnp.expand_dims(xs, axis=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    idx = jnp.argmax(jnp.expand_dims(xs, axis=-2) > jnp.expand_dims(x, axis=-1), axis=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * jnp.expand_dims(x, axis=-1) + intercept
    y = y.astype(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_jax(trackdata):
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_jax(pose)
    lhand = matrix_interp_jax(lhand)
    rhand = matrix_interp_jax(rhand)
    face = matrix_interp_jax(face)
    return jnp.concatenate([pose, lhand, rhand, face], axis=1)

In [None]:
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

In [None]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax(trackdata)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax, trackdata=trackdata)
pmeasure(target_fn)

Time of first call.




Overall summary: Max 3.16399s, Min 3.16399s, Mean +/- Std 3.16399s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 3.83771s, Min 884.631ms, Mean +/- Std 1.11187s +/- 372.653ms
Top 10 summary: Max 900.855ms, Min 884.631ms, Mean +/- Std 894.798ms +/- 5.34435ms


In [None]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax(trackdata[:-1])
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax, trackdata=trackdata[:-1])
pmeasure(target_fn)

Time of first call.
Overall summary: Max 1.97251s, Min 1.97251s, Mean +/- Std 1.97251s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 1.35904s, Min 882.56ms, Mean +/- Std 993.649ms +/- 128.017ms
Top 10 summary: Max 906.912ms, Min 882.56ms, Mean +/- Std 900.938ms +/- 6.57869ms


In [None]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax(trackdata)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax, trackdata=trackdata)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 660.329ms, Min 660.329ms, Mean +/- Std 660.329ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 18.4818ms, Min 9.62976ms, Mean +/- Std 11.2587ms +/- 1.82495ms
Top 10 summary: Max 9.92951ms, Min 9.62976ms, Mean +/- Std 9.77754ms +/- 98.2471µs


In [None]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax(trackdata[:-1])
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax, trackdata=trackdata[:-1])
pmeasure(target_fn)

Time of first call.
Overall summary: Max 596.675ms, Min 596.675ms, Mean +/- Std 596.675ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 15.9955ms, Min 9.28973ms, Mean +/- Std 10.7838ms +/- 1.23746ms
Top 10 summary: Max 9.70015ms, Min 9.28973ms, Mean +/- Std 9.60985ms +/- 124.862µs


## 4.2 Implementation based on define-and-run (JIT compile)

In [None]:
from typing import Generic, TypeVar
from functools import partial

In [None]:
T = TypeVar('T')      # Declare type variable

# Workaround to avoid unhashable error.
# https://github.com/google/jax/issues/4572
class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

In [None]:
@partial(jit, static_argnums=(0,))
def simple_interp_jax_jit(trackdata):
    tlength, num_joints, _ = trackdata.shape
    newtrack = jnp.zeros_like(trackdata)
    for i in range(num_joints):
        temp = trackdata[:, i, :]
        mask = temp[:, -1] != 0
        valid = mask.sum()
        if valid == tlength:
            newtrack = newtrack.at[:, i].set(temp)
            continue
        xs = jnp.where(mask != 0, size=valid)[0]
        # ys = temp[xs, :] <- can't be compiled.
        ys = jnp.take(temp, xs, axis=0)
        newys = jnp.zeros_like(temp)
        for j in range(temp.shape[-1]):
            newy = jnp.interp(jnp.arange(tlength), xs, ys[:, j])
            newys = newys.at[:, j].set(newy)
        newtrack = newtrack.at[:, i].set(newys)
    return newtrack


def matrix_interp_jax_jit(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = jnp.where(mask != 0, size=valid)[0]
    # ys = track.reshape([tlength, -1])[xs, :] <- can't be compiled
    ys = jnp.take(track.reshape([tlength, -1]), xs, axis=0)
    x = jnp.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.astype(ys.dtype)
    x = x.astype(ys.dtype)
    # Pad control points for extrapolation.
    xs = jnp.concatenate([jnp.array([jnp.finfo(xs.dtype).min]), xs, jnp.array([jnp.finfo(xs.dtype).max])], axis=0)
    ys = jnp.concatenate([ys[:1], ys, ys[-1:]], axis=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / jnp.expand_dims((xs[1:] - xs[:-1]), axis=-1)
    sloops = jnp.pad(sloops[:-1], [(1, 1), (0, 0)])

    # Solve for intercepts.
    intercepts = ys - sloops * jnp.expand_dims(xs, axis=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    idx = jnp.argmax(jnp.expand_dims(xs, axis=-2) > jnp.expand_dims(x, axis=-1), axis=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * jnp.expand_dims(x, axis=-1) + intercept
    y = y.astype(ys.dtype)
    y = y.reshape(orig_shape)
    return y


@partial(jit, static_argnums=(0,))
def partsbased_interp_jax_jit(trackdata):
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_jax_jit(pose)
    lhand = matrix_interp_jax_jit(lhand)
    rhand = matrix_interp_jax_jit(rhand)
    face = matrix_interp_jax_jit(face)
    return jnp.concatenate([pose, lhand, rhand, face], axis=1)

In [None]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata))
pmeasure(target_fn)

Time of first call.
Overall summary: Max 31.0618s, Min 31.0618s, Mean +/- Std 31.0618s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 27.1741s, Min 2.90357ms, Mean +/- Std 274.897ms +/- 2.70347s
Top 10 summary: Max 2.98753ms, Min 2.90357ms, Mean +/- Std 2.94133ms +/- 32.6188µs


In [None]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata[:-1]))
pmeasure(target_fn)

Time of first call.
Overall summary: Max 26.039s, Min 26.039s, Mean +/- Std 26.039s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 25.4215s, Min 2.93153ms, Mean +/- Std 257.416ms +/- 2.52909s
Top 10 summary: Max 3.00317ms, Min 2.93153ms, Mean +/- Std 2.97888ms +/- 21.8667µs


In [None]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata))
pmeasure(target_fn)

Time of first call.
Overall summary: Max 361.726ms, Min 361.726ms, Mean +/- Std 361.726ms +/-   0s
Sum of error:-2.2037927038809357e-13
Time after second call.
Overall summary: Max 349.023ms, Min 2.00301ms, Mean +/- Std 5.59448ms +/- 34.5168ms
Top 10 summary: Max 2.03389ms, Min 2.00301ms, Mean +/- Std 2.0225ms +/- 9.97664µs


In [None]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata[:-1]))
pmeasure(target_fn)

Time of first call.
Overall summary: Max 345.807ms, Min 345.807ms, Mean +/- Std 345.807ms +/-   0s
Sum of error:-2.2037927038809357e-13
Time after second call.
Overall summary: Max 333.857ms, Min 2.06338ms, Mean +/- Std 5.57588ms +/- 32.9948ms
Top 10 summary: Max 2.12235ms, Min 2.06338ms, Mean +/- Std 2.10369ms +/- 18.2378µs
