<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_affine_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import gc
import sys
import time
from functools import partial

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

In [2]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"JAX:{jax.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
JAX:0.4.16


# 2. Load data

In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-10-31 04:21:59--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 192.30.255.113
Connecting to github.com (github.com)|192.30.255.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-10-31 04:21:59--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-10-31 04:22:00 (41.6 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [4]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy

--2023-10-31 04:22:00--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy
Resolving github.com (github.com)... 20.29.134.23
Connecting to github.com (github.com)|20.29.134.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy [following]
--2023-10-31 04:22:00--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_affine.npy’


2023-10-31 04:22:00 (37.3 MB/s) - ‘finger_far0_non_static_affine.

In [5]:
!ls

finger_far0_non_static_affine.npy  finger_far0_non_static.npy  sample_data


# 3. Evaluation settings

In [6]:
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str

In [7]:
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    if top_k:
        print(f"Top {top_k} summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
    else:
        print(f"Overall summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")

In [8]:
class PerfMeasure():
    def __init__(self,
                 trials=100,
                 top_k=10):
        self.trials = trials
        self.top_k = top_k

    def __call__(self, func):
        gc.collect()
        gc.disable()
        intervals = []
        for _ in range(self.trials):
            start = time.perf_counter()
            func()
            end = time.perf_counter()
            intervals.append(end - start)
        intervals = np.array(intervals)
        print_perf_time(intervals)
        if self.top_k:
            print_perf_time(intervals, self.top_k)
        gc.enable()
        gc.collect()

In [9]:
TRIALS = 100
TOPK = 10
pmeasure = PerfMeasure(TRIALS, TOPK)

# 4. Implement affine transformation

## 4.1 Based on define-by-run

In [10]:
def get_affine_matrix_2d_jax(center,
                             trans,
                             scale,
                             rot,
                             skew,
                             dtype=jnp.float32):
    center_m = jnp.array([[1.0, 0.0, -center[0]],
                          [0.0, 1.0, -center[1]],
                          [0.0, 0.0, 1.0]])
    scale_m = jnp.array([[scale[0], 0.0, 0.0],
                         [0.0, scale[1], 0.0],
                         [0.0, 0.0, 1.0]])
    _cos = jnp.cos(rot)
    _sin = jnp.sin(rot)
    rot_m = jnp.array([[_cos, -_sin, 0.0],
                       [_sin, _cos, 0],
                       [0.0, 0.0, 1.0]])
    _tan = jnp.tan(skew)
    skew_m = jnp.array([[1.0, _tan[0], 0.0],
                        [_tan[1], 1.0, 0.0],
                        [0.0, 0.0, 1.0]])
    move = jnp.array(center) + jnp.array(trans)
    trans_m = jnp.array([[1.0, 0.0, move[0]],
                         [0.0, 1.0, move[1]],
                         [0.0, 0.0, 1.0]])
    # Make affine matrix.
    mat = jnp.identity(3, dtype=dtype)
    mat = jnp.matmul(center_m, mat)
    mat = jnp.matmul(scale_m, mat)
    mat = jnp.matmul(rot_m, mat)
    mat = jnp.matmul(skew_m, mat)
    mat = jnp.matmul(trans_m, mat)
    return mat.astype(dtype)

In [11]:
def apply_affine_jax(inputs, mat):
    # Apply transform.
    xy = inputs[:, :, :2]
    xy = jnp.concatenate([xy, jnp.ones([xy.shape[0], xy.shape[1], 1])], axis=-1)
    xy = jnp.einsum("...j,ij", xy, mat)
    inputs = inputs.at[:, :, :2].set(xy[:, :, :-1])
    return inputs

In [12]:
# Load data.
trackfile = "./finger_far0_non_static.npy"
reffile = "./finger_far0_non_static_affine.npy"
trackdata = np.load(trackfile).astype(np.float32)
refdata = np.load(reffile).astype(np.float32)
print(trackdata.shape)

# Remove person axis.
trackdata = trackdata[0]
refdata = refdata[0]

# Convert to jnp.array
trackdata = jnp.array(trackdata)
refdata = jnp.array(refdata)



(1, 130, 553, 4)


In [13]:
 # Get affine matrix.
center = jnp.array([638.0, 389.0])
trans = jnp.array([100.0, 0.0])
scale = jnp.array([2.0, 0.5])
rot = float(jnp.radians(15.0))
skew = jnp.radians(jnp.array([15.0, 15.0]))
dtype = jnp.float32
print("Parameters")
print("Center:", center)
print("Trans:", trans)
print("Scale:", scale)
print("Rot:", rot)
print("Skew:", skew)

Parameters
Center: [638. 389.]
Trans: [100.   0.]
Scale: [2.  0.5]
Rot: 0.2617993950843811
Skew: [0.2617994 0.2617994]


In [14]:
def perf_wrap_func(trackdata, center, trans, scale, rot, skew, dtype):
    mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_jax(trackdata, mat)

In [15]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_jax(testtrack, mat)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

# Evaluate difference.
diff = (jnp.round(newtrack) - jnp.round(refdata)).sum()
print(f"Sum of error:{diff}")

testtrack = trackdata.copy()

print("Time after second call.")
target_fn = partial(perf_wrap_func,
                    trackdata=testtrack,
                    center=center, trans=trans, scale=scale, rot=rot, skew=skew,
                    dtype=dtype)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 662.249ms, Min 662.249ms, Mean +/- Std 662.249ms +/-   0s
Sum of error:0.0
Time after second call.
Overall summary: Max 115.266ms, Min 17.1939ms, Mean +/- Std 36.2639ms +/- 16.9179ms
Top 10 summary: Max 17.8094ms, Min 17.1939ms, Mean +/- Std 17.4581ms +/- 166.24µs


In [16]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_jax(testtrack[:-1], mat)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

testtrack = trackdata.copy()

print("Time after second call.")
target_fn = partial(perf_wrap_func,
                    trackdata=testtrack[:-1],
                    center=center, trans=trans, scale=scale, rot=rot, skew=skew,
                    dtype=dtype)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 324.22ms, Min 324.22ms, Mean +/- Std 324.22ms +/-   0s
Time after second call.
Overall summary: Max 142.864ms, Min 27.1088ms, Mean +/- Std 65.1514ms +/- 26.8165ms
Top 10 summary: Max 30.4011ms, Min 27.1088ms, Mean +/- Std 29.4172ms +/- 961.927µs


## 4.2 Based on define-and-run

In [17]:
@jit
def get_affine_matrix_2d_jax_jit(center,
                                 trans,
                                 scale,
                                 rot,
                                 skew):
    center_m = jnp.array([[1.0, 0.0, -center[0]],
                          [0.0, 1.0, -center[1]],
                          [0.0, 0.0, 1.0]])
    scale_m = jnp.array([[scale[0], 0.0, 0.0],
                         [0.0, scale[1], 0.0],
                         [0.0, 0.0, 1.0]])
    _cos = jnp.cos(rot)
    _sin = jnp.sin(rot)
    rot_m = jnp.array([[_cos, -_sin, 0.0],
                       [_sin, _cos, 0],
                       [0.0, 0.0, 1.0]])
    _tan = jnp.tan(skew)
    skew_m = jnp.array([[1.0, _tan[0], 0.0],
                        [_tan[1], 1.0, 0.0],
                        [0.0, 0.0, 1.0]])
    move = jnp.array(center) + jnp.array(trans)
    trans_m = jnp.array([[1.0, 0.0, move[0]],
                         [0.0, 1.0, move[1]],
                         [0.0, 0.0, 1.0]])
    # Make affine matrix.
    mat = jnp.identity(3)
    mat = jnp.matmul(center_m, mat)
    mat = jnp.matmul(scale_m, mat)
    mat = jnp.matmul(rot_m, mat)
    mat = jnp.matmul(skew_m, mat)
    mat = jnp.matmul(trans_m, mat)
    return mat

In [18]:
@jit
def apply_affine_jax_jit(inputs, mat):
    # Apply transform.
    xy = inputs[:, :, :2]
    xy = jnp.concatenate([xy, jnp.ones([xy.shape[0], xy.shape[1], 1])], axis=-1)
    xy = jnp.einsum("...j,ij", xy, mat)
    inputs = inputs.at[:, :, :2].set(xy[:, :, :-1])
    return inputs

In [19]:
def perf_wrap_func(trackdata, center, trans, scale, rot, skew):
    mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
    newtrack = apply_affine_jax_jit(trackdata, mat)

In [20]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
newtrack = apply_affine_jax_jit(testtrack, mat)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

# Evaluate difference.
diff = (jnp.round(newtrack) - jnp.round(refdata)).sum()
print(f"Sum of error:{diff}")

testtrack = trackdata.copy()

print("Time after second call.")
target_fn = partial(perf_wrap_func,
                    trackdata=testtrack,
                    center=center, trans=trans, scale=scale, rot=rot, skew=skew)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 445.719ms, Min 445.719ms, Mean +/- Std 445.719ms +/-   0s
Sum of error:0.0
Time after second call.
Overall summary: Max 14.5279ms, Min 1.23261ms, Mean +/- Std 2.41537ms +/- 2.43119ms
Top 10 summary: Max 1.36594ms, Min 1.23261ms, Mean +/- Std 1.2714ms +/- 43.6211µs


In [21]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
newtrack = apply_affine_jax_jit(testtrack[:-1], mat)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

# Evaluate difference.
diff = (jnp.round(newtrack) - jnp.round(refdata[:-1])).sum()
print(f"Sum of error:{diff}")

testtrack = trackdata.copy()

print("Time after second call.")
target_fn = partial(perf_wrap_func,
                    trackdata=testtrack[:-1],
                    center=center, trans=trans, scale=scale, rot=rot, skew=skew)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 164.981ms, Min 164.981ms, Mean +/- Std 164.981ms +/-   0s
Sum of error:0.0
Time after second call.
Overall summary: Max 8.83706ms, Min 1.21181ms, Mean +/- Std 2.05039ms +/- 1.7484ms
Top 10 summary: Max 1.34863ms, Min 1.21181ms, Mean +/- Std 1.26267ms +/- 44.5478µs


# 5. Application to randomized transformation

## 5.1 Implement1: Call JIT function from a python process.

In [22]:
class RandomAffineTransform2D_JAX():
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None):
        self.center_joints = center_joints
        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = jnp.radians(jnp.array(rot_range))
        self.skew_range = jnp.radians(jnp.array(skew_range))
        if random_seed is not None:
            self.rng = jax.random.PRNGKey(random_seed)
        else:
            self.rng = jax.random.PRNGKey(0)

    def gen_uniform_and_update_key(self, low=0.0, high=1.0, shape=(1,)):
        # Generate random value.
        val = jax.random.uniform(self.rng, shape)
        # Scale to target range.
        val = (high - low) * val + low
        # Update key.
        self.rng = jax.random.split(self.rng, num=1)[0]
        return val

    def __call__(self, inputs):
        if self.gen_uniform_and_update_key() >= self.apply_ratio:
            return inputs

        # Calculate center position.
        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = jnp.sum(temp, axis=(1, 2)) != 0
        # Use x and y only.
        center = temp[mask].mean(axis=0).mean(axis=0)[:2]

        trans = self.gen_uniform_and_update_key(
            self.trans_range[0], self.trans_range[1], (2,))
        scale = self.gen_uniform_and_update_key(
            self.scale_range[0], self.scale_range[1], (2,))
        rot = self.gen_uniform_and_update_key(
            self.rot_range[0], self.rot_range[1], (1,))[0]
        skew = self.gen_uniform_and_update_key(
            self.skew_range[0], self.skew_range[1], (2,))

        # Calculate matrix.
        mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)

        # Apply transform.
        inputs = apply_affine_jax_jit(inputs, mat)
        return inputs

In [23]:
aug_fn = RandomAffineTransform2D_JAX(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0])

In [24]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
temp = aug_fn(testtrack)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

testtrack = trackdata.copy()
print("Time after second call.")
target_fn = partial(aug_fn, inputs=testtrack)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 1.44494s, Min 1.44494s, Mean +/- Std 1.44494s +/-   0s
Time after second call.
Overall summary: Max 12.2906ms, Min 6.80869ms, Mean +/- Std 7.3811ms +/- 859.549µs
Top 10 summary: Max 6.9176ms, Min 6.80869ms, Mean +/- Std 6.86691ms +/- 41.6897µs


In [25]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
temp = aug_fn(testtrack[:-1])
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

testtrack = trackdata.copy()
print("Time after second call.")
target_fn = partial(aug_fn, inputs=testtrack[:-1])
pmeasure(target_fn)

Time of first call.
Overall summary: Max 262.183ms, Min 262.183ms, Mean +/- Std 262.183ms +/-   0s
Time after second call.
Overall summary: Max 12.4024ms, Min 6.62738ms, Mean +/- Std 7.4838ms +/- 1.09016ms
Top 10 summary: Max 6.83709ms, Min 6.62738ms, Mean +/- Std 6.77123ms +/- 66.6792µs


## 5.2 Implementation2: Apply JIT to whole affine process.

In [26]:
class RandomAffineTransform2D_JAX_JIT():
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None,
                 dtype=np.float32):
        self.center_joints = center_joints
        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = jnp.radians(jnp.array(rot_range))
        self.skew_range = jnp.radians(jnp.array(skew_range))
        self.dtype = dtype
        if random_seed is not None:
            self.rng = jax.random.PRNGKey(random_seed)
        else:
            self.rng = jax.random.PRNGKey(0)

    def gen_uniform_and_update_key(self, rng, low=0.0, high=1.0, shape=(2,)):
        # Generate random value.
        val = jax.random.uniform(rng, shape)
        # Scale to target range.
        val = (high - low) * val + low
        # Update key.
        rng = jax.random.split(rng, num=1)[0]
        return val, rng

    def apply(self, inputs, rng):
        # Calculate center position.
        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = jnp.sum(temp, axis=(1, 2)) != 0
        mask = mask.astype(self.dtype)

        temp = temp * mask[:, None, None]
        mask_sum = jnp.sum(mask)
        # `[T, J, C] -> [J, C] -> [C]`
        center = temp.sum(axis=0) / mask_sum
        center = center.mean(axis=0)
        # Use x and y only.
        center = center[:2]

        trans, rng = self.gen_uniform_and_update_key(rng,
            self.trans_range[0], self.trans_range[1], (2,))
        scale, rng = self.gen_uniform_and_update_key(rng,
            self.scale_range[0], self.scale_range[1], (2,))
        rot, rng = self.gen_uniform_and_update_key(rng,
            self.rot_range[0], self.rot_range[1], (2,))
        rot = rot[0]
        skew, rng = self.gen_uniform_and_update_key(rng,
            self.skew_range[0], self.skew_range[1], (2,))

        # Calculate matrix.
        mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)

        # Apply transform.
        inputs = apply_affine_jax_jit(inputs, mat)
        return inputs, rng

    @partial(jit, static_argnums=(0,))
    def affine_proc(self, inputs, rng):
        val, rng = self.gen_uniform_and_update_key(rng)
        retval, rng = jax.lax.cond(
            (val >= self.apply_ratio).astype(jnp.int32)[0],
            lambda: (inputs, rng),
            lambda: self.apply(inputs, rng))
        return retval, rng

    def __call__(self, inputs):
        rng = self.rng
        retval, rng = self.affine_proc(inputs, rng)
        self.rng = rng
        return retval

In [27]:
aug_fn = RandomAffineTransform2D_JAX_JIT(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0],
    dtype=dtype)

In [28]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
temp = aug_fn(testtrack)
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

testtrack = trackdata.copy()
print("Time after second call.")
target_fn = partial(aug_fn, inputs=testtrack)
pmeasure(target_fn)

Time of first call.
Overall summary: Max 759.034ms, Min 759.034ms, Mean +/- Std 759.034ms +/-   0s
Time after second call.
Overall summary: Max 2.33457ms, Min 648.916µs, Mean +/- Std 727.373µs +/- 231.14µs
Top 10 summary: Max 657.18µs, Min 648.916µs, Mean +/- Std 654.719µs +/- 2.50747µs


In [29]:
testtrack = trackdata.copy()

# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.perf_counter()
temp = aug_fn(testtrack[:-1])
interval = time.perf_counter() - start
print_perf_time(np.array([interval]))

testtrack = trackdata.copy()
print("Time after second call.")
target_fn = partial(aug_fn, inputs=testtrack[:-1])
pmeasure(target_fn)

Time of first call.
Overall summary: Max 729.709ms, Min 729.709ms, Mean +/- Std 729.709ms +/-   0s
Time after second call.
Overall summary: Max 3.6329ms, Min 653.367µs, Mean +/- Std 763.583µs +/- 352.963µs
Top 10 summary: Max 664.35µs, Min 653.367µs, Mean +/- Std 662.35µs +/- 3.05068µs
