Skip to content

Commit

Permalink
set the device id for the taskflow (PaddlePaddle#1011)
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor committed Sep 12, 2021
1 parent b9a4cb1 commit 955c10b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddlenlp/taskflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _prepare_static_mode(self):
if place == 'cpu':
self._config.disable_gpu()
else:
self._config.enable_use_gpu(100, 0)
self._config.enable_use_gpu(100, self.kwargs['device_id'])
self._config.switch_use_feed_fetch_ops(False)
self._config.disable_glog_info()
self.predictor = paddle.inference.create_predictor(self._config)
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/taskflow/taskflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self, task, model=None, device_id=0, **kwargs):
self.model = model
# Update the task config to kwargs
config_kwargs = TASKS[self.task]['models'][self.model]
kwargs['device_id'] = device_id
kwargs.update(config_kwargs)
self.kwargs = kwargs
task_class = TASKS[self.task]['models'][self.model]['task_class']
Expand Down

0 comments on commit 955c10b

Please sign in to comment.