In [1]:
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(threads_per_worker=1)

In [2]:
from cuml import numba_utils

In [3]:
from dask.distributed import Client, wait
import time

import dask
import dask_cudf
import dask.dataframe as dd

import pandas as pd

import cudf
import numpy as np

import pandas.testing

from dask_cuml import knn as cumlKNN


In [4]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:39587  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 2  Cores: 2  Memory: 50.39 GB


In [5]:
def create_df(f, m, n):
    X = np.random.rand(m, n)
    ret = cudf.DataFrame([(i,X[:,i].astype(np.float32)) for i in range(n)], 
                        index = cudf.dataframe.RangeIndex(f*m, f*m+m, 1))
    
    print(str(ret))
    return ret

def get_meta(df):
    ret =  df.iloc[:0]
    return ret

In [6]:
workers = client.has_what().keys()
workers

dict_keys(['tcp://127.0.0.1:45151', 'tcp://127.0.0.1:46074'])

In [7]:
# Per gpu/worker
train_m = 500000 
train_n = 1000

In [8]:
search_m = 10000
search_k = 15

In [9]:
%%time

# Create dfs on each worker (gpu)
dfs = [client.submit(create_df, n, train_m, train_n, workers = [worker])
       for worker, n in list(zip(workers, list(range(len(workers)))))]

# Wait for completion
wait(dfs)

meta = client.submit(get_meta, dfs[0]).result()

CPU times: user 1.15 s, sys: 322 ms, total: 1.47 s
Wall time: 19.5 s


In [10]:
lr = cumlKNN.KNN()

In [11]:
print(str(meta))

Empty DataFrame
Columns: [0, 1, 2, 3, 4, 5, 6, 999]
Index: []


In [12]:
%%time
X_df = dask_cudf.from_delayed(dfs, meta=meta)

CPU times: user 981 ms, sys: 79.3 ms, total: 1.06 s
Wall time: 999 ms


In [13]:
print(X_df[0:100].compute())

  0  1  2  3  4  5  6 ... 999
0 0.10139009 0.07201549 0.65633726 0.3071411 0.9925756 0.3539088 0.53590757 ... 0.2589649
1 0.35729495 0.30308837 0.56302357 0.10059356 0.28452167 0.96970314 0.26175326 ... 0.15090063
2 0.7250316 0.90886396 0.6061101 0.12998222 0.97684807 0.9353843 0.5272743 ... 0.026771756
3 0.52606386 0.6716579 0.37730047 0.59357494 0.13730475 0.6403551 0.59904647 ... 0.109955594
4 0.048398733 0.4157121 0.96247315 0.91701686 0.2341268 0.61787415 0.33628196 ... 0.69784194
5 0.58373994 0.08451256 0.25855947 0.5741674 0.21742156 0.9273928 0.032421775 ... 0.12900756
6 0.68150485 0.030775618 0.84499604 0.72958827 0.5053623 0.14381415 0.2295179 ... 0.078017496
7 0.329907 0.7917158 0.88575083 0.118281655 0.5324294 0.8937012 0.39055377 ... 0.20475446
8 0.04843139 0.70570797 0.60194105 0.60177773 0.48951828 0.95911366 0.8427563 ... 0.24056378
9 0.56068295 0.23696649 0.16458271 0.7515991 0.4354127 0.09067992 0.5553361 ... 0.01360116
[91 more rows]
[992 more columns]


In [14]:
client.who_has()

{'create_df-a4f799a633489b66d025a10b8f412b8c': ('tcp://127.0.0.1:45151',),
 'create_df-7c2054fd12ffd58770aac694adf2e8c3': ('tcp://127.0.0.1:46074',)}

In [15]:
%%time
lr.fit(X_df)

CPU times: user 312 ms, sys: 24.5 ms, total: 337 ms
Wall time: 1.97 s


In [None]:
%%time
I, D = lr.kneighbors(X_df[0:1], search_k)

In [20]:
print(str(I.compute()))

   0       1       2       3       4       5       6 ...      14
0  0  781796  523508  916968  167847   20716  596024 ...  379231
1  1  730918   27991  922692  385833  832135  308708 ...  242088
2  2  146000  423092  199800  996310  548837  444670 ...  153133
3  3  616041  413595   94004  885505  184494  749833 ...  545175
4  4   59987  274076  904959   93236  515061  136859 ...  690894
5  5   34607  846054  225128  279271  402862  639006 ...  536653
6  6  475732  789387   67332  247823  282084  153822 ...  685600
7  7  819266  667070  704379   41812  882521  456082 ...  232857
8  8  836894  781884  695143  610935  263473   60620 ...  154401
9  9   63760  443290  109338  845015  383271  590284 ...  117366
[99991 more rows]
[7 more columns]


In [21]:
print(str(D.compute()))

               0          1          2          3          4          5          6 ...         14
0   0.0004272461  137.59232  137.84723  137.86157   138.1344   138.3052  138.40436 ...  139.49414
1            0.0  139.36816  139.42822   141.5502   141.7684   141.8808  142.28348 ...   143.1145
2  0.00012207031  138.09415    138.393  139.39926   139.9206  140.04099  140.26184 ...  141.18582
3  0.00079345703  134.95456   137.4346  138.43582  138.63019  138.72491  139.38257 ...  140.39908
4  0.00048828125  142.94702   143.2362  143.62723  143.89893  144.63147  144.71918 ...  145.62146
5   0.0004272461  140.81964  142.57458  142.69263  142.84949  143.13815  144.00696 ...  144.86786
6  0.00024414062  140.68677  140.71603  140.81183  141.55954  141.75519  142.20209 ...  143.61176
7            0.0  140.56503  140.74158  141.50775  142.17569  142.33572  142.90872 ...  144.57922
8  0.00036621094   138.3765   139.6683  140.63403  140.85495  141.42828  141.44672 ...  142.57455
9            0.0  13

In [None]:
a.index[-1]