In [2]:
import wgpu
from functools import reduce
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
import math
import time
import torch
import torch_directml

In [3]:
adapter_0 = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
device_0 = adapter_0.request_device_sync(
    required_features=[wgpu.FeatureName.timestamp_query]
)
pprint(adapter_0.info)

adapter_1 = wgpu.gpu.request_adapter_sync(power_preference="low-power")
device_1 = adapter_1.request_device_sync(
    required_features=[wgpu.FeatureName.timestamp_query]
)
pprint(adapter_1.info)

{'adapter_type': 'IntegratedGPU',
 'architecture': '',
 'backend_type': 'Vulkan',
 'description': '24.12.1 (AMD proprietary shader compiler)',
 'device': 'AMD Radeon(TM) Graphics',
 'device_id': 5761,
 'vendor': 'AMD proprietary driver',
 'vendor_id': 4098}
{'adapter_type': 'IntegratedGPU',
 'architecture': '',
 'backend_type': 'Vulkan',
 'description': '24.12.1 (AMD proprietary shader compiler)',
 'device': 'AMD Radeon(TM) Graphics',
 'device_id': 5761,
 'vendor': 'AMD proprietary driver',
 'vendor_id': 4098}


In [4]:
if 0:
    adapter = adapter_0
    device = device_0
else:
    adapter = adapter_1
    device = device_1

pprint(adapter.info)

{'adapter_type': 'IntegratedGPU',
 'architecture': '',
 'backend_type': 'Vulkan',
 'description': '24.12.1 (AMD proprietary shader compiler)',
 'device': 'AMD Radeon(TM) Graphics',
 'device_id': 5761,
 'vendor': 'AMD proprietary driver',
 'vendor_id': 4098}


In [5]:
# x, y, z
grid_size = [32, 128, 1024]
#grid_size = [17, 107, 152]
total_cells = grid_size[0]*grid_size[1]*grid_size[2]

print(f"grid_size={grid_size}")
print(f"total_cells={total_cells}")

# create cpu side data
n_dims = 3
x_cpu = np.zeros(grid_size + [n_dims,], dtype=np.float32)
y_cpu = np.zeros(grid_size + [n_dims,], dtype=np.float32)

x = np.reshape(x_cpu, (n_dims*total_cells,))
x[:] = (-1.0*np.arange(0, n_dims*total_cells, dtype=np.float32) + 0.5) % 0.88490

# Create buffer objects, input buffer is mapped.
x_gpu = device.create_buffer_with_data(data=x_cpu.data, usage=wgpu.BufferUsage.STORAGE)
y_gpu = device.create_buffer_with_data(
    data=y_cpu.data,
    usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC
)
y_gpu_readback = device.create_buffer(
    size=y_cpu.data.nbytes,
    usage=wgpu.BufferUsage.MAP_READ | wgpu.BufferUsage.COPY_DST
)

# Setup layout and bindings
binding_layouts = [
    {
        "binding": 0,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.read_only_storage,
        },
    },
    {
        "binding": 1,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.storage,
        },
    },
]

# Put everything together
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])

grid_size=[32, 128, 1024]
total_cells=4194304


In [6]:
# create shader
# local_size = [1, 8, 128]
local_size = [1, 4, 64]
dispatch_size = [math.ceil(g/l) for g,l in zip(grid_size, local_size)]
global_size = [d*l for d,l in zip(dispatch_size, local_size)]
flops_per_cell = 12

print("### SEARCH ###")
print(f"local_size={local_size}")
print(f"dispatch_size={dispatch_size}")
print(f"grid_size={grid_size}")
print(f"global_size={global_size}")
#assert(global_size == grid_size)

shader_source = f"""
@group(0) @binding(0)
var<storage,read> x: array<f32>;

@group(0) @binding(1)
var<storage,read_write> y: array<f32>;

fn get_offset(i: vec3<u32>) -> u32 {{
    let x: u32 = i.x % {grid_size[0]};
    let y: u32 = i.y % {grid_size[1]};
    let z: u32 = i.z % {grid_size[2]};
    let offset: u32 = z + y*{grid_size[2]} + x*{grid_size[2]*grid_size[1]};
    return offset*{n_dims};
}}

@compute
@workgroup_size({",".join(map(str, local_size))})
fn main(@builtin(global_invocation_id) i0: vec3<u32>) {{
    if (i0.x >= {grid_size[0]}) {{ return; }}
    if (i0.y >= {grid_size[1]}) {{ return; }}
    if (i0.z >= {grid_size[2]}) {{ return; }}
    let i = get_offset(i0);
    let iz = get_offset(i0 + vec3(0,0,1));
    let iy = get_offset(i0 + vec3(0,1,0));
    let ix = get_offset(i0 + vec3(1,0,0));
    y[i+0] += (x[i+2]-x[iy+2]) - (x[i+1]-x[iz+1]);
    y[i+1] += (x[i+0]-x[iz+0]) - (x[i+2]-x[ix+2]);
    y[i+2] += (x[i+1]-x[ix+1]) - (x[i+0]-x[iy+0]);
}}
"""

cshader = device.create_shader_module(code=shader_source)

# Create and run the pipeline
compute_pipeline = device.create_compute_pipeline(
    layout=pipeline_layout,
    compute={"module": cshader, "entry_point": "main"},
)

"""
Create a QuerySet to store the 'beginning_of_pass' and 'end_of_pass' timestamps.
Set the 'count' parameter to 2, as this set will contain 2 timestamps.
"""
# query_count = 6
query_count = 2
query_set = device.create_query_set(type=wgpu.QueryType.timestamp, count=query_count)

"""
Create the buffer to store our query results.
Each timestamp is 8 bytes. We mark the buffer usage to be QUERY_RESOLVE,
as we will use this buffer in a resolve_query_set call later.
"""
query_buf = device.create_buffer(
    size=8*query_set.count,
    usage=wgpu.BufferUsage.QUERY_RESOLVE | wgpu.BufferUsage.COPY_SRC,
)

# Pass our QuerySet and the indices into it, where the timestamps will be written.
gpu_samples = []
total_ns = 0
max_samples = 128
min_samples = 3
sampling_timeout_ms = 5

for sample_idx in range(max_samples):
    command_encoder = device.create_command_encoder()
    
    compute_pass_0 = command_encoder.begin_compute_pass(
        timestamp_writes={
            "query_set": query_set,
            "beginning_of_pass_write_index": 0,
            "end_of_pass_write_index": 1,
        }
    )
    compute_pass_0.set_pipeline(compute_pipeline)
    bindings = [
        {
            "binding": 0,
            "resource": {"buffer": x_gpu, "offset": 0, "size": x_gpu.size},
        },
        {
            "binding": 1,
            "resource": {"buffer": y_gpu, "offset": 0, "size": y_gpu.size},
        },
    ]
    bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
    compute_pass_0.set_bind_group(0, bind_group)
    compute_pass_0.dispatch_workgroups(*dispatch_size)  # x y z
    compute_pass_0.end()
    
    # Resolve our queries, and store the results in the destination buffer we created above.
    command_encoder.resolve_query_set(
        query_set=query_set,
        first_query=0,
        query_count=query_count,
        destination=query_buf,
        destination_offset=0,
    )
    command_encoder.copy_buffer_to_buffer(
        source=y_gpu,
        source_offset=0,
        destination=y_gpu_readback,
        destination_offset=0,
        size=y_cpu.data.nbytes
    )
    device.queue.submit([command_encoder.finish()])
    
    """
    Read the query buffer to get the timestamps.
    Index 0: beginning timestamp
    Index 1: end timestamp
    """
    timestamps_ns = device.queue.read_buffer(query_buf).cast("Q").tolist()
    delta_ns = timestamps_ns[1] - timestamps_ns[0]
    gpu_samples.append(delta_ns)
    total_ns += delta_ns
    if sample_idx >= min_samples and total_ns > sampling_timeout_ms*1e6:
        print("Exceeded timeout limit for sample collection")
        break

gpu_delta_ns = np.array(gpu_samples)
gpu_delta_ns_avg = np.mean(gpu_delta_ns)
print(f"cell_count={total_cells}")
print(f"gpu_delta_avg={gpu_delta_ns_avg*1e-3:.3f} us")
gpu_cell_rate = total_cells / (gpu_delta_ns_avg*1e-9)
print(f"gpu_cell_rate={gpu_cell_rate*1e-6:.3f} M/s")
print(f"total_samples={len(gpu_samples)}")
gpu_flops = flops_per_cell*gpu_cell_rate
print(f"gpu_flops={gpu_flops*1e-9:.3f} GFlops")

### SEARCH ###
local_size=[1, 4, 64]
dispatch_size=[32, 32, 16]
grid_size=[32, 128, 1024]
global_size=[32, 128, 1024]
Exceeded timeout limit for sample collection
cell_count=4194304
gpu_delta_avg=368.223 us
gpu_cell_rate=11390.658 M/s
total_samples=14
gpu_flops=136.688 GFlops


In [7]:
# Read result
y_gpu_readback.map_sync(wgpu.MapMode.READ, 0, y_cpu.data.nbytes)
y_gpu_pred_memview = y_gpu_readback.read_mapped(buffer_offset=0, size=y_cpu.data.nbytes, copy=True)
y_gpu_pred = np.frombuffer(y_gpu_pred_memview, dtype=np.float32)
y_gpu_pred = np.reshape(y_gpu_pred, grid_size+[n_dims,])
y_gpu_readback.unmap()

In [8]:
def cpu_shader(x, y, wrap_around=True):
    if not wrap_around:
        y[:,:-1,:-1,0] += (x[:,:-1,:-1,2]-x[:,1:,:-1,2]) - (x[:,:-1,:-1,1]-x[:,:-1,1:,1])
        y[:-1,:,:-1,1] += (x[:-1,:,:-1,0]-x[:-1,:,1:,0]) - (x[:-1,:,:-1,2]-x[1:,:,:-1,2])
        y[:-1,:-1,:,2] += (x[:-1,:-1,:,1]-x[1:,:-1,:,1]) - (x[:-1,:-1,:,0]-x[:-1,1:,:,0])
    else:
        y[:,:-1,:,0] += (x[:,:-1,:,2]-x[:,1:,:,2]) 
        y[:,-1,:,0] += (x[:,-1,:,2]-x[:,0,:,2])
        y[:,:,:-1,0] -= (x[:,:,:-1,1]-x[:,:,1:,1])
        y[:,:,-1,0] -= (x[:,:,-1,1]-x[:,:,0,1])
        
        y[:,:,:-1,1] += (x[:,:,:-1,0]-x[:,:,1:,0]) 
        y[:,:,-1,1] += (x[:,:,-1,0]-x[:,:,0,0]) 
        y[:-1,:,:,1] -= (x[:-1,:,:,2]-x[1:,:,:,2])
        y[-1,:,:,1] -= (x[-1,:,:,2]-x[0,:,:,2])
        
        y[:-1,:,:,2] += (x[:-1,:,:,1]-x[1:,:,:,1]) 
        y[-1,:,:,2] += (x[-1,:,:,1]-x[0,:,:,1]) 
        y[:,:-1,:,2] -= (x[:,:-1,:,0]-x[:,1:,:,0])
        y[:,-1,:,2] -= (x[:,-1,:,0]-x[:,0,:,0])

# Calculate the result on the CPU for comparison
y_cpu_pred = np.zeros(y_cpu.shape, dtype=y_cpu.dtype)

cpu_samples = []
for _ in range(len(gpu_samples)):
    start_ns = time.perf_counter_ns()
    cpu_shader(x_cpu, y_cpu_pred, wrap_around=True)
    end_ns = time.perf_counter_ns()
    delta_ns = end_ns - start_ns
    cpu_samples.append(delta_ns)
cpu_delta_ns = np.array(cpu_samples)
cpu_delta_ns_avg = np.mean(cpu_delta_ns)
cpu_cell_rate = total_cells / (cpu_delta_ns_avg*1e-9)
cpu_flops = cpu_cell_rate*flops_per_cell
print(f"cpu_delta_avg={cpu_delta_ns_avg*1e-3:.3f} us")
print(f"cpu_cell_rate={cpu_cell_rate*1e-6:.3f} M/s")
print(f"cpu_flops={cpu_flops*1e-9:.3f} GFlops")
print(f"gpu/cpu = {gpu_cell_rate/cpu_cell_rate:.2f}x")

# Ensure results are the same
error = y_gpu_pred - y_cpu_pred
#error = error[:-1,:-1,:-1,:] # skip last dimension on curl
error_max = np.max(error)
error_min = np.min(error)
error_abs = np.abs(error)
error_avg = np.mean(error) 
error_abs_avg = np.mean(error_abs)

print(f"error_min={error_min:.3e}")
print(f"error_max={error_max:.3e}")
print(f"error_avg={error_avg:.3e}")
print(f"error_abs_avg={error_abs_avg:.3e}")

n_read = 1
print(y_gpu_pred[:n_read, :n_read, :n_read, :])
print(y_cpu_pred[:n_read, :n_read, :n_read, :])
print(y_gpu_pred[-n_read:, -n_read:, -n_read:,:])
print(y_cpu_pred[-n_read:, -n_read:, -n_read:,:])

cpu_delta_avg=83240.843 us
cpu_cell_rate=50.388 M/s
cpu_flops=0.605 GFlops
gpu/cpu = 226.06x
error_min=-5.722e-06
error_max=5.722e-06
error_avg=1.376e-07
error_abs_avg=4.631e-07
[[[[-10.052131    3.7448397   6.307289 ]]]]
[[[[-10.052131   3.74484    6.307289]]]]
[[[[-2.3594213  1.2679145 -9.685694 ]]]]
[[[[-2.3594213  1.2679145 -9.685691 ]]]]


In [9]:
# Calculate the result on the CPU for comparison
torch_device = "cpu"
x_torch_cpu = torch.zeros(*x_cpu.shape, dtype=torch.float32, device=torch_device)
y_torch_cpu_pred = torch.zeros(*y_cpu.shape, dtype=torch.float32, device=torch_device)
x_torch_cpu.copy_(torch.from_numpy(x_cpu))

torch_cpu_samples = []
for _ in range(len(gpu_samples)):
    start_ns = time.perf_counter_ns()
    cpu_shader(x_torch_cpu, y_torch_cpu_pred, wrap_around=True)
    end_ns = time.perf_counter_ns()
    delta_ns = end_ns - start_ns
    torch_cpu_samples.append(delta_ns)
torch_cpu_delta_ns = np.array(torch_cpu_samples)
torch_cpu_delta_ns_avg = np.mean(torch_cpu_delta_ns)
torch_cpu_cell_rate = total_cells / (torch_cpu_delta_ns_avg*1e-9)
torch_cpu_flops = torch_cpu_cell_rate*flops_per_cell
print(f"torch_cpu_delta_avg={torch_cpu_delta_ns_avg*1e-3:.3f} us")
print(f"torch_cpu_cell_rate={torch_cpu_cell_rate*1e-6:.3f} M/s")
print(f"torch_cpu_flops={torch_cpu_flops*1e-9:.3f} GFlops")
print(f"gpu/torch_cpu = {gpu_cell_rate/torch_cpu_cell_rate:.2f}x")

y_torch_cpu_pred = y_torch_cpu_pred.numpy()

# Ensure results are the same
error = y_gpu_pred - y_torch_cpu_pred
#error = error[:-1,:-1,:-1,:] # skip last dimension on curl
error_max = np.max(error)
error_min = np.min(error)
error_abs = np.abs(error)
error_avg = np.mean(error) 
error_abs_avg = np.mean(error_abs)

print(f"error_min={error_min:.3e}")
print(f"error_max={error_max:.3e}")
print(f"error_avg={error_avg:.3e}")
print(f"error_abs_avg={error_abs_avg:.3e}")

n_read = 1
print(y_gpu_pred[:n_read, :n_read, :n_read, :])
print(y_torch_cpu_pred[:n_read, :n_read, :n_read, :])
print(y_gpu_pred[-n_read:, -n_read:, -n_read:,:])
print(y_torch_cpu_pred[-n_read:, -n_read:, -n_read:,:])

torch_cpu_delta_avg=36977.993 us
torch_cpu_cell_rate=113.427 M/s
torch_cpu_flops=1.361 GFlops
gpu/torch_cpu = 100.42x
error_min=-5.722e-06
error_max=5.722e-06
error_avg=1.376e-07
error_abs_avg=4.631e-07
[[[[-10.052131    3.7448397   6.307289 ]]]]
[[[[-10.052131   3.74484    6.307289]]]]
[[[[-2.3594213  1.2679145 -9.685694 ]]]]
[[[[-2.3594213  1.2679145 -9.685691 ]]]]


In [10]:
# Calculate the result on the CPU for comparison
TORCH_DIRECTML_DEVICE = torch_directml.device(0)
torch_device = TORCH_DIRECTML_DEVICE
x_torch_gpu = torch.zeros(*x_cpu.shape, dtype=torch.float32, device=torch_device)
y_torch_gpu_pred = torch.zeros(*y_cpu.shape, dtype=torch.float32, device=torch_device)
x_torch_gpu.copy_(torch.from_numpy(x_cpu))

torch_gpu_samples = []
for _ in range(len(gpu_samples)):
    start_ns = time.perf_counter_ns()
    cpu_shader(x_torch_gpu, y_torch_gpu_pred, wrap_around=True)
    end_ns = time.perf_counter_ns()
    delta_ns = end_ns - start_ns
    torch_gpu_samples.append(delta_ns)
torch_gpu_delta_ns = np.array(torch_gpu_samples)
torch_gpu_delta_ns_avg = np.mean(torch_gpu_delta_ns)
torch_gpu_cell_rate = total_cells / (torch_gpu_delta_ns_avg*1e-9)
torch_gpu_flops = torch_gpu_cell_rate*flops_per_cell
print(f"torch_gpu_delta_avg={torch_gpu_delta_ns_avg*1e-3:.3f} us")
print(f"torch_gpu_cell_rate={torch_gpu_cell_rate*1e-6:.3f} M/s")
print(f"torch_gpu_flops={torch_gpu_flops*1e-9:.3f} GFlops")
print(f"gpu/torch_gpu = {gpu_cell_rate/torch_gpu_cell_rate:.2f}x")

y_torch_gpu_pred = y_torch_gpu_pred.cpu().numpy()

# Ensure results are the same
error = y_gpu_pred - y_torch_gpu_pred
error = error[:-1,:-1,:-1,:] # skip last dimension on curl
error_max = np.max(error)
error_min = np.min(error)
error_abs = np.abs(error)
error_avg = np.mean(error) 
error_abs_avg = np.mean(error_abs)

print(f"error_min={error_min:.3e}")
print(f"error_max={error_max:.3e}")
print(f"error_avg={error_avg:.3e}")
print(f"error_abs_avg={error_abs_avg:.3e}")

n_read = 1
print(y_gpu_pred[:n_read, :n_read, :n_read, :])
print(y_torch_gpu_pred[:n_read, :n_read, :n_read, :])
print(y_gpu_pred[-n_read:, -n_read:, -n_read:,:])
print(y_torch_gpu_pred[-n_read:, -n_read:, -n_read:,:])

torch_gpu_delta_avg=421550.671 us
torch_gpu_cell_rate=9.950 M/s
torch_gpu_flops=0.119 GFlops
gpu/torch_gpu = 1144.82x
error_min=-5.722e-06
error_max=3.338e-06
error_avg=1.334e-07
error_abs_avg=4.592e-07
[[[[-10.052131    3.7448397   6.307289 ]]]]
[[[[-10.052131   3.74484    6.307289]]]]
[[[[-2.3594213  1.2679145 -9.685694 ]]]]
[[[[ 0.         -0.72506833  0.        ]]]]
