This demonstrates a more memory- and time-efficient approach to rotation, using jit.rawkernel in place.

In [1]:
import numpy as np
import cupy as cp
from cupyx import jit
from typing import Tuple
print(f"CuPy version {cp.__version__}")
mempool = cp.get_default_memory_pool()
mempool.free_all_blocks()
print(f"mempool.used_bytes {mempool.used_bytes()}")
print(f"mempool.total_bytes {mempool.total_bytes()}")

CuPy version 11.0.0
mempool.used_bytes 0
mempool.total_bytes 0


In [2]:
@jit.rawkernel(device=True)
def any_perpendicular(vx: float, vy: float, vz: float) -> Tuple[float, float, float]:
    if vz < vx:
        return (vy, -vx, 0.0)
    return (0.0, -vz, vy)

@jit.rawkernel(device=True)
def normalize(x: float, y: float, z: float) -> Tuple[float, float, float]:
    n = cp.sqrt(x * x + y * y + z * z)
    return (x/n, y/n, z/n)

@jit.rawkernel(device=True)
def unitary_perpendicular(vx: float, vy: float, vz: float) -> Tuple[float, float, float]:
    (ux, uy, uz) = any_perpendicular(vx, vy, vz)
    return normalize(ux, uy, uz)    

@jit.rawkernel(device=True)
def do_rotation(X: float, Y: float, Z: float, 
                ux: float, uy: float, uz: float,
                theta: float) -> Tuple[float, float, float]:
    """ Rotate v around u. """
    cost = cp.cos(theta)
    sint = cp.sin(theta)
    one_cost = 1 - cost
           
    x = (cost + ux * ux * one_cost) * X + (ux * uy * one_cost - uz * sint) * Y + (
            ux * uz * one_cost + uy * sint) * Z
    y = (uy * ux * one_cost + uz * sint) * X + (cost + uy * uy * one_cost) * Y + (
            uy * uz * one_cost - ux * sint) * Z
    z = (uz * ux * one_cost - uy * sint) * X + (uz * uy * one_cost + ux * sint) * Y + (
            cost + uz * uz * one_cost) * Z
    
    return (x, y, z)

@jit.rawkernel()
def scatter(vx: cp.ndarray, vy: cp.ndarray, vz: cp.ndarray,
            theta: cp.ndarray, phi: cp.ndarray, size: float) -> None:
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        (ux, uy, uz) = unitary_perpendicular(vx[i], vy[i], vz[i])

        # first rotate the perpendicular around the photon axis
        (ux, uy, uz) = do_rotation(ux, uy, uz, vx[i], vy[i], vz[i], phi[i])
            
        # then rotate the photon around that perpendicular
        (vx[i], vy[i], vz[i]) = do_rotation(vx[i], vy[i], vz[i], ux, uy, uz, theta[i])

  cupy._util.experimental('cupyx.jit.rawkernel')


In [3]:
size = 50000000
vx = cp.random.random(size, dtype=np.float32)
print(f"mempool.used_bytes {mempool.used_bytes()}")
vy = cp.random.random(size, dtype=np.float32)
print(f"mempool.used_bytes {mempool.used_bytes()}")
vz = cp.random.random(size, dtype=np.float32)
print(f"mempool.used_bytes {mempool.used_bytes()}")
phi = cp.random.random(size, dtype=np.float32)
theta = cp.full(size, np.pi/2, dtype=np.float32)
print(f"mempool.used_bytes {mempool.used_bytes()}")
## for perpendicularity check. remove to see memory limits
#rx = cp.copy(vx)
#print(rx)
#print(vx)
#ry = cp.copy(vy)
#rz = cp.copy(vz)

mempool.used_bytes 200000000
mempool.used_bytes 400000000
mempool.used_bytes 600000000
mempool.used_bytes 1000000000


In [4]:
%%time
scatter((128,),(1024,),(vx, vy, vz, theta, phi, size))

CPU times: user 312 ms, sys: 11.8 ms, total: 324 ms
Wall time: 322 ms


In [5]:
%%time
scatter((128,),(1024,),(vx, vy, vz, theta, phi, size))

CPU times: user 90 µs, sys: 0 ns, total: 90 µs
Wall time: 92.5 µs


In [6]:
print(f"mempool.used_bytes {mempool.used_bytes()}")
# check perpendicularity
#print(rx)
#print(vx)
#dot = (vx * rx + vy * ry + vz * rz)
#print(cp.amax(dot))

mempool.used_bytes 1000000000


In [7]:
del vx, vy, vz, theta, phi
#del rx, ry, rz