### Load necessary modules

In [None]:
from numba import cuda
import cupy as cp
import numpy as np
import math

### Kernel to evaluate $y_i = x_{i}^2$

In [None]:
@cuda.jit
def eval_function(y, x, N):
        idx = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    
        if (idx<N):
            y[idx] = x[idx] * x[idx]

### Kernel to build matrix with $a_{i,j} = \exp\left(-\left(x^{(1)}_i-x^{(2)}_j\right)^2\right)$

In [None]:
@cuda.jit
def fill_matrix(A, X1, X2, N1, N2):
    idx = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    
    # idx = i * N2 + j
    
    if (idx < N1* N2):
        i = idx // N2
        j = idx % N2
        
        A[idx] = math.exp(-(X1[i]-X2[j])*(X1[i]-X2[j]))

### Train kernel ridge regression model

In [None]:
N = 10

cp.random.seed(0)

A = cp.empty(N*N)  # building 1D arrray for kernel
X = cp.random.rand(N)
Y = cp.empty(N)

# evaluating function to get outputs
block_size = 1024;
grid_size = (N*N + (block_size-1)) // block_size
fill_matrix[grid_size,block_size](A, X, X, N, N)
cuda.synchronize()

# fill kernel matrix
block_size = 1024;
grid_size = (N + (block_size-1)) // block_size
eval_function[grid_size,block_size](Y, X, N)
cuda.synchronize()

A = np.reshape(A,(N,N)) # reindexing 1D array to matrix

alpha = cp.linalg.solve(A, Y)

### Evaluate prediction error

In [None]:
N_eval = 10000

cp.random.seed(42)
X_eval = cp.random.rand(N_eval)
Y_exact = cp.empty(N_eval)
Y_eval = cp.empty(N_eval)

# evaluate function to get exact solution
block_size = 1024;
grid_size = (N_eval + (block_size-1)) // block_size
eval_function[grid_size,block_size](Y_exact, X_eval, N_eval)
cuda.synchronize()

# allocate evaluation matrix
A_eval = cp.empty(N_eval*N)

# fill evaluation matrix
block_size = 1024;
grid_size = (N_eval*N + (block_size-1)) // block_size
fill_matrix[grid_size,block_size](A_eval, X_eval, X, N_eval, N)
cuda.synchronize()

A_eval = cp.reshape(A_eval,(N_eval,N)) # reindexing 1D array to matrix

# evaluate trained model
Y_eval = A_eval @ alpha

# compute error
error = ((Y_exact - Y_eval) * (Y_exact - Y_eval)).sum() / N_eval

print("error: %e" % error)

error: 8.346831e-09


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=a8cf5cf7-1e95-4ea0-8cd8-79ded01cd257' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>