### Dask LD Prune Prototype

In [1]:
import dask
import dask.array as da
import numpy as np
from numba import jit
import plotly.graph_objects as go
from IPython.display import Image
dask.config.set(scheduler='single-threaded')

<dask.config.set at 0x7f23b4b72ad0>

### Load Random Data

In [None]:
# X = da.random.randint(0, 3, size=(500, 30), dtype=np.uint8)
# X.shape
# pos = np.array(
#     [[1, 1]] * 300 + 
#     [[2, 1]] * 100 + 
#     [[3, 1]] * 75 + 
#     [[4, 1]] * 25
# )
# len(pos)

### Load HapMap Data

In [2]:
X = da.from_zarr('/home/eczech/data/gwas/benchmark/datasets/ld_prune/lsh/hapmap-sr=0.1.zarr')
X = X[::8]
X

Unnamed: 0,Array,Chunk
Bytes,215.66 kB,215.66 kB
Shape,"(1307, 165)","(1307, 165)"
Count,3 Tasks,1 Chunks
Type,int8,numpy.ndarray
"Array Chunk Bytes 215.66 kB 215.66 kB Shape (1307, 165) (1307, 165) Count 3 Tasks 1 Chunks Type int8 numpy.ndarray",165  1307,

Unnamed: 0,Array,Chunk
Bytes,215.66 kB,215.66 kB
Shape,"(1307, 165)","(1307, 165)"
Count,3 Tasks,1 Chunks
Type,int8,numpy.ndarray


In [3]:
# Generate (contig, position) vector
pos = np.array(
    [[1, 1]] * X.shape[0]
)
len(pos)

1307

In [5]:
assert len(X) == len(pos)

### Compute Chunk Alignment

In [6]:
from dataclasses import dataclass
from typing import List

@dataclass(frozen=True)
class ChunkContigInfo:
    contig_index: int
    contig_value: int
    chunk_idx: List[int]
    chunk_size: List[int]
        
@dataclass(frozen=True)
class ChunkInfo:
    chunks: List[ChunkContigInfo]

    def get_contig_chunk_boundary(self):
        """ Get index in last chunk for each contig keyed by chunk index """
        bounds = {}
        for c in self.chunks:
            bounds[c.chunk_idx[-1]] = c.chunk_size[-1]
        return bounds
    
    def get_chunk_offset(self):
        """ Get global offset for first row index in each chunk keyed by chunk index """
        offsets = {}
        o = 0
        for c in self.chunks:
            for i, s in zip(c.chunk_idx, c.chunk_size):
                offsets[i] = o
                o += s
        return offsets

        
def get_chunk_info(pos, size):
    chunks = []
    csct = 0
    for i, (v, c) in enumerate(zip(*np.unique(pos[:,0], return_counts=True))):
        sizes = [size] * (c//size) 
        if c % size > 0:
            sizes += [int(c % size)]
        idx = [j + csct for j in range(len(sizes))]
        csct += len(sizes)
        chunks.append(ChunkContigInfo(contig_index=i, contig_value=v, chunk_idx=idx, chunk_size=sizes))
    return ChunkInfo(chunks)

# Calculate chunk sizes with boundaries determined by contigs and a maximum row limit
get_chunk_info(pos, 40).chunks

[ChunkContigInfo(contig_index=0, contig_value=1, chunk_idx=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32], chunk_size=[40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 27])]

In [7]:
# Eventual wrapper method to run rechunking and other metadata calculations
def prune(X, pos, window, step, normalize=True, align_chunks=True, windows_per_chunk=10):
    assert step < window
    assert window % step == 0
    assert X.ndim == 2
    
    chunk_size = windows_per_chunk * window 
    chunk_info = get_chunk_info(pos, chunk_size)
    chunk_lens = tuple(
        cs
        for ci in chunk_info.chunks
        for cs in ci.chunk_size
    )
    
    if not align_chunks and X.chunks[0] != chunk_lens:
        raise ValueError(f'Expected chunks {chunk_lens}, found {X.chunks[0]}')
    if align_chunks:
        X = X.rechunk(chunks=(chunk_lens, X.chunks[1]))
        
    # This is the amount each block should overlap
    overlap_depth = window - step
    
    return X, chunk_info, overlap_depth

In [8]:
window = 100
step = 10
Xp, chunk_info, overlap_depth = prune(X, pos, window, step, windows_per_chunk=3)
Xp

Unnamed: 0,Array,Chunk
Bytes,215.66 kB,49.50 kB
Shape,"(1307, 165)","(300, 165)"
Count,13 Tasks,5 Chunks
Type,int8,numpy.ndarray
"Array Chunk Bytes 215.66 kB 49.50 kB Shape (1307, 165) (300, 165) Count 13 Tasks 5 Chunks Type int8 numpy.ndarray",165  1307,

Unnamed: 0,Array,Chunk
Bytes,215.66 kB,49.50 kB
Shape,"(1307, 165)","(300, 165)"
Count,13 Tasks,5 Chunks
Type,int8,numpy.ndarray


In [9]:
chunk_info.chunks

[ChunkContigInfo(contig_index=0, contig_value=1, chunk_idx=[0, 1, 2, 3, 4], chunk_size=[300, 300, 300, 300, 107])]

In [10]:
Xp.chunks

((300, 300, 300, 300, 107), (165,))

In [11]:
Xp.map_blocks(lambda x: np.array([x.shape[0]]), dtype=int, drop_axis=[1]).compute()

array([300, 300, 300, 300, 107])

In [12]:
chunk_info.get_contig_chunk_boundary()

{4: 107}

In [13]:
chunk_info.get_chunk_offset()

{0: 0, 1: 300, 2: 600, 3: 900, 4: 1200}

In [14]:
overlap_depth

90

### Run Pruning

In [50]:
#@jit(nopython=False)
def _prune(
    X, block_id=None, window=None, step=None, threshold=None, 
    contig_boundary=None, chunk_offset=None, overlap_depth=None,
    short_circuit=False
):
    assert block_id is not None
    assert window is not None
    assert step is not None
    assert threshold is not None
    assert contig_boundary is not None
    assert chunk_offset is not None
    assert overlap_depth is not None
    
    # Always eliminate leading padding rows
    X = X[overlap_depth:]
    # Eliminate padding rows if in last chunk
    if block_id[0] == max(chunk_offset.keys()):
        X = X[:-overlap_depth]
        
    # Determine max row index for contig (only applies to overlap)
    row_max = contig_boundary.get(block_id[0])
    # Determine global offset from original array
    row_offset = chunk_offset.get(block_id[0])
    
    # Run preprocessing for triangle inequality short-circuiting
    if short_circuit:
        Xc = (X - np.mean(X, axis=1, keepdims=True)) / np.std(X, axis=1, keepdims=True)
        Xs = Xc[::(step//2)]
        # Make sure to divide dot products by number of columns
        Xs = np.matmul(Xc, Xs.T) / Xc.shape[1]
        # Convert correlation to distance (d = 1 - corr => corr = 1 - d)
        Xs = 1 - Xs
        eps = 1e-6
        assert np.all((Xs >= -eps) & (Xs <= 2 + eps))
        Xsi = np.argmin(Xs, axis=1)
        assert Xsi.shape[0] == Xs.shape[0] == X.shape[0]
    
    n, m = X.shape
    r2m = np.ones((n, n), dtype=np.float64) * -2
    keep = np.ones(n, dtype=bool)
    # Loop over window start index
    for w_start in range(0, n, step):
        w_stop = min(w_start + window, n)
        # Loop over primary row index
        for i in range(w_start, w_stop):
            if not keep[i]:
                continue
            # Loop over secondary row index
            # TODO: Figure out how to avoid re-computation
            # cf. j_start = i+1 if w_start == 0 else max(i+1, w_start + window - step)
            j_start = i + 1
            for j in range(j_start, w_stop):
                if not keep[j]:
                    continue
                if short_circuit:
                    # Find closest vector to primary
                    min_dist_idx = Xsi[i]
                    di = Xs[i, min_dist_idx]
                    dj = Xs[j, min_dist_idx]
                    # Determine distance lower bound to secondary
                    dlb = abs(di - dj)
                    # Convert back to r2 and continue if we can be sure
                    # these two rows are sufficiently uncorrelated
                    cub = 1 - dlb
                    if cub ** 2 <= threshold:
                        continue
                xi, xj = X[i], X[j]
                mask = (xi >= 0) & (xj >= 0)
                xi, xj = xi[mask], xj[mask]
                r2 = np.corrcoef(xi, xj)[0, 1] ** 2
                assert np.isnan(r2) or -1 <= r2 <= 1, 'r2 value' + str(r2) + ' not in [-1, 1]'
                r2m[i,j] = r2
                if r2 > threshold:
                    keep[j] = False
                    
    # For now, simply return the correlation matrix
    r2m = r2m.reshape(-1)
    return np.stack([r2m, np.repeat(block_id[0], repeats=len(r2m))], axis=1)

g = da.map_overlap(
    Xp, _prune, 
    window=window, step=step,
    overlap_depth=overlap_depth,
    depth=(overlap_depth, 0),
    threshold=.3,
    boundary=-1, 
    short_circuit=True,
    contig_boundary=chunk_info.get_contig_chunk_boundary(),
    chunk_offset=chunk_info.get_chunk_offset(),
    chunks=([v**2 for v in Xp.chunks[0]], 2),
    dtype=np.float64,
    trim=False
)

In [51]:
R2 = g.compute()


invalid value encountered in true_divide


invalid value encountered in true_divide



In [52]:
R2.shape

(619849, 2)

In [53]:
R2[:10]

array([[-2.00000000e+00,  0.00000000e+00],
       [ 7.40086562e-03,  0.00000000e+00],
       [ 5.39528575e-02,  0.00000000e+00],
       [ 1.42916124e-02,  0.00000000e+00],
       [ 1.24899305e-02,  0.00000000e+00],
       [ 2.29085388e-04,  0.00000000e+00],
       [ 2.26030980e-03,  0.00000000e+00],
       [ 1.00771770e-02,  0.00000000e+00],
       [ 1.27525524e-03,  0.00000000e+00],
       [ 4.69770036e-03,  0.00000000e+00]])

In [54]:
r2 = R2[R2[:,1] == 0]
r2 = r2[:,0].reshape(int(np.sqrt(len(r2))), int(np.sqrt(len(r2))))
r2.shape

(390, 390)

### Results (No Short-Circuit)

In [55]:
(r2 >= 0).sum()

31048

In [60]:
fig = go.Figure(data=[go.Heatmap(z=r2[::-1])])
fig.show()

### Results (With Short-Circuit)

In [48]:
(r2 >= 0).sum()

23979

In [49]:
fig = go.Figure(data=[go.Heatmap(z=r2[::-1])])
fig.show()