<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_interp_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import time

# CV/ML.
import numpy as np

import tensorflow as tf

# 2. Load data

In [2]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-09-29 10:03:05--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-09-29 10:03:05--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-09-29 10:03:06 (27.8 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

--2023-09-29 10:03:06--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy [following]
--2023-09-29 10:03:06--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_interp.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_interp.npy’


2023-09-29 10:03:06 (26.0 MB/s) - ‘finger_far0_non_static_interp.

In [4]:
!ls

finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data


# 3. Tensorflow implementation

## 3.1 Implementation based on define-by-run (eager execution)

In [5]:
def matrix_interp_tf_eager(track):
    orig_shape = tf.shape(track)
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = tf.reduce_sum(tf.cast(mask, dtype=tf.int32))
    if valid == tlength:
        y = track
    else:
        xs = tf.where(mask)
        xs = tf.reshape(xs, [valid])
        # determine the output data type
        ys = tf.reshape(track, [tlength, -1])
        ys = tf.gather(ys, xs, axis=0)
        x = tf.range(tlength)
        dtype_ys = ys.dtype

        # normalize data types
        xs = tf.cast(xs, dtype_ys)
        x = tf.cast(x, dtype_ys)

        # pad control points for extrapolation
        xs = tf.concat([[xs.dtype.min], xs, [xs.dtype.max]], axis=0)
        ys = tf.concat([ys[:1], ys, ys[-1:]], axis=0)

        # compute slopes, pad at the edges to flatten
        ms = (ys[1:] - ys[:-1]) / tf.expand_dims((xs[1:] - xs[:-1]), axis=-1)
        ms = tf.pad(ms[:-1], [(1, 1), (0, 0)])

        # solve for intercepts
        bs = ys - ms * tf.expand_dims(xs, axis=-1)

        # search for the line parameters at each input data point
        # create a grid of the inputs and piece breakpoints for thresholding
        # rely on argmax stopping on the first true when there are duplicates,
        # which gives us an index into the parameter vectors
        i = tf.math.argmax(tf.expand_dims(xs, axis=-2) > tf.expand_dims(x, axis=-1), axis=-1)
        m = tf.gather(ms, i, axis=0)
        b = tf.gather(bs, i, axis=0)

        # apply the linear mapping at each input data point
        y = m*tf.expand_dims(x, axis=-1) + b
        y = tf.cast(y, dtype_ys)
        y = tf.reshape(y, orig_shape)
    return y


def partsbased_interp_tf_eager(trackdata):
    num_joints = trackdata.shape[1]
    trackdata = tf.convert_to_tensor(trackdata)
    pose = tf.gather(trackdata, tf.range(0, 33), axis=1)
    lhand = tf.gather(trackdata, tf.range(33, 33+21), axis=1)
    rhand = tf.gather(trackdata, tf.range(33+21, 33+21+21), axis=1)
    face = tf.gather(trackdata, tf.range(33+21+21, num_joints), axis=1)

    pose = matrix_interp_tf_eager(pose)
    lhand = matrix_interp_tf_eager(lhand)
    rhand = matrix_interp_tf_eager(rhand)
    face = matrix_interp_tf_eager(face)
    return tf.concat([pose, lhand, rhand, face], axis=1)

In [6]:
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

In [7]:
# Tensorflow.
# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
newtrack = partsbased_interp_tf_eager(trackdata)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(10):
  newtrack = partsbased_interp_tf_eager(trackdata)
interval = time.perf_counter() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack.numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.24913760500000137
Average time after second call:0.013600808399999664
Sum of error:-6.935119145623503e-12


In [8]:
# Tensorflow.
# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
newtrack = partsbased_interp_tf_eager(trackdata[:-1])
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(10):
  newtrack = partsbased_interp_tf_eager(trackdata[:-1])
interval = time.perf_counter() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack.numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.017207738999999833
Average time after second call:0.014746782699999983
Sum of error:-6.935119145623503e-12


## 3.2 Implementation based on define-and-run (tf.function)

In [9]:
# If input_signature is omitted, the re-tracing is performed when a tensor's shape is changed.
@tf.function(input_signature=(tf.TensorSpec(shape=[None, None, 4], dtype=tf.float64),))
def matrix_interp_tf(track):
    orig_shape = tf.shape(track)
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = tf.reduce_sum(tf.cast(mask, dtype=tf.int32))
    if valid == tlength:
        y = track
    else:
        xs = tf.where(mask)
        xs = tf.reshape(xs, [valid])
        # determine the output data type
        ys = tf.reshape(track, [tlength, -1])
        ys = tf.gather(ys, xs, axis=0)
        x = tf.range(tlength)
        dtype_ys = ys.dtype

        # normalize data types
        xs = tf.cast(xs, dtype_ys)
        x = tf.cast(x, dtype_ys)

        # pad control points for extrapolation
        xs = tf.concat([[xs.dtype.min], xs, [xs.dtype.max]], axis=0)
        ys = tf.concat([ys[:1], ys, ys[-1:]], axis=0)

        # compute slopes, pad at the edges to flatten
        ms = (ys[1:] - ys[:-1]) / tf.expand_dims((xs[1:] - xs[:-1]), axis=-1)
        ms = tf.pad(ms[:-1], [(1, 1), (0, 0)])

        # solve for intercepts
        bs = ys - ms * tf.expand_dims(xs, axis=-1)

        # search for the line parameters at each input data point
        # create a grid of the inputs and piece breakpoints for thresholding
        # rely on argmax stopping on the first true when there are duplicates,
        # which gives us an index into the parameter vectors
        i = tf.math.argmax(tf.expand_dims(xs, axis=-2) > tf.expand_dims(x, axis=-1), axis=-1)
        m = tf.gather(ms, i, axis=0)
        b = tf.gather(bs, i, axis=0)

        # apply the linear mapping at each input data point
        y = m*tf.expand_dims(x, axis=-1) + b
        y = tf.cast(y, dtype_ys)
        y = tf.reshape(y, orig_shape)
    return y


def partsbased_interp_tf(trackdata):
    num_joints = trackdata.shape[1]
    trackdata = tf.convert_to_tensor(trackdata)
    pose = tf.gather(trackdata, tf.range(0, 33), axis=1)
    lhand = tf.gather(trackdata, tf.range(33, 33+21), axis=1)
    rhand = tf.gather(trackdata, tf.range(33+21, 33+21+21), axis=1)
    face = tf.gather(trackdata, tf.range(33+21+21, num_joints), axis=1)

    pose = matrix_interp_tf(pose)
    lhand = matrix_interp_tf(lhand)
    rhand = matrix_interp_tf(rhand)
    face = matrix_interp_tf(face)
    return tf.concat([pose, lhand, rhand, face], axis=1)

In [10]:
# Tensorflow.
# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
newtrack = partsbased_interp_tf(trackdata)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(10):
  newtrack = partsbased_interp_tf(trackdata)
interval = time.perf_counter() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack - newtrack.numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.3889126279999999
Average time after second call:0.006037442600000986
Sum of error:-6.935119145623503e-12


In [11]:
# Tensorflow.
# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
newtrack = partsbased_interp_tf(trackdata[:-1])
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(10):
  newtrack = partsbased_interp_tf(trackdata[:-1])
interval = time.perf_counter() - start
print(f"Average time after second call:{interval / 10}")

diff = (reftrack[:-1] - newtrack.numpy()).sum()
print(f"Sum of error:{diff}")

Time of first call:0.012303197999997906
Average time after second call:0.006144258700000194
Sum of error:-6.935119145623503e-12
