From 955c10b67a84ec4d26ea93248293356b39e9f836 Mon Sep 17 00:00:00 2001 From: wawltor Date: Sun, 12 Sep 2021 17:45:56 +0800 Subject: [PATCH] set the device id for the taskflow (#1011) --- paddlenlp/taskflow/task.py | 2 +- paddlenlp/taskflow/taskflow.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlenlp/taskflow/task.py b/paddlenlp/taskflow/task.py index 7fc0d95c24521..3bcf2aa573320 100644 --- a/paddlenlp/taskflow/task.py +++ b/paddlenlp/taskflow/task.py @@ -91,7 +91,7 @@ def _prepare_static_mode(self): if place == 'cpu': self._config.disable_gpu() else: - self._config.enable_use_gpu(100, 0) + self._config.enable_use_gpu(100, self.kwargs['device_id']) self._config.switch_use_feed_fetch_ops(False) self._config.disable_glog_info() self.predictor = paddle.inference.create_predictor(self._config) diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 070e6de8974e6..ccdfd80ae9d54 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -135,6 +135,7 @@ def __init__(self, task, model=None, device_id=0, **kwargs): self.model = model # Update the task config to kwargs config_kwargs = TASKS[self.task]['models'][self.model] + kwargs['device_id'] = device_id kwargs.update(config_kwargs) self.kwargs = kwargs task_class = TASKS[self.task]['models'][self.model]['task_class']