## Compute shaders

Webgpu allows to run `compute shaders` which allow doing any calculations on GPU (currently only single precision).

In [None]:
import webgpu.jupyter # initializes the webgpu runtime in jupyter notebooks

import numpy as np
from webgpu.utils import *

device = get_device()

a = np.array([1, 2, 3], dtype=np.float32)
b = np.array([4, 5, 6], dtype=np.float32)

N = a.size
mem_size = a.size*a.itemsize

a_gpu = buffer_from_array(a)
b_gpu = buffer_from_array(b)

res_gpu = device.createBuffer(mem_size, BufferUsage.STORAGE | BufferUsage.COPY_SRC )

uniform_N = uniform_from_array(np.array([N], dtype=np.uint32))
    
bindings = [
    BufferBinding(101, a_gpu),
    BufferBinding(102, b_gpu),
    BufferBinding(103, res_gpu, read_only=False),
    UniformBinding(104, uniform_N),
]


shader_code = """

@group(0) @binding(101) var<storage> vec_a : array<f32>;
@group(0) @binding(102) var<storage> vec_b : array<f32>;
@group(0) @binding(103) var<storage, read_write> vec_res : array<f32>;
@group(0) @binding(104) var<uniform> N : u32;


@compute @workgroup_size(256, 1, 1)
fn main( @builtin(global_invocation_id) gid: vec3<u32>) {

  let tid = gid.x;
  if (tid < N)
    {
      vec_res[tid] = vec_a[tid] + vec_b[tid];
    }
}    
"""

run_compute_shader(shader_code, bindings, n_workgroups=((N + 255) // 256, 1, 1))
data = read_buffer(res_gpu, np.float32)
print(data)
