<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_affine_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import sys
import time
from functools import partial
from pathlib import Path

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

In [2]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"JAX:{jax.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
JAX:0.4.16


# 2. Load data

In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-10-17 02:43:43--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-10-17 02:43:43--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-10-17 02:43:43 (30.9 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [4]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy

--2023-10-17 02:43:43--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy [following]
--2023-10-17 02:43:43--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_affine.npy’


2023-10-17 02:43:44 (32.2 MB/s) - ‘finger_far0_non_static_affine.

In [5]:
!ls

finger_far0_non_static_affine.npy  finger_far0_non_static.npy  sample_data


# 3. Implement affine transformation

## 3.1 Based on define-by-run

In [6]:
def get_affine_matrix_2d_jax(center,
                             trans,
                             scale,
                             rot,
                             skew,
                             dtype=jnp.float32):
    center_m = jnp.array([[1.0, 0.0, -center[0]],
                          [0.0, 1.0, -center[1]],
                          [0.0, 0.0, 1.0]])
    scale_m = jnp.array([[scale[0], 0.0, 0.0],
                         [0.0, scale[1], 0.0],
                         [0.0, 0.0, 1.0]])
    _cos = jnp.cos(rot)
    _sin = jnp.sin(rot)
    rot_m = jnp.array([[_cos, -_sin, 0.0],
                       [_sin, _cos, 0],
                       [0.0, 0.0, 1.0]])
    _tan = jnp.tan(skew)
    skew_m = jnp.array([[1.0, _tan[0], 0.0],
                        [_tan[1], 1.0, 0.0],
                        [0.0, 0.0, 1.0]])
    move = jnp.array(center) + jnp.array(trans)
    trans_m = jnp.array([[1.0, 0.0, move[0]],
                         [0.0, 1.0, move[1]],
                         [0.0, 0.0, 1.0]])
    # Make affine matrix.
    mat = jnp.identity(3, dtype=dtype)
    mat = jnp.matmul(center_m, mat)
    mat = jnp.matmul(scale_m, mat)
    mat = jnp.matmul(rot_m, mat)
    mat = jnp.matmul(skew_m, mat)
    mat = jnp.matmul(trans_m, mat)
    return mat.astype(dtype)

In [7]:
def apply_affine_jax(inputs, mat):
    # Apply transform.
    xy = inputs[:, :, :2]
    xy = jnp.concatenate([xy, np.ones([xy.shape[0], xy.shape[1], 1])], axis=-1)
    xy = jnp.einsum("...j,ij", xy, mat)
    inputs = inputs.at[:, :, :2].set(xy[:, :, :-1])
    return inputs

In [8]:
# Load data.
trackfile = Path("./finger_far0_non_static.npy")
reffile = Path("./finger_far0_non_static_affine.npy")
trackdata = np.load(trackfile)
refdata = np.load(reffile)
print(trackdata.shape)

# Remove person axis.
trackdata = trackdata[0]
refdata = refdata[0]

# Convert to jnp.array
trackdata = jnp.array(trackdata)
refdata = jnp.array(refdata)

(1, 130, 553, 4)




In [9]:
 # Get affine matrix.
center = jnp.array([638.0, 389.0])
trans = jnp.array([100.0, 0.0])
scale = jnp.array([2.0, 0.5])
rot = float(jnp.radians(15.0))
skew = jnp.radians(jnp.array([15.0, 15.0]))
dtype = jnp.float32
print("Parameters")
print("Center:", center)
print("Trans:", trans)
print("Scale:", scale)
print("Rot:", rot)
print("Skew:", skew)

Parameters
Center: [638. 389.]
Trans: [100.   0.]
Scale: [2.  0.5]
Rot: 0.2617993950843811
Skew: [0.2617994 0.2617994]


In [10]:
testtrack = trackdata.copy()
trial = 10

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_jax(testtrack, mat)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

# Evaluate difference.
diff = (jnp.round(newtrack) - jnp.round(refdata)).sum()

testtrack = trackdata.copy()

start = time.perf_counter()
for _ in range(trial):
    mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_jax(testtrack, mat)
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

print(f"Sum of error:{diff}")

Time of first call:0.8782596799999993
Average time:0.04856080149999968
Sum of error:0.0


In [11]:
testtrack = trackdata.copy()
trial = 10

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_jax(testtrack[:-1], mat)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

testtrack = trackdata.copy()

start = time.perf_counter()
for _ in range(trial):
    mat = get_affine_matrix_2d_jax(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_jax(testtrack[:-1], mat)
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:0.4352745529999993
Average time:0.051776887199999774


## 3.2 Based on define-and-run

In [12]:
from typing import Generic, TypeVar

T = TypeVar('T')      # Declare type variable

# Workaround to avoid unhashable error.
# https://github.com/google/jax/issues/4572
class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

In [13]:
@jit
def get_affine_matrix_2d_jax_jit(center,
                                 trans,
                                 scale,
                                 rot,
                                 skew):
    center_m = jnp.array([[1.0, 0.0, -center[0]],
                          [0.0, 1.0, -center[1]],
                          [0.0, 0.0, 1.0]])
    scale_m = jnp.array([[scale[0], 0.0, 0.0],
                         [0.0, scale[1], 0.0],
                         [0.0, 0.0, 1.0]])
    _cos = jnp.cos(rot)
    _sin = jnp.sin(rot)
    rot_m = jnp.array([[_cos, -_sin, 0.0],
                       [_sin, _cos, 0],
                       [0.0, 0.0, 1.0]])
    _tan = jnp.tan(skew)
    skew_m = jnp.array([[1.0, _tan[0], 0.0],
                        [_tan[1], 1.0, 0.0],
                        [0.0, 0.0, 1.0]])
    move = jnp.array(center) + jnp.array(trans)
    trans_m = jnp.array([[1.0, 0.0, move[0]],
                         [0.0, 1.0, move[1]],
                         [0.0, 0.0, 1.0]])
    # Make affine matrix.
    mat = jnp.identity(3)
    mat = jnp.matmul(center_m, mat)
    mat = jnp.matmul(scale_m, mat)
    mat = jnp.matmul(rot_m, mat)
    mat = jnp.matmul(skew_m, mat)
    mat = jnp.matmul(trans_m, mat)
    return mat

In [14]:
@partial(jit, static_argnums=(0,))
def apply_affine_jax_jit(inputs, mat):
    # Apply transform.
    xy = inputs[:, :, :2]
    xy = jnp.concatenate([xy, np.ones([xy.shape[0], xy.shape[1], 1])], axis=-1)
    xy = jnp.einsum("...j,ij", xy, mat)
    inputs = inputs.at[:, :, :2].set(xy[:, :, :-1])
    return inputs

In [15]:
testtrack = trackdata.copy()
trial = 10

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
newtrack = apply_affine_jax_jit(HashableArrayWrapper(testtrack), mat)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

# Evaluate difference.
diff = (jnp.round(newtrack) - jnp.round(refdata)).sum()

testtrack = trackdata.copy()

start = time.perf_counter()
for _ in range(trial):
    mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
    newtrack = apply_affine_jax_jit(HashableArrayWrapper(testtrack), mat)
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

print(f"Sum of error:{diff}")

Time of first call:0.5590976609999956
Average time:0.0068611328000002915
Sum of error:0.0


In [16]:
testtrack = trackdata.copy()
trial = 10

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
newtrack = apply_affine_jax_jit(HashableArrayWrapper(testtrack[:-1]), mat)
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

testtrack = trackdata.copy()

start = time.perf_counter()
for _ in range(trial):
    mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)
    newtrack = apply_affine_jax_jit(HashableArrayWrapper(testtrack[:-1]), mat)
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:0.42977556299999975
Average time:0.008777991899999903


# 4. Application to randomized transformation

## 4.1 Implement1: Call JIT function from a python process.

In [17]:
class RandomAffineTransform2D_JAX():
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None):
        self.center_joints = center_joints
        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = jnp.radians(jnp.array(rot_range))
        self.skew_range = jnp.radians(jnp.array(skew_range))
        if random_seed is not None:
            self.rng = jax.random.PRNGKey(random_seed)
        else:
            self.rng = jax.random.PRNGKey(0)

    def gen_uniform_and_update_key(self, low=0.0, high=1.0, shape=(1,)):
        # Generate random value.
        val = jax.random.uniform(self.rng, shape)
        # Scale to target range.
        val = (high - low) * val + low
        # Update key.
        self.rng = jax.random.split(self.rng, num=1)[0]
        return val

    def __call__(self, inputs):
        if self.gen_uniform_and_update_key() >= self.apply_ratio:
            return inputs

        # Calculate center position.
        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = jnp.sum(temp, axis=(1, 2)) != 0
        # Use x and y only.
        center = temp[mask].mean(axis=0).mean(axis=0)[:2]

        trans = self.gen_uniform_and_update_key(
            self.trans_range[0], self.trans_range[1], (2,))
        scale = self.gen_uniform_and_update_key(
            self.scale_range[0], self.scale_range[1], (2,))
        rot = self.gen_uniform_and_update_key(
            self.rot_range[0], self.rot_range[1], (1,))[0]
        skew = self.gen_uniform_and_update_key(
            self.skew_range[0], self.skew_range[1], (2,))

        # Calculate matrix.
        mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)

        # Apply transform.
        inputs = apply_affine_jax_jit(inputs, mat)
        return inputs

In [18]:
aug_fn = RandomAffineTransform2D_JAX(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0])

In [19]:
trial = 10
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(HashableArrayWrapper(trackdata.copy().astype(jnp.float32)))
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(trial):
  augtracks.append(aug_fn(HashableArrayWrapper(trackdata.copy().astype(jnp.float32))))
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:2.0656173279999948
Average time:0.02885811359999977


In [20]:
trial = 10
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(HashableArrayWrapper(trackdata.copy().astype(jnp.float32)[:-1]))
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(trial):
  augtracks.append(aug_fn(HashableArrayWrapper(trackdata.copy().astype(jnp.float32)[:-1])))
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:0.579926008000001
Average time:0.04163150070000014


## 4.2 Implementation2: Apply JIT to whole affine process.

In [21]:
class RandomAffineTransform2D_JAX_JIT():
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None,
                 dtype=np.float32):
        self.center_joints = center_joints
        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = jnp.radians(jnp.array(rot_range))
        self.skew_range = jnp.radians(jnp.array(skew_range))
        self.dtype = dtype
        if random_seed is not None:
            self.rng = jax.random.PRNGKey(random_seed)
        else:
            self.rng = jax.random.PRNGKey(0)

    def gen_uniform_and_update_key(self, rng, low=0.0, high=1.0, shape=(2,)):
        # Generate random value.
        val = jax.random.uniform(rng, shape)
        # Scale to target range.
        val = (high - low) * val + low
        # Update key.
        rng = jax.random.split(rng, num=1)[0]
        return val, rng

    def apply(self, inputs, rng):
        # Calculate center position.
        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = jnp.sum(temp, axis=(1, 2)) != 0
        mask = mask.astype(self.dtype)

        temp = temp * mask[:, None, None]
        mask_sum = jnp.sum(mask)
        # `[T, J, C] -> [J, C] -> [C]`
        center = temp.sum(axis=0) / mask_sum
        center = center.mean(axis=0)
        # Use x and y only.
        center = center[:2]

        trans, rng = self.gen_uniform_and_update_key(rng,
            self.trans_range[0], self.trans_range[1], (2,))
        scale, rng = self.gen_uniform_and_update_key(rng,
            self.scale_range[0], self.scale_range[1], (2,))
        rot, rng = self.gen_uniform_and_update_key(rng,
            self.rot_range[0], self.rot_range[1], (2,))
        rot = rot[0]
        skew, rng = self.gen_uniform_and_update_key(rng,
            self.skew_range[0], self.skew_range[1], (2,))

        # Calculate matrix.
        mat = get_affine_matrix_2d_jax_jit(center, trans, scale, rot, skew)

        # Apply transform.
        inputs = apply_affine_jax_jit(inputs, mat)
        return inputs, rng

    @partial(jit, static_argnums=(0,))
    def affine_proc(self, inputs, rng):
        val, rng = self.gen_uniform_and_update_key(rng)
        retval, rng = jax.lax.cond(
            (val >= self.apply_ratio).astype(jnp.int32)[0],
            lambda: (inputs, rng),
            lambda: self.apply(inputs, rng))
        return retval, rng

    def __call__(self, inputs):
        rng = self.rng
        retval, rng = self.affine_proc(inputs, rng)
        self.rng = rng
        return retval

In [22]:
aug_fn = RandomAffineTransform2D_JAX_JIT(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0],
    dtype=dtype)

In [23]:
trial = 10
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(trackdata.copy().astype(jnp.float32))
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(trial):
  augtracks.append(aug_fn(trackdata.copy().astype(jnp.float32)))
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:2.275207016000003
Average time:0.003445001200000064


In [24]:
trial = 10
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(trackdata.copy().astype(jnp.float32)[:-1])
interval = time.perf_counter() - start
print(f"Time of first call:{interval}")

start = time.perf_counter()
for _ in range(trial):
  augtracks.append(aug_fn(trackdata.copy().astype(jnp.float32)[:-1]))
interval = time.perf_counter() - start
print(f"Average time:{interval / trial}")

Time of first call:2.2039557909999985
Average time:0.004038998200000065
