In [1]:
from numba import cuda
import numpy as np
from timeit import Timer

In [2]:
@cuda.jit
def increment_data(data_arr):
    blockIdx = cuda.blockIdx.x
    blockDim = cuda.blockDim.x
    threadIdx = cuda.threadIdx.x
    # Compute flattened index inside the array
    pos = blockIdx * blockDim + threadIdx
    if threadIdx < blockDim:
        data_arr[pos] += (blockIdx + threadIdx)

In [3]:
def push_random_data_and_increment_timer(
    num_runs,
    num_envs,
    num_agents
):
    blocksPerGrid = num_envs
    threadsPerBlock = num_agents
    cuda.gridDim = (num_envs,)
    cuda.blockDim = (num_agents,)
    
    def push_random_data(num_agents, num_envs, device_data):
        random_data = np.random.rand(num_envs, num_agents)

        # Flatten in "C" order
        flattened_random_data = random_data.astype(np.float32).flatten(order="C")
        
        device_data["random_data"] = cuda.to_device(
            flattened_random_data
        )
        return device_data
      
    device_data = {}
    data_push_time = Timer(
        lambda: push_random_data(num_agents, num_envs, device_data)
    ).timeit(number=num_runs)
    program_run_time = Timer(
        lambda: increment_data[blocksPerGrid, threadsPerBlock](
            cuda.as_cuda_array(device_data["random_data"]))
    ).timeit(number=num_runs)

    return {
        "data push times": data_push_time,
        "code run time": program_run_time
    }

In [4]:
num_runs = 100
times = {}

for scenario in [
    (1, 1),
    (1, 100),
    (1, 1000),
    (100, 1000),
    (1000, 1000)
]:
    num_envs, num_agents = scenario
    times.update(
        {
            f"envs={num_envs}, agents={num_agents}":
            push_random_data_and_increment_timer(
                num_runs,
                num_envs,
                num_agents,
            )            
        }
    )
    
print(f"Times for {num_runs} function calls")
print("*"*40)
for key, value in times.items():
    print(f"{key:30}: mean data push times: {value['data push times']:10.5}s,\t mean increment times: {value['code run time']:10.5}s")    



Times for 100 function calls
****************************************
envs=1, agents=1              : mean data push times:    0.15361s,	 mean increment times:    0.31087s
envs=1, agents=100            : mean data push times:   0.027738s,	 mean increment times:   0.013628s
envs=1, agents=1000           : mean data push times:   0.028772s,	 mean increment times:    0.01363s
envs=100, agents=1000         : mean data push times:    0.13333s,	 mean increment times:   0.013784s
envs=1000, agents=1000        : mean data push times:     1.4457s,	 mean increment times:   0.012887s


WarpDrive times

```
Times for 100 function calls
****************************************
envs=1, agents=1              : mean data push times:   0.025804s,	 mean increment times:  0.0012021s
envs=1, agents=100            : mean data push times:   0.027073s,	 mean increment times:  0.0012057s
envs=1, agents=1000           : mean data push times:   0.026155s,	 mean increment times:   0.001203s
envs=100, agents=1000         : mean data push times:    0.12081s,	 mean increment times:  0.0011974s
envs=1000, agents=1000        : mean data push times:    0.98343s,	 mean increment times:  0.0012207s
```