In [1]:
import cupy as cp
import numpy as np
from scipy.spatial.transform import Rotation as R

In [2]:
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

In [3]:
ang = 30
x_angle = []
y_angle = []
z_angle = []

i = 0
while i < 360:
    x_angle.append(i)
    y_angle.append(i)
    i += ang

i = 0
while i <= 180:
    z_angle.append(i)
    i += ang

angle_comb = np.array(np.meshgrid(x_angle, y_angle, z_angle)).T.reshape(-1, 3)

In [4]:
def ang_to_mtx_ZYX(angle):
    r = R.from_euler("ZYX", angle, degrees=True)
    mtx = r.as_matrix()
    mtx[np.isclose(mtx, 0, atol=1e-15)] = 0
    return np.flip(mtx).T
    return np.flip(mtx).T.astype(np.float32)

In [5]:
%%time
rot_mtx_tensor = cp.array([ang_to_mtx_ZYX(ang) for ang in angle_comb])
rot_mtx_tensor.shape

Wall time: 214 ms


(1008, 3, 3)

In [6]:
%%time
dim = 48
cent = 0.5 * float(dim)
new_pos_arr = np.array(
    np.meshgrid(np.arange(dim), np.arange(dim), np.arange(dim),)
).T.reshape(-1, 3)
new_pos_arr = new_pos_arr - cent
new_pos_arr = new_pos_arr.astype(np.float32)

Wall time: 2.93 ms


In [7]:
new_pos_arr_gpu = cp.asarray(new_pos_arr)

In [8]:
%%time
new_pos_tensor = cp.repeat(
    new_pos_arr_gpu[cp.newaxis, :, :], len(rot_mtx_tensor), axis=0
)
new_pos_tensor.shape

Wall time: 125 ms


(1008, 110592, 3)

In [9]:
%%time
old_pos_tensor = cp.einsum("ikl,ijk->ijl", rot_mtx_tensor, new_pos_tensor, optimize=True).astype(cp.int32)
old_pos_tensor.shape

Wall time: 757 ms


(1008, 110592, 3)

In [10]:
print(mempool.used_bytes() / (1024**3))
print(mempool.total_bytes() / (1024**3))

2.4930028915405273
6.230551719665527


In [11]:
%%time
combined = np.hstack((old_pos_tensor.get(), new_pos_tensor.get()))

Wall time: 1.81 s


In [12]:
%%time
bool_mask_arr = (
    (old_pos_tensor[..., 0] >= 0)
    & (old_pos_tensor[..., 1] >= 0)
    & (old_pos_tensor[..., 2] >= 0)
    & (old_pos_tensor[..., 0] < dim)
    & (old_pos_tensor[..., 1] < dim)
    & (old_pos_tensor[..., 2] < dim)
)

Wall time: 31.3 ms


In [13]:
bool_mask_arr.shape

(1008, 110592)

In [14]:
print(mempool.used_bytes() / (1024**3))
print(mempool.total_bytes() / (1024**3))

2.5968236923217773
6.230551719665527


In [15]:
new_pos_tensor = None
old_pos_tensor = None
bool_mask_arr = None
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()