## LSH Prototype 2

[SimHash](https://en.wikipedia.org/wiki/SimHash) implementation over Xarray/Dask.

In [4]:
h = None              # Number of individual signed projections to compute
g = None              # Number of composite hashes (groupings of signs)
sample_rate = None    # Variant sampling rate used in extraction
ds_name = None        # Dataset name 
n_workers = 8         # Number of dask workers
mem_fraction = .9     # Maximum fraction of system memory to use
projection = 'random'

# Example Settings
# h = 24
# g = 100        
# sample_rate = .05
# ds_name = 'hapmap'
# n_workers = 8
# mem_fraction = .9
# projection = 'orthogonal'

In [5]:
import os
import pandas as pd
import numpy as np
from dask.distributed import Client
from gwas_analysis.dask import get_dask_client
import dask.array as da
import dask.dataframe as dd
import xarray as xr
%run {os.environ['NB_DIR']}/nb.py
%run $BENCHMARK_METHOD_DIR/common.py
assert h is not None
assert g is not None
assert sample_rate is not None
assert ds_name is not None
ds_config = DATASET_CONFIG[ds_name]
n_projections = h * g
ds_path = ld_prune_lsh.dataset_path(ds_name, sr=sample_rate)
ds_path

'/home/eczech/data/gwas/benchmark/datasets/ld_prune/lsh/hapmap-sr=0.05'

### Initialization

In [6]:
client = get_dask_client(n_workers=n_workers, max_mem_fraction=mem_fraction)
client

0,1
Client  Scheduler: tcp://127.0.0.1:45899  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 8  Cores: 8  Memory: 120.00 GB


In [7]:
# Load the coded call data
X = da.from_zarr(ds_path + '.zarr')

# Rechunk to ensure that more than one worker is used downstream
# (which is the case with a relatively small number of variants < 1M)
def blocks(n, n_workers):
    if n <= n_workers:
        return n
    return n // n_workers
X = X.rechunk(chunks=(blocks(X.shape[0], n_workers), -1))
X

Unnamed: 0,Array,Chunk
Bytes,12.00 MB,1.50 MB
Shape,"(72732, 165)","(9091, 165)"
Count,20 Tasks,9 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 12.00 MB 1.50 MB Shape (72732, 165) (9091, 165) Count 20 Tasks 9 Chunks Type uint8 numpy.ndarray",165  72732,

Unnamed: 0,Array,Chunk
Bytes,12.00 MB,1.50 MB
Shape,"(72732, 165)","(9091, 165)"
Count,20 Tasks,9 Chunks
Type,uint8,numpy.ndarray


In [8]:
# Convert to xarray and center variant vectors (to make dot product equivalent to un-normalized cosine)
X = xr.DataArray(
    data=X, 
    dims=['variant', 'sample'],
    coords=dict(
        variant=da.arange(X.shape[0], dtype=np.int32),
        sample=da.arange(X.shape[1], dtype=np.int32)
    ),
    name='calls'
)
X -= X.mean(dim='sample')
X /= X.std(dim='sample')
X

In [21]:
# Generate random vectors for projection
da.random.seed(seed=1)

def get_random_matrix(m, n):
    if projection == 'random':
        print('Using random projection')
        return da.random.normal(
            size=(m, n), 
            chunks='auto'
            # Split the random array into chunks here as well
            # chunks=(-1, blocks(n_projections, n_workers))
        )
    elif projection == 'orthogonal':
        print('Using orthogonal projection')
        n_batch = int(np.ceil(n / m))
        R = []
        for i in range(n_batch):
            rm = da.random.normal(
                size=(m, m), 
                chunks='auto'
            )
            R.append(da.linalg.qr(rm)[0])
        R = da.concatenate(R, axis=1)[:,:n]
        # rms = (rms - rms.mean(axis=1, keepdims=True)) / rms.std(axis=1, keepdims=True)
        return R
    else:
        raise ValueError(f'Projection type ${projection} not supported')

R = xr.DataArray(
    data=get_random_matrix(len(X['sample']), n_projections), 
    dims=('sample', 'projection'),
    coords=dict(
        sample=da.arange(len(X['sample']), dtype=np.int32),
        projection=da.arange(n_projections, dtype=np.int32)
    ),
    name='random'
)
R

Using orthogonal projection


### Compute Hash Bits

In [9]:
# Apply projection and attach hash signature grouping
P = (
    (X @ R)
    .rename('projections')
    .pipe(lambda x: x.assign_coords(hash_group=('projection', x['projection']//h)))
)
P

In [10]:
%%time
# Group by hash signature group (i.e. columns) and compute row-wise hashes
# within those columns, which will each be boolean vectors indicating sign
# TODO: This should operate on rows across hash group boundaries since the number
# of groups can be high and making many small chunks is inefficient
def hash_bits(x):
    return np.expand_dims(np.apply_along_axis(
        lambda r: hash(np.asarray(r).tobytes()), 
        axis=1, arr=x
    ), 1)
H = xr.DataArray(
    (P > 0).data.rechunk(chunks=(P.data.chunksize[0], h))
    .map_blocks(hash_bits, chunks=(P.data.chunksize[0], 1))
    .compute(),
    dims=('variant', 'hash_group'),
    coords=dict(
        variant=P['variant'],
        hash_group=np.arange(g)
    ),
    name='hash_value'
)
H

CPU times: user 4.25 s, sys: 263 ms, total: 4.51 s
Wall time: 8.27 s


Note that the above is unfortunately much faster than using the xarray API directly:

In [10]:
# %%time
# def hash_bits(x, axis=None):
#     # NOTE: It makes little difference if the np conversion is done
#     # per row or initially for this group (it must all be loaded into memory already)
#     return xr.DataArray(da.apply_along_axis(
#         lambda r: hash(np.asarray(r).tobytes()), 
#         axis=axis, arr=x
#     ))

# H = (
#     (P > 0)
#     .groupby('hash_group')
#     .reduce(hash_bits, dim='projection')
#     .rename('hash_value')
# )
# H

# For h = 24, g = 100:
# CPU times: user 24 s, sys: 1.64 s, total: 25.6 s
# Wall time: 2min 10s (compared to ~10s for Dask version w/ same parameters)

### Compute Hash Bucket

In [11]:
L = H.stack(i=('variant', 'hash_group')).reset_index('i')
L

In [12]:
%%time
def hash_bucket(x, axis=None):
    return xr.DataArray(da.apply_along_axis(
        lambda r: hash(np.asarray(r).tobytes()), 
        axis=axis, arr=x
    ))
L = L.assign_coords(hash_bucket=
    xr.concat([L, L['hash_group']], dim='component', coords='minimal').T
    .pipe(lambda x: x.reduce(hash_bucket, dim='component'))
    .rename('hash_bucket')
)
L

CPU times: user 2.66 s, sys: 522 ms, total: 3.18 s
Wall time: 30.1 s


### Export

In [13]:
df = L.to_dataset().to_dask_dataframe().drop('i', axis='columns')
df.head(8)

Unnamed: 0,variant,hash_group,hash_bucket,hash_value
0,0,0,-9112397201106482923,9211463383448702463
1,0,1,-3507714323226662759,9211463383448702463
2,0,2,8817343152979801908,9211463383448702463
3,0,3,3299391009333304604,9211463383448702463
4,0,4,555232255770496111,9211463383448702463
5,0,5,4082874934278889961,9211463383448702463
6,0,6,5795625952918698016,9211463383448702463
7,0,7,-4491241682868907049,9211463383448702463


In [17]:
%%time
path = ld_prune_lsh.dataset_path(ds_name, sr=sample_rate, h=h, g=g, p=projection) + '.parquet'
df.to_parquet(path)
path

CPU times: user 395 ms, sys: 225 ms, total: 620 ms
Wall time: 2.8 s


'/home/eczech/data/gwas/benchmark/datasets/ld_prune/lsh/hapmap-sr=0.05-h=24-g=100-p=orthogonal.parquet'

In [18]:
client.shutdown()

distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError
