In [5]:
import numpy as np
from numba import cuda

- should take LD matrix arg since that is easier to make full block scalable
- must take "score" argument ordered by variant index
- must be sorted by contig and position
- masking? Hail does mean impute in ld_matrix
- make sure call array is row-major ("C" order)

#def ld_prune(ds, window=1000, max_distance=None, impute=False, score=None, fields={'contig': 'contig', 'position': 'position'}):

In [10]:
@cuda.jit
def _ld_prune_block_gpu_kernel(arr, groups, positions, window, max_distance, out):
    # These are the pairs of rows being compared
    i, j = cuda.grid(2)
    if i < out.shape[0] and j < out.shape[1] and i + j < out.shape[0]:
        tmp = 0.
        # Loop through columns of data matrix for a pair of rows
        for k in range(arr.shape[1]):
            tmp += arr[i, k] * arr[i + j, k]
        out[i, j] = tmp


def _ld_prune_block_gpu(arr, groups, positions, window, scores=None, max_distance=None):
    nr, nc = arr.shape
    
    arr = cuda.to_device(arr)
    groups = cuda.to_device(groups)
    positions = cuda.to_device(positions)
    if scores is not None:
        scores = cuda.to_device(scores)
    
    # Output is (n rows, n pairwise comparisons in window)
    out = cuda.device_array((nr, window))
    
    
    tpb = (32, 32) # threads per block
    bpgr = int(math.ceil(nr / tpb[0])) # blocks per grid
    bpgc = int(math.ceil(window / tpb[1]))
    bpg = (bpgr, bpgc)
    
    _ld_prune_block_gpu_kernel[bpg, tpb](arr, groups, positions, window, max_distance or 0, out)
    
    return out.copy_to_host()

In [11]:
nr, nc = 10, 5
arr = np.ones((nr, nc))
groups = np.ones(nr)
positions = np.ones(nr)

res = _ld_prune_block_gpu(arr, groups, positions, window=3)
res

array([[5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 0.],
       [5., 0., 0.]])

In [1]:
# from: https://nyu-cds.github.io/python-numba/05-cuda/
from __future__ import division
from numba import cuda
import numpy
import math

# CUDA kernel
@cuda.jit
def matmul(A, B, C):
    """Perform matrix multiplication of C = A * B
    """
    row, col = cuda.grid(2)
    if row < C.shape[0] and col < C.shape[1]:
        tmp = 0.
        for k in range(A.shape[1]):
            tmp += A[row, k] * B[k, col]
        C[row, col] = tmp
        
# Host code

# Initialize the data arrays
A = numpy.full((24, 12), 3, numpy.float) # matrix containing all 3's
B = numpy.full((12, 22), 4, numpy.float) # matrix containing all 4's

# Copy the arrays to the device
A_global_mem = cuda.to_device(A)
B_global_mem = cuda.to_device(B)

# Allocate memory on the device for the result
C_global_mem = cuda.device_array((24, 22))

# Configure the blocks
threadsperblock = (16, 16)
blockspergrid_x = int(math.ceil(A.shape[0] / threadsperblock[0]))
blockspergrid_y = int(math.ceil(B.shape[1] / threadsperblock[1]))
blockspergrid = (blockspergrid_x, blockspergrid_y)

# Start the kernel 
matmul[blockspergrid, threadsperblock](A_global_mem, B_global_mem, C_global_mem)

# Copy the result back to the host
C = C_global_mem.copy_to_host()

print(C)

[[144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 144. 144. 144. 144. 144. 144. 144.]
 [144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144. 144.
  144. 1