Skip to content

Commit

Permalink
Use download thread to speed up result retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Jan 5, 2023
1 parent e4a3f0f commit 3a5e182
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@
import copy
import functools
import os
import queue
import random
import re
import threading
import urllib.parse
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
import requests
Expand Down Expand Up @@ -666,6 +668,27 @@ def _verify_extra_credential(self, header):
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")


class ResultDownloader():
def __init__(self):
self.queue: queue.Queue = queue.Queue()
self.executor: Optional[ThreadPoolExecutor] = None

def submit(self, fetch_func: Callable[[], List[Any]]):
assert self.executor is not None
self.executor.submit(self.download_task, fetch_func)

def download_task(self, fetch_func):
self.queue.put(fetch_func())

def __enter__(self):
self.executor = ThreadPoolExecutor(max_workers=1)
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.executor.shutdown()
self.executor = None


class TrinoResult(object):
"""
Represent the result of a Trino query as an iterator on rows.
Expand Down Expand Up @@ -693,16 +716,21 @@ def rownumber(self) -> int:
return self._rownumber

def __iter__(self):
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
while not self._query.finished or self._rows is not None:
next_rows = self._query.fetch() if not self._query.finished else None
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row
with ResultDownloader() as result_downloader:
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
result_downloader.submit(self._query.fetch)
while not self._query.finished or self._rows is not None:
next_rows = result_downloader.queue.get() if not self._query.finished else None
if not self._query.finished:
result_downloader.submit(self._query.fetch)

self._rows = next_rows
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row

self._rows = next_rows


class TrinoQuery(object):
Expand Down Expand Up @@ -735,7 +763,7 @@ def columns(self):
while not self._columns and not self.finished and not self.cancelled:
# Columns are not returned immediately after query is submitted.
# Continue fetching data until columns information is available and push fetched rows into buffer.
self._result.rows += self.fetch()
self._result.rows += self.map_rows(self.fetch())
return self._columns

@property
Expand Down Expand Up @@ -784,7 +812,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
self._result.rows += self.map_rows(self.fetch())
return self._result

def _update_state(self, status):
Expand All @@ -796,19 +824,20 @@ def _update_state(self, status):
if status.columns:
self._columns = status.columns

def fetch(self) -> List[List[Any]]:
def fetch(self) -> List[Any]:
"""Continue fetching data for the current query_id"""
response = self._request.get(self._request.next_uri)
status = self._request.process(response)
self._update_state(status)
logger.debug(status)
if status.next_uri is None:
self._finished = True
return status.rows

def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]:
if not self._row_mapper:
return []

return self._row_mapper.map(status.rows)
return self._row_mapper.map(rows)

def cancel(self) -> None:
"""Cancel the current query"""
Expand Down

0 comments on commit 3a5e182

Please sign in to comment.