Skip to content

Commit

Permalink
add an option deferred_fetch to Cursor.execute()
Browse files Browse the repository at this point in the history
  • Loading branch information
dungdm93 committed Sep 11, 2023
1 parent f712739 commit 3af2b28
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 15 deletions.
53 changes: 53 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,59 @@ def json(self):
assert isinstance(result, TrinoResult)


def test_trino_query_deferred_fetch(sample_get_response_data):
"""
Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution
"""

class MockResponse(mock.Mock):
# Fake response class
@property
def headers(self):
return {
'X-Trino-Fake-1': 'one',
'X-Trino-Fake-2': 'two',
}

def json(self):
return sample_get_response_data

rows = sample_get_response_data['data']
sample_get_response_data['data'] = []
sql = 'SELECT 1'
request = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test",
source="test",
catalog="test",
schema="test",
properties={},
),
http_scheme="http",
)
query = TrinoQuery(
request=request,
query=sql
)

with (
mock.patch.object(request, 'post', return_value=MockResponse()),
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch,
):
result = query.execute()
mock_fetch.assert_called_once()
assert result.rows == rows

with (
mock.patch.object(request, 'post', return_value=MockResponse()),
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch,
):
result = query.execute(deferred_fetch=True)
mock_fetch.assert_not_called()


def test_delay_exponential_without_jitter():
max_delay = 1200.0
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)
Expand Down
27 changes: 17 additions & 10 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,13 +775,18 @@ def result(self):
def info_uri(self):
return self._info_uri

def execute(self, additional_http_headers=None) -> TrinoResult:
"""Initiate a Trino query by sending the SQL statement
This is the first HTTP request sent to the coordinator.
It sets the query_id and returns a Result object used to
track the rows returned by the query. To fetch all rows,
call fetch() until finished is true.
def execute(
self,
additional_http_headers: Optional[Dict[str, Any]] = None,
deferred_fetch: bool = False,
) -> TrinoResult:
"""Initiate a Trino query by sending the SQL statement to the coordinator.
To fetch all rows, call fetch() until finished is true.
Parameters:
additional_http_headers: extra headers send to the Trino server.
deferred_fetch: By default, the execution is blocked until at least one row is received
or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True.
"""
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)
Expand All @@ -802,9 +807,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
if not deferred_fetch:
# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()

return self._result

def _update_state(self, status):
Expand Down
20 changes: 16 additions & 4 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
def _generate_unique_statement_name(self):
return 'st_' + uuid.uuid4().hex.replace('-', '')

def execute(self, operation, params=None):
def execute(self, operation, params=None, **kwargs: Any):
additional_http_headers = kwargs.get("additional_http_headers", None)
deferred_fetch = kwargs.get("deferred_fetch", False)

if params:
assert isinstance(params, (list, tuple)), (
'params must be a list or tuple containing the query '
Expand All @@ -575,7 +578,10 @@ def execute(self, operation, params=None):
self._query = self._execute_prepared_statement(
statement_name, params
)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))
finally:
# Send deallocate statement
# At this point the query can be deallocated since it has already
Expand All @@ -584,12 +590,18 @@ def execute(self, operation, params=None):
self._deallocate_prepared_statement(statement_name)
else:
self._query = self._execute_immediate_statement(operation, params)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))

else:
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))
return self

def executemany(self, operation, seq_of_params):
Expand Down
2 changes: 1 addition & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
def do_execute(
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
):
cursor.execute(statement, parameters)
cursor.execute(statement, parameters, **context.execution_options)

def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
if dbapi_connection.transaction is not None:
Expand Down

0 comments on commit 3af2b28

Please sign in to comment.