# Use numba.cuda with cupy

In [1]:
import cupy as cp
from numba import cuda

@cuda.jit
def add(x, y, out):
        start = cuda.grid(1)
        stride = cuda.gridsize(1)
        for i in range(start, x.shape[0], stride):
                out[i] = x[i] + y[i]

a = cp.arange(10, dtype=float)
b = a * 2
out = cp.zeros_like(a)
add[1, 32](a, b, out)
print(out)

[ 0.  3.  6.  9. 12. 15. 18. 21. 24. 27.]


In [2]:
type(out)

cupy.core.core.ndarray

In [3]:
out.dtype

dtype('float64')

In [4]:
out.flags

  C_CONTIGUOUS : True
  F_CONTIGUOUS : True
  OWNDATA : True

# Numba share memory

## First, one without shared memory

In [7]:
from math import log2 as math_log2

@cuda.jit
def pt_wf_keep_fisrt_1_kernel(iwf, iwf_conj, rho, m, m_idx, n_idx):
    i, j = cuda.grid(2)
    if i >= rho.shape[0] or j >= rho.shape[1]:
        return
    i_shift = i << m
    j_shift = j << m
    for k in range(m_idx):
        rho[i_shift, j_shift] += iwf[i_shift + k] * iwf_conj[j_shift+ k]
def pt_wf_keep_fisrt_1(iwf: cp.ndarray, n):
    iwf_conj = iwf.conj()
    nqb = int(math_log2(iwf.shape[0]))
    m = nqb - n
    m_idx = 2 ** m
    n_idx = 2 ** n

    rho = cp.zeros(shape=(n_idx, n_idx), dtype=iwf.dtype, order='C')
    rho_max_length = rho.shape[0] * rho.shape[1]
    # Here we simply use the threadDim for i, j in the cuda code.
    threadDim = (n_idx, n_idx)
    blockDim = (1,)
    # TODO cleverly adjust the blockDim to deal with the other case
    pt_wf_keep_fisrt_1_kernel[blockDim, threadDim](
        iwf, iwf_conj, rho, m, m_idx, n_idx
    )
    return rho

In [8]:
wf = cp.ones(shape=(2)) / cp.sqrt(2)

rho = cp.ones(shape=(2,2), dtype=wf.dtype)

pt_wf_keep_fisrt_1(wf, 0)

array([[1.]])

## Second, one with cupy built-in functions

In [28]:
def pt_wf_keep_fisrt_2(iwf, n):
    n_range = 2 ** n
    m_range = 2 ** (int(math_log2(iwf.shape[0])) - n)
    iwf = iwf.reshape(n_range, m_range, order='C')
    iwf_conj = iwf.conj()
    ret = cp.zeros(shape=(n_range, n_range), dtype=iwf.dtype)
    for i in range(n_range):
        for j in range(n_range):
            ret[i,j] = iwf[i,:].dot(iwf_conj[j,:])
    return ret

In [29]:
pt_wf_keep_fisrt_2(wf, 0)

array([[1.]])

## Compare first and second

In [30]:
arr = cp.ones(shape=(2**18), dtype=cp.complex64)
arr /= cp.linalg.norm(arr, ord=2)

In [32]:
pt_wf_keep_fisrt_2(arr, 2)

array([[0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j]], dtype=complex64)

In [33]:
from mlec.q_toolkit.cupy_impl import partial_trace_wf_keep_first_cupy

In [34]:
partial_trace_wf_keep_first_cupy(arr, 2)

array([[0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j],
       [0.25+0.j, 0.25+0.j, 0.25+0.j, 0.25+0.j]], dtype=complex64)

In [35]:
def timeit(f, count=1000, **kwargs):
    import time
    t0 = time.time()
    for _ in range(count):
        f(**kwargs)
    print(f"Time costs for {count} loops: {time.time()-t0}")

In [39]:
timeit(pt_wf_keep_fisrt_2, iwf=arr, n=2)

Time costs for 1000 loops: 0.6749603748321533


In [40]:
timeit(partial_trace_wf_keep_first_cupy, iwf=arr, n=2)

Time costs for 1000 loops: 3.800832509994507


However, in my experiment with 833 code, the `partial_trace_wf_keep_first_cupy` performs way better than the `pt_wf_keep_fisrt_2`. I am not sure why, but one reason maybe that in 833 code, I need to permute the qubit indices to move the `retain_qubits` to the front, which might be causing the performance degragation.