In [1]:
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(threads_per_worker=1)

In [2]:
from dask.distributed import Client, wait
import time

import dask
import dask_cudf
import dask.dataframe as dd

import pandas as pd

import cudf
import numpy as np

import pandas.testing

from dask_cuml.neighbors import NearestNeighbors as cumlNN


In [3]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:45059  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 2  Cores: 2  Memory: 50.39 GB


In [4]:
def create_df(f, m, n):
    X = np.random.rand(m, n)
    ret = cudf.DataFrame([(i,X[:,i].astype(np.float32)) for i in range(n)], 
                        index = cudf.dataframe.RangeIndex(f*m, f*m+m, 1))
    
    print(str(ret))
    return ret

def get_meta(df):
    ret =  df.iloc[:0]
    return ret

In [5]:
workers = client.has_what().keys()
workers

dict_keys(['tcp://127.0.0.1:40175', 'tcp://127.0.0.1:45830'])

In [6]:
# Per gpu/worker
train_m = 10000 
train_n = 1000

In [7]:
search_m = 10000
search_k = 15

In [8]:
%%time

# Create dfs on each worker (gpu)
dfs = [client.submit(create_df, n, train_m, train_n, workers = [worker])
       for worker, n in list(zip(workers, list(range(len(workers)))))]

# Wait for completion
wait(dfs)

meta = client.submit(get_meta, dfs[0]).result()

CPU times: user 602 ms, sys: 227 ms, total: 829 ms
Wall time: 7.4 s


In [9]:
lr = cumlNN()

In [10]:
%%time
X_df = dask_cudf.from_delayed(dfs, meta=meta)

CPU times: user 975 ms, sys: 98.7 ms, total: 1.07 s
Wall time: 1.01 s


In [11]:
client.who_has()

{'create_df-0c0691e76bf8e8aefca6bc2147638b31': ('tcp://127.0.0.1:40175',),
 'create_df-2605e0a451580da6f6b3aa8a6d0e2a19': ('tcp://127.0.0.1:45830',)}

In [12]:
%%time
lr.fit(X_df)

CPU times: user 327 ms, sys: 21.2 ms, total: 348 ms
Wall time: 1.95 s


In [13]:
%%time
D, I = lr.kneighbors(X_df[0:10000], search_k)

CPU times: user 1.61 s, sys: 140 ms, total: 1.75 s
Wall time: 4.15 s


In [14]:
print(str(I.compute()))

   0      1      2      3      4      5      6 ...     14
0  0   9016   9405  11897   1371  13547   6468 ...  19704
1  1  10142  10023   7644  13958  15802   3967 ...  16032
2  2    286  18403   9083   6032   9032  10069 ...   3166
3  3  16007   2833  13047  15114  18320    650 ...   7501
4  4  18806  13650  15136  15123  13922  16986 ...   6567
5  5   6808  15007   5450  15308   5794  16032 ...  12078
6  6  19063   5271   8378  18952   2689  15615 ...  12079
7  7  15295   3458  16853  15895   6249   8212 ...   9562
8  8   5088  12899   6415   5032  17537  17187 ...  17811
9  9  12551  16713   1014   6612  14545   3024 ...   1707
[9991 more rows]
[7 more columns]


In [15]:
print(str(D.compute()))

              0          1          2          3          4          5          6 ...         14
0 0.00024414062  140.76175  143.24838  143.61707  144.56842  145.26196  145.91473 ...  147.12753
1           0.0  143.01147   145.5264  145.59384    146.544  146.58359  146.75842 ...  148.08487
2           0.0  146.16101   146.8891  147.22748  147.99466  148.10764  148.26846 ...  149.67548
3           0.0  145.51044  146.15027  147.76392  148.02557   148.1903   148.7699 ...  150.17789
4           0.0  146.84421  147.10855  147.18796  147.67963  147.73804  148.30591 ...  149.42575
5           0.0  138.38107  139.62338  141.24573  141.67313  142.08456  142.35028 ...  143.81967
6 6.1035156e-05   145.6785  145.69165  146.08551  146.15036  146.59009  146.96014 ...  147.98499
7           0.0  145.36646  145.88678    146.182  146.21704  146.67093  146.93567 ...  147.65881
8           0.0  145.90506  146.64432  146.69531  147.10669  147.43063  147.66815 ...   148.3883
9           0.0  147.21524  14

In [16]:
np.array(D.compute().as_gpu_matrix())

array([[2.44140625e-04, 1.40761749e+02, 1.43248383e+02, ...,
        1.46407104e+02, 1.46910248e+02, 1.47127533e+02],
       [0.00000000e+00, 1.43011475e+02, 1.45526398e+02, ...,
        1.47958069e+02, 1.48002197e+02, 1.48084869e+02],
       [0.00000000e+00, 1.46161011e+02, 1.46889099e+02, ...,
        1.49300842e+02, 1.49607635e+02, 1.49675476e+02],
       ...,
       [0.00000000e+00, 1.46034424e+02, 1.47559448e+02, ...,
        1.50026215e+02, 1.50099823e+02, 1.50158936e+02],
       [6.10351562e-05, 1.46264130e+02, 1.47796509e+02, ...,
        1.50346466e+02, 1.50458038e+02, 1.50826172e+02],
       [0.00000000e+00, 1.43787262e+02, 1.44267120e+02, ...,
        1.47506989e+02, 1.47538055e+02, 1.47698059e+02]])