<a href="https://colab.research.google.com/github/takayama-rado/trado_samples/blob/main/colab_files/exp_track_affine_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Load library

In [1]:
# Standard modules.
import sys
import time

# CV/ML.
import numpy as np

import torch
import torch.nn as nn

In [2]:
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"Torch:{torch.__version__}")

Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
Torch:2.1.0+cu118


# 2. Load data

In [3]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy

--2023-10-28 08:38:49--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
Resolving github.com (github.com)... 20.29.134.23
Connecting to github.com (github.com)|20.29.134.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy [following]
--2023-10-28 08:38:49--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static.npy’


2023-10-28 08:38:49 (31.9 MB/s) - ‘finger_far0_non_static.npy’ saved [2300608/2300608]



In [4]:
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy

--2023-10-28 08:38:49--  https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_affine.npy
Resolving github.com (github.com)... 192.30.255.113
Connecting to github.com (github.com)|192.30.255.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy [following]
--2023-10-28 08:38:50--  https://raw.githubusercontent.com/takayama-rado/trado_samples/main/test_data/finger_far0_non_static_affine.npy
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2300608 (2.2M) [application/octet-stream]
Saving to: ‘finger_far0_non_static_affine.npy’


2023-10-28 08:38:50 (42.7 MB/s) - ‘finger_far0_non_static_aff

In [5]:
!ls

finger_far0_non_static_affine.npy  finger_far0_non_static.npy  sample_data


# 3. Implement affine transformation

## 3.1 Based on define-by-run

In [6]:
def get_affine_matrix_2d_torch(center,
                               trans,
                               scale,
                               rot,
                               skew,
                               dtype = torch.float32):
    center_m = torch.tensor([[1.0, 0.0, float(-center[0])],
                             [0.0, 1.0, float(-center[1])],
                             [0.0, 0.0, 1.0]], dtype=dtype)
    scale_m = torch.tensor([[float(scale[0]), 0.0, 0.0],
                            [0.0, float(scale[1]), 0.0],
                            [0.0, 0.0, 1.0]], dtype=dtype)
    _cos = torch.cos(rot)
    _sin = torch.sin(rot)
    rot_m = torch.tensor([[float(_cos), float(-_sin), 0.0],
                          [float(_sin), float(_cos), 0.0],
                          [0.0, 0.0, 1.0]], dtype=dtype)
    _tan = torch.tan(skew)
    skew_m = torch.tensor([[1.0, float(_tan[0]), 0.0],
                           [float(_tan[1]), 1.0, 0.0],
                           [0.0, 0.0, 1.0]], dtype=dtype)
    move = center + trans
    trans_m = torch.tensor([[1.0, 0.0, float(move[0])],
                            [0.0, 1.0, float(move[1])],
                            [0.0, 0.0, 1.0]], dtype=dtype)
    # Make affine matrix.
    mat = torch.eye(3, 3, dtype=dtype)
    mat = torch.matmul(center_m, mat)
    mat = torch.matmul(scale_m, mat)
    mat = torch.matmul(rot_m, mat)
    mat = torch.matmul(skew_m, mat)
    mat = torch.matmul(trans_m, mat)
    return mat.to(dtype)

In [7]:
def apply_affine_torch(inputs, mat):
    xy = inputs[:, :, :2]
    xy = torch.cat([xy, torch.ones([xy.shape[0], xy.shape[1], 1])], dim=-1)
    xy = torch.einsum("...j,ij", xy, mat)
    inputs[:, :, :2] = xy[:, :, :-1]
    return inputs

In [8]:
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str

In [9]:
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    print(f"Summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")

In [10]:
TRIALS = 100
TOPK = 10

In [11]:
# Load data.
trackfile = "./finger_far0_non_static.npy"
reffile = "./finger_far0_non_static_affine.npy"
trackdata = np.load(trackfile)
refdata = np.load(reffile)
print(trackdata.shape)

# Remove person axis.
trackdata = trackdata[0]
refdata = refdata[0]

(1, 130, 553, 4)


In [12]:
 # Get affine matrix.
center = torch.tensor([638.0, 389.0])
trans = torch.tensor([100.0, 0.0])
scale = torch.tensor([2.0, 0.5])
rot = torch.tensor(np.radians(15.0))
skew = torch.tensor(np.radians([15.0, 15.0]))
dtype = torch.float32
print("Parameters")
print("Center:", center)
print("Trans:", trans)
print("Scale:", scale)
print("Rot:", rot)
print("Skew:", skew)

Parameters
Center: tensor([638., 389.])
Trans: tensor([100.,   0.])
Scale: tensor([2.0000, 0.5000])
Rot: tensor(0.2618, dtype=torch.float64)
Skew: tensor([0.2618, 0.2618], dtype=torch.float64)


In [13]:
testtrack = torch.tensor(trackdata.copy().astype(np.float32))

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_torch(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_torch(testtrack, mat)
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

# Evaluate difference.
diff = (np.round(newtrack.detach().cpu().numpy()) - np.round(refdata)).sum()

testtrack = torch.tensor(trackdata.copy().astype(np.float32))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    mat = get_affine_matrix_2d_torch(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_torch(testtrack, mat)
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

print(f"Sum of error:{diff}")

Time of first call
Summary: Max 117.289ms, Min 117.289ms, Mean +/- Std 117.289ms +/-   0s
Time after second call
Summary: Max 2.67086ms, Min 2.25075ms, Mean +/- Std 2.53925ms +/- 140.011µs
Sum of error:0.0


In [14]:
testtrack = torch.tensor(trackdata.copy().astype(np.float32))

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_torch(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_torch(testtrack[:-1], mat)
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

testtrack = torch.tensor(trackdata.copy().astype(np.float32))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    mat = get_affine_matrix_2d_torch(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_torch(testtrack[:-1], mat)
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 23.3252ms, Min 23.3252ms, Mean +/- Std 23.3252ms +/-   0s
Time after second call
Summary: Max 2.80755ms, Min 2.69608ms, Mean +/- Std 2.75717ms +/- 29.3817µs


## 3.2 Based on define-and-run

In [15]:
@torch.jit.script
def get_affine_matrix_2d_torch_jit(center: torch.Tensor,
                                   trans: torch.Tensor,
                                   scale: torch.Tensor,
                                   rot: torch.Tensor,
                                   skew: torch.Tensor,
                                   dtype: torch.dtype = torch.float32):
    center_m = torch.tensor([[1.0, 0.0, float(-center[0])],
                             [0.0, 1.0, float(-center[1])],
                             [0.0, 0.0, 1.0]], dtype=dtype)
    scale_m = torch.tensor([[float(scale[0]), 0.0, 0.0],
                            [0.0, float(scale[1]), 0.0],
                            [0.0, 0.0, 1.0]], dtype=dtype)
    _cos = torch.cos(rot)
    _sin = torch.sin(rot)
    rot_m = torch.tensor([[float(_cos), float(-_sin), 0.0],
                          [float(_sin), float(_cos), 0.0],
                          [0.0, 0.0, 1.0]], dtype=dtype)
    _tan = torch.tan(skew)
    skew_m = torch.tensor([[1.0, float(_tan[0]), 0.0],
                           [float(_tan[1]), 1.0, 0.0],
                           [0.0, 0.0, 1.0]], dtype=dtype)
    move = center + trans
    trans_m = torch.tensor([[1.0, 0.0, float(move[0])],
                            [0.0, 1.0, float(move[1])],
                            [0.0, 0.0, 1.0]], dtype=dtype)
    # Make affine matrix.
    mat = torch.eye(3, 3, dtype=dtype)
    mat = torch.matmul(center_m, mat)
    mat = torch.matmul(scale_m, mat)
    mat = torch.matmul(rot_m, mat)
    mat = torch.matmul(skew_m, mat)
    mat = torch.matmul(trans_m, mat)
    return mat.to(dtype)

In [16]:
@torch.jit.script
def apply_affine_torch_jit(inputs, mat):
    xy = inputs[:, :, :2]
    xy = torch.cat([xy, torch.ones([xy.shape[0], xy.shape[1], 1])], dim=-1)
    xy = torch.einsum("...j,ij", xy, mat)
    inputs[:, :, :2] = xy[:, :, :-1]
    return inputs

In [17]:
testtrack = torch.tensor(trackdata.copy().astype(np.float32))

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_torch_jit(testtrack, mat)
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

# Evaluate difference.
diff = (np.round(newtrack.detach().cpu().numpy()) - np.round(refdata)).sum()

testtrack = torch.tensor(trackdata.copy().astype(np.float32))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_torch_jit(testtrack, mat)
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

print(f"Sum of error:{diff}")

Time of first call
Summary: Max 244.887ms, Min 244.887ms, Mean +/- Std 244.887ms +/-   0s
Time after second call
Summary: Max 2.74402ms, Min 2.58267ms, Mean +/- Std 2.66269ms +/- 51.9464µs
Sum of error:0.0


In [18]:
testtrack = torch.tensor(trackdata.copy().astype(np.float32))

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew, dtype=dtype)
newtrack = apply_affine_torch_jit(testtrack[:-1], mat)
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

testtrack = torch.tensor(trackdata.copy().astype(np.float32))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew, dtype=dtype)
    newtrack = apply_affine_torch_jit(testtrack[:-1], mat)
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 3.45068ms, Min 3.45068ms, Mean +/- Std 3.45068ms +/-   0s
Time after second call
Summary: Max 2.61029ms, Min 2.49749ms, Mean +/- Std 2.56817ms +/- 36.3041µs


# 4. Application to randomized transformation

## 4.1 Implementation1: Call JIT function from a python process

In [19]:
class RandomAffineTransform2D_Torch():
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None,
                 device="cpu",
                 dtype=torch.float32):

        self.center_joints = center_joints
        if isinstance(self.center_joints, int):
            self.center_joints = [self.center_joints]

        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = np.radians(rot_range).tolist()
        self.skew_range = np.radians(skew_range).tolist()
        self.dtype = dtype
        self.rng = torch.Generator(device=device)
        if random_seed is not None:
            self.rng.manual_seed(random_seed)

    def __call__(self, inputs):
        if torch.rand(1, generator=self.rng) >= self.apply_ratio:
            return inputs

        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = temp.sum(dim=(1, 2)) != 0
        # Use x and y only.
        center = temp[mask].mean(dim=0).mean(dim=0)[:2]

        # Random value in [0, 1].
        trans = torch.rand(2, generator=self.rng)
        scale = torch.rand(2, generator=self.rng)
        rot = torch.rand(1, generator=self.rng)
        skew = torch.rand(2, generator=self.rng)
        # Scale to target range.
        trans = (self.trans_range[1] - self.trans_range[0]) * trans + self.trans_range[0]
        scale = (self.scale_range[1] - self.scale_range[0]) * scale + self.scale_range[0]
        rot = (self.rot_range[1] - self.rot_range[0]) * rot + self.rot_range[0]
        skew = (self.skew_range[1] - self.skew_range[0]) * skew + self.skew_range[0]

        # Calculate matrix.
        mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew,
            dtype=self.dtype)

        # Apply transform.
        inputs = apply_affine_torch_jit(inputs, mat)
        return inputs

In [20]:
aug_fn = RandomAffineTransform2D_Torch(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0],
    dtype=dtype)

In [21]:
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(torch.tensor(trackdata.copy().astype(np.float32)))
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    augtracks.append(aug_fn(torch.tensor(trackdata.copy().astype(np.float32))))
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 51.1747ms, Min 51.1747ms, Mean +/- Std 51.1747ms +/-   0s
Time after second call
Summary: Max 4.80058ms, Min 4.2865ms, Mean +/- Std 4.48785ms +/- 189.306µs


In [22]:
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(torch.tensor(trackdata.copy().astype(np.float32)[:-1]))
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    augtracks.append(aug_fn(torch.tensor(trackdata.copy().astype(np.float32))[:-1]))
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 5.91589ms, Min 5.91589ms, Mean +/- Std 5.91589ms +/-   0s
Time after second call
Summary: Max 4.17646ms, Min 4.00562ms, Mean +/- Std 4.12132ms +/- 52.0244µs


## 4.2 Implementation2: Apply JIT to whole affine process (JIT compile nn.Module).

In [23]:
class RandomAffineTransform2D_TorchModule(nn.Module):
    def __init__(self,
                 center_joints,
                 apply_ratio,
                 trans_range,
                 scale_range,
                 rot_range,
                 skew_range,
                 random_seed=None,
                 device="cpu",
                 dtype=torch.float32):
        super().__init__()

        self.center_joints = center_joints
        if isinstance(self.center_joints, int):
            self.center_joints = [self.center_joints]

        self.apply_ratio = apply_ratio
        self.trans_range = trans_range
        self.scale_range = scale_range
        self.rot_range = np.radians(rot_range).tolist()
        self.skew_range = np.radians(skew_range).tolist()
        self.dtype = dtype
        # self.rng = torch.Generator(device=device)
        # if random_seed is not None:
        #     self.rng.manual_seed(random_seed)
        self.rng = None

    def forward(self, inputs):
        if torch.rand(1, generator=self.rng) >= self.apply_ratio:
            return inputs

        temp = inputs[:, self.center_joints, :]
        temp = temp.reshape([inputs.shape[0], -1, inputs.shape[-1]])
        mask = temp.sum(dim=(1, 2)) != 0
        # Use x and y only.
        center = temp[mask].mean(dim=0).mean(dim=0)[:2]

        # Random value in [0, 1].
        trans = torch.rand(2, generator=self.rng)
        scale = torch.rand(2, generator=self.rng)
        rot = torch.rand(1, generator=self.rng)
        skew = torch.rand(2, generator=self.rng)
        # Scale to target range.
        trans = (self.trans_range[1] - self.trans_range[0]) * trans + self.trans_range[0]
        scale = (self.scale_range[1] - self.scale_range[0]) * scale + self.scale_range[0]
        rot = (self.rot_range[1] - self.rot_range[0]) * rot + self.rot_range[0]
        skew = (self.skew_range[1] - self.skew_range[0]) * skew + self.skew_range[0]

        # Calculate matrix.
        mat = get_affine_matrix_2d_torch_jit(center, trans, scale, rot, skew,
            dtype=self.dtype)

        # Apply transform.
        inputs = apply_affine_torch_jit(inputs, mat)
        return inputs

In [24]:
aug_fn = RandomAffineTransform2D_TorchModule(
    center_joints=[11, 12],
    apply_ratio=1.0,
    trans_range=[-100.0, 100.0],
    scale_range=[0.5, 2.0],
    rot_range=[-30.0, 30.0],
    skew_range=[-30.0, 30.0],
    dtype=dtype)
aug_fn = torch.jit.script(aug_fn)

In [25]:
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(torch.tensor(trackdata.copy().astype(np.float32)))
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    augtracks.append(aug_fn(torch.tensor(trackdata.copy().astype(np.float32))))
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 40.84ms, Min 40.84ms, Mean +/- Std 40.84ms +/-   0s
Time after second call
Summary: Max 3.02665ms, Min 2.61142ms, Mean +/- Std 2.85113ms +/- 126.827µs


In [26]:
augtracks = []

# The 1st call may be slow because of the computation graph construction.
start = time.perf_counter()
temp = aug_fn(torch.tensor(trackdata.copy().astype(np.float32)[:-1]))
interval = time.perf_counter() - start
print("Time of first call")
print_perf_time(np.array(interval))

intervals = []
for _ in range(TRIALS):
    start = time.perf_counter()
    augtracks.append(aug_fn(torch.tensor(trackdata.copy().astype(np.float32))[:-1]))
    end = time.perf_counter()
    intervals.append(end - start)
intervals = np.array(intervals)
print("Time after second call")
print_perf_time(intervals, TOPK)

Time of first call
Summary: Max 11.5391ms, Min 11.5391ms, Mean +/- Std 11.5391ms +/-   0s
Time after second call
Summary: Max 2.71104ms, Min 2.48304ms, Mean +/- Std 2.61997ms +/- 76.097µs
