<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_interp_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import time

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

# Enable float64.
jax.config.update("jax_enable_x64", True)

# 2. Load data

In [2]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-09-29 10:17:59--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-09-29 10:17:59--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-09-29 10:17:59 (47.7 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

--2023-09-29 10:17:59--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy [following]
--2023-09-29 10:18:00--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_interp.npy’


2023-09-29 10:18:01 (45.2 MB/s) - ‘finger_far0_non_static_interp.

In [4]:
!ls

finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data


# 3. JAX implementation

## 3.1 Implementation based on define-by-run mode

In [5]:
def simple_interp_jax(trackdata):
    tlength, num_joints, _ = trackdata.shape
    newtrack = jnp.zeros_like(trackdata)
    for i in range(num_joints):
        temp = trackdata[:, i, :]
        mask = temp[:, -1] != 0
        valid = mask.sum()
        if valid == tlength:
            newtrack = newtrack.at[:, i].set(temp)
            continue
        # xs = jnp.where(mask != 0)[0] <- can't be compiled.
        xs = jnp.nonzero(mask, size=valid)[0]
        # ys = temp[xs, :] <- can't be compiled.
        ys = jnp.take(temp, xs, axis=0)
        newys = jnp.zeros_like(temp)
        for j in range(temp.shape[-1]):
            newy = jnp.interp(jnp.arange(tlength), xs, ys[:, j])
            newys = newys.at[:, j].set(newy)
        newtrack = newtrack.at[:, i].set(newys)
    return newtrack


def matrix_interp_jax(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    # xs = jnp.where(mask != 0)[0] <- can't be compiled.
    xs = jnp.nonzero(mask, size=valid)[0]
    # ys = track.reshape([tlength, -1])[xs, :] <- can't be compiled
    ys = jnp.take(track.reshape([tlength, -1]), xs, axis=0)
    x = jnp.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.astype(ys.dtype)
    x = x.astype(ys.dtype)
    # xs = xs.astype(jnp.float32)
    # x = x.astype(jnp.float32)
    # Pad control points for extrapolation.
    xs = jnp.concatenate([jnp.array([jnp.finfo(xs.dtype).min]), xs, jnp.array([jnp.finfo(xs.dtype).max])], axis=0)
    ys = jnp.concatenate([ys[:1], ys, ys[-1:]], axis=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / jnp.expand_dims((xs[1:] - xs[:-1]), axis=-1)
    sloops = jnp.pad(sloops[:-1], [(1, 1), (0, 0)])

    # Solve for intercepts.
    intercepts = ys - sloops * jnp.expand_dims(xs, axis=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    idx = jnp.argmax(jnp.expand_dims(xs, axis=-2) > jnp.expand_dims(x, axis=-1), axis=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * jnp.expand_dims(x, axis=-1) + intercept
    y = y.astype(ys.dtype)
    y = y.reshape(orig_shape)
    return y


def partsbased_interp_jax(trackdata):
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_jax(pose)
    lhand = matrix_interp_jax(lhand)
    rhand = matrix_interp_jax(rhand)
    face = matrix_interp_jax(face)
    return jnp.concatenate([pose, lhand, rhand, face], axis=1)

In [6]:
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

In [7]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = simple_interp_jax(trackdata)
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = simple_interp_jax(trackdata)
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")



Time of first call:5.231683015823364
Average time after second call:1.880511713027954
Sum of error:-6.195044477408373e-13


In [8]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = simple_interp_jax(trackdata[:-1])
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = simple_interp_jax(trackdata[:-1])
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:2.8428306579589844
Average time after second call:1.4648744344711304
Sum of error:-6.195044477408373e-13


In [9]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_jax(trackdata)
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_jax(trackdata)
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:0.628657341003418
Average time after second call:0.01871466636657715
Sum of error:-6.935119145623503e-12


In [10]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_jax(trackdata[:-1])
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_jax(trackdata[:-1])
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:0.5926902294158936
Average time after second call:0.0198394775390625
Sum of error:-6.935119145623503e-12


## 3.2 Implementation based on define-and-run (JIT compile)

In [11]:
from typing import Generic, TypeVar
from functools import partial

In [12]:
T = TypeVar('T')      # Declare type variable

# Workaround to avoid unhashable error.
# https://github.com/google/jax/issues/4572
class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

In [13]:
@partial(jit, static_argnums=(0,))
def simple_interp_jax_jit(trackdata):
    tlength, num_joints, _ = trackdata.shape
    newtrack = jnp.zeros_like(trackdata)
    for i in range(num_joints):
        temp = trackdata[:, i, :]
        mask = temp[:, -1] != 0
        valid = mask.sum()
        if valid == tlength:
            newtrack = newtrack.at[:, i].set(temp)
            continue
        # xs = jnp.where(mask != 0)[0] <- can't be compiled.
        xs = jnp.nonzero(mask, size=valid)[0]
        # ys = temp[xs, :] <- can't be compiled.
        ys = jnp.take(temp, xs, axis=0)
        newys = jnp.zeros_like(temp)
        for j in range(temp.shape[-1]):
            newy = jnp.interp(jnp.arange(tlength), xs, ys[:, j])
            newys = newys.at[:, j].set(newy)
        newtrack = newtrack.at[:, i].set(newys)
    return newtrack


def matrix_interp_jax_jit(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    # xs = jnp.where(mask != 0)[0] <- can't be compiled.
    xs = jnp.nonzero(mask, size=valid)[0]
    # ys = track.reshape([tlength, -1])[xs, :] <- can't be compiled
    ys = jnp.take(track.reshape([tlength, -1]), xs, axis=0)
    x = jnp.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.astype(ys.dtype)
    x = x.astype(ys.dtype)
    # xs = xs.astype(jnp.float32)
    # x = x.astype(jnp.float32)
    # Pad control points for extrapolation.
    xs = jnp.concatenate([jnp.array([jnp.finfo(xs.dtype).min]), xs, jnp.array([jnp.finfo(xs.dtype).max])], axis=0)
    ys = jnp.concatenate([ys[:1], ys, ys[-1:]], axis=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / jnp.expand_dims((xs[1:] - xs[:-1]), axis=-1)
    sloops = jnp.pad(sloops[:-1], [(1, 1), (0, 0)])

    # Solve for intercepts.
    intercepts = ys - sloops * jnp.expand_dims(xs, axis=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    idx = jnp.argmax(jnp.expand_dims(xs, axis=-2) > jnp.expand_dims(x, axis=-1), axis=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * jnp.expand_dims(x, axis=-1) + intercept
    y = y.astype(ys.dtype)
    y = y.reshape(orig_shape)
    return y


@partial(jit, static_argnums=(0,))
def partsbased_interp_jax_jit(trackdata):
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_jax_jit(pose)
    lhand = matrix_interp_jax_jit(lhand)
    rhand = matrix_interp_jax_jit(rhand)
    face = matrix_interp_jax_jit(face)
    return jnp.concatenate([pose, lhand, rhand, face], axis=1)

In [14]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:27.684165954589844
Average time after second call:0.006444549560546875
Sum of error:-6.195044477408373e-13


In [15]:
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:26.79814577102661
Average time after second call:0.0073321819305419925
Sum of error:-6.195044477408373e-13


In [16]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:0.6566777229309082
Average time after second call:0.009604620933532714
Sum of error:-2.2037927038809357e-13


In [17]:
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print(f"Time of first call:{interval}")

start = time.time()
for _ in range(10):
  newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

Time of first call:0.8227601051330566
Average time after second call:0.01083838939666748
Sum of error:-2.2037927038809357e-13
