Skip to content

Commit

Permalink
[dask] Order the prediction result. (dmlc#5416)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 24, 2020
1 parent eff45ae commit f861c70
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
99 changes: 67 additions & 32 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ class DaskDMatrix:
'''

_feature_names = None # for previous version's pickle
_feature_types = None

def __init__(self,
client,
data,
Expand All @@ -156,9 +153,9 @@ def __init__(self,
_assert_dask_support()
_assert_client(client)

self._feature_names = feature_names
self._feature_types = feature_types
self._missing = missing
self.feature_names = feature_names
self.feature_types = feature_types
self.missing = missing

if len(data.shape) != 2:
raise ValueError(
Expand Down Expand Up @@ -240,6 +237,10 @@ def check_columns(parts):
for part in parts:
assert part.status == 'finished'

self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i

key_to_partition = {part.key: part for part in parts}
who_has = await client.scheduler.who_has(
keys=[part.key for part in parts])
Expand All @@ -250,6 +251,16 @@ def check_columns(parts):

self.worker_map = worker_map

def get_worker_x_ordered(self, worker):
list_of_parts = self.worker_map[worker.address]
client = get_client()
list_of_parts_value = client.gather(list_of_parts)
result = []
for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0],
self.partition_order[part.key]))
return result

def get_worker_parts(self, worker):
'''Get mapped parts of data in each worker.'''
list_of_parts = self.worker_map[worker.address]
Expand Down Expand Up @@ -292,8 +303,8 @@ def get_worker_data(self, worker):
workers=set(self.worker_map.keys()))
logging.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self._feature_names,
feature_types=self._feature_types)
feature_names=self.feature_names,
feature_types=self.feature_types)
return d

data, labels, weights = self.get_worker_parts(worker)
Expand All @@ -311,9 +322,9 @@ def get_worker_data(self, worker):
dmatrix = DMatrix(data,
labels,
weight=weights,
missing=self._missing,
feature_names=self._feature_names,
feature_types=self._feature_types)
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types)
return dmatrix

def get_worker_data_shape(self, worker):
Expand Down Expand Up @@ -460,41 +471,65 @@ def predict(client, model, data, *args):
worker_map = data.worker_map
client = _xgb_get_client(client)

rabit_args = _get_rabit_args(worker_map, client)
missing = data.missing
feature_names = data.feature_names
feature_types = data.feature_types

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
logging.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
local_x = data.get_worker_data(worker)

with RabitContext(rabit_args):
local_predictions = booster.predict(
data=local_x, validate_features=local_x.num_row() != 0, *args)
return local_predictions

futures = client.map(dispatched_predict,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
list_of_parts = data.get_worker_x_ordered(worker)
predictions = []
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=feature_names,
feature_types=feature_types,
missing=missing)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
predictions.append(ret)
return predictions

def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
logging.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
rows, _ = data.get_worker_data_shape(worker)
return rows, 1 # default is 1
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []
for part, order in list_of_parts:
s = part.shape
shapes.append((s, order))
return shapes

def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = client.submit(func, wid,
pure=False,
workers=list_of_workers)
futures.append(f)

# Get delayed objects
results = client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results

results = map_function(dispatched_predict)
shapes = map_function(dispatched_get_shape)

# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
futures_shape = client.map(dispatched_get_shape,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
shapes = client.gather(futures_shape)
arrays = []
for i in range(len(futures_shape)):
arrays.append(da.from_delayed(futures[i], shape=(shapes[i][0], ),
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
predictions = da.concatenate(arrays, axis=0)
return predictions
Expand Down
5 changes: 5 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def test_from_dask_array():
# force prediction to be computed
prediction = prediction.compute()

single_node_predt = result['booster'].predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
Expand Down

0 comments on commit f861c70

Please sign in to comment.