Cribbed from PyTissueOptics CuPy version.

In [1]:
import numpy as np
import cupy as cp
from cupyx import jit
from typing import Tuple
print(f"CuPy version {cp.__version__}")

CuPy version 11.0.0


In [2]:
mempool = cp.get_default_memory_pool()
mempool.free_all_blocks()
print(f"mempool.used_bytes {mempool.used_bytes()}")

mempool.used_bytes 0


In [3]:
@jit.rawkernel()
def _hanley(random_inout: cp.ndarray, g: float, size: int) -> cp.ndarray:
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        temp = (1 - g * g) / (1 - g + random_inout[i])
        cost = (1 + g * g - temp * temp) / (2 * g)
        random_inout[i] = cp.arccos(cost)

def get_scattering_theta(g: float, size: int) -> cp.ndarray:
    random_input = cp.random.uniform(0, 2 * g, size, dtype=np.float32)
    _hanley((128,),(1024,),(random_input, g, size))
    return random_input

  cupy._util.experimental('cupyx.jit.rawkernel')


In [4]:
def get_scattering_phi(size: float) -> cp.ndarray:
    return cp.random.uniform(0, 2 * np.pi, size)

In [5]:
@jit.rawkernel(device=True)
def any_perpendicular(vx: float, vy: float, vz: float) -> Tuple[float, float, float]:
    if vz < vx:
        return (vy, -vx, 0.0)
    return (0.0, -vz, vy)

@jit.rawkernel(device=True)
def normalize(x: float, y: float, z: float) -> Tuple[float, float, float]:
    n = cp.sqrt(x * x + y * y + z * z)
    return (x/n, y/n, z/n)

@jit.rawkernel(device=True)
def unitary_perpendicular(vx: float, vy: float, vz: float) -> Tuple[float, float, float]:
    (ux, uy, uz) = any_perpendicular(vx, vy, vz)
    return normalize(ux, uy, uz)    

@jit.rawkernel(device=True)
def do_rotation(X: float, Y: float, Z: float, 
                ux: float, uy: float, uz: float,
                theta: float) -> Tuple[float, float, float]:
    """ Rotate v around u. """
    cost = cp.cos(theta)
    sint = cp.sin(theta)
    one_cost = 1 - cost
           
    x = (cost + ux * ux * one_cost) * X + (ux * uy * one_cost - uz * sint) * Y + (
            ux * uz * one_cost + uy * sint) * Z
    y = (uy * ux * one_cost + uz * sint) * X + (cost + uy * uy * one_cost) * Y + (
            uy * uz * one_cost - ux * sint) * Z
    z = (uz * ux * one_cost - uy * sint) * X + (uz * uy * one_cost + ux * sint) * Y + (
            cost + uz * uz * one_cost) * Z
    
    return (x, y, z)

@jit.rawkernel()
def scatter(vx: cp.ndarray, vy: cp.ndarray, vz: cp.ndarray,
            theta: cp.ndarray, phi: cp.ndarray, size: int) -> None:
    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    ntid = jit.gridDim.x * jit.blockDim.x
    for i in range(tid, size, ntid):
        (ux, uy, uz) = unitary_perpendicular(vx[i], vy[i], vz[i])

        # first rotate the perpendicular around the photon axis
        (ux, uy, uz) = do_rotation(ux, uy, uz, vx[i], vy[i], vz[i], phi[i])
            
        # then rotate the photon around that perpendicular
        (vx[i], vy[i], vz[i]) = do_rotation(vx[i], vy[i], vz[i], ux, uy, uz, theta[i])

In [6]:
g = 0.9
size = 50000000

In [7]:
vx = cp.random.random(size, dtype=np.float32)
vy = cp.random.random(size, dtype=np.float32)
vz = cp.random.random(size, dtype=np.float32)

In [8]:
%%time
phi = get_scattering_phi(size)
theta = get_scattering_theta(g, size)
scatter((128,),(1024,),(vx, vy, vz, theta, phi, size))

CPU times: user 133 ms, sys: 2.63 ms, total: 136 ms
Wall time: 135 ms


In [9]:
%%time
phi = get_scattering_phi(size)
theta = get_scattering_theta(g, size)
scatter((128,),(1024,),(vx, vy, vz, theta, phi, size))

CPU times: user 0 ns, sys: 1.81 ms, total: 1.81 ms
Wall time: 1.34 ms
