From 3af2b282b38e1ab855f50a910a5f109f03bec52e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Tue, 8 Aug 2023 15:05:03 +0700 Subject: [PATCH] add an option `deferred_fetch` to `Cursor.execute()` --- tests/unit/test_client.py | 53 +++++++++++++++++++++++++++++++++++++ trino/client.py | 27 ++++++++++++------- trino/dbapi.py | 20 +++++++++++--- trino/sqlalchemy/dialect.py | 2 +- 4 files changed, 87 insertions(+), 15 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index bbd953e8..42400ae7 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1016,6 +1016,59 @@ def json(self): assert isinstance(result, TrinoResult) +def test_trino_query_deferred_fetch(sample_get_response_data): + """ + Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution + """ + + class MockResponse(mock.Mock): + # Fake response class + @property + def headers(self): + return { + 'X-Trino-Fake-1': 'one', + 'X-Trino-Fake-2': 'two', + } + + def json(self): + return sample_get_response_data + + rows = sample_get_response_data['data'] + sample_get_response_data['data'] = [] + sql = 'SELECT 1' + request = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test", + source="test", + catalog="test", + schema="test", + properties={}, + ), + http_scheme="http", + ) + query = TrinoQuery( + request=request, + query=sql + ) + + with ( + mock.patch.object(request, 'post', return_value=MockResponse()), + mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch, + ): + result = query.execute() + mock_fetch.assert_called_once() + assert result.rows == rows + + with ( + mock.patch.object(request, 'post', return_value=MockResponse()), + mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch, + ): + result = query.execute(deferred_fetch=True) + mock_fetch.assert_not_called() + + def test_delay_exponential_without_jitter(): max_delay = 1200.0 get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay) diff --git a/trino/client.py b/trino/client.py index 48368b7a..2bd92f2e 100644 --- a/trino/client.py +++ b/trino/client.py @@ -775,13 +775,18 @@ def result(self): def info_uri(self): return self._info_uri - def execute(self, additional_http_headers=None) -> TrinoResult: - """Initiate a Trino query by sending the SQL statement - - This is the first HTTP request sent to the coordinator. - It sets the query_id and returns a Result object used to - track the rows returned by the query. To fetch all rows, - call fetch() until finished is true. + def execute( + self, + additional_http_headers: Optional[Dict[str, Any]] = None, + deferred_fetch: bool = False, + ) -> TrinoResult: + """Initiate a Trino query by sending the SQL statement to the coordinator. + To fetch all rows, call fetch() until finished is true. + + Parameters: + additional_http_headers: extra headers send to the Trino server. + deferred_fetch: By default, the execution is blocked until at least one row is received + or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True. """ if self.cancelled: raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) @@ -802,9 +807,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received or query is finished or cancelled - while not self.finished and not self.cancelled and len(self._result.rows) == 0: - self._result.rows += self.fetch() + if not deferred_fetch: + # Execute should block until at least one row is received or query is finished or cancelled + while not self.finished and not self.cancelled and len(self._result.rows) == 0: + self._result.rows += self.fetch() + return self._result def _update_state(self, status): diff --git a/trino/dbapi.py b/trino/dbapi.py index 62ce893b..25db342d 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None: def _generate_unique_statement_name(self): return 'st_' + uuid.uuid4().hex.replace('-', '') - def execute(self, operation, params=None): + def execute(self, operation, params=None, **kwargs: Any): + additional_http_headers = kwargs.get("additional_http_headers", None) + deferred_fetch = kwargs.get("deferred_fetch", False) + if params: assert isinstance(params, (list, tuple)), ( 'params must be a list or tuple containing the query ' @@ -575,7 +578,10 @@ def execute(self, operation, params=None): self._query = self._execute_prepared_statement( statement_name, params ) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) finally: # Send deallocate statement # At this point the query can be deallocated since it has already @@ -584,12 +590,18 @@ def execute(self, operation, params=None): self._deallocate_prepared_statement(statement_name) else: self._query = self._execute_immediate_statement(operation, params) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) else: self._query = trino.client.TrinoQuery(self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types) - self._iterator = iter(self._query.execute()) + self._iterator = iter(self._query.execute( + additional_http_headers=additional_http_headers, + deferred_fetch=deferred_fetch, + )) return self def executemany(self, operation, seq_of_params): diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index febfac59..0e5ec62e 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -387,7 +387,7 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: def do_execute( self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None ): - cursor.execute(statement, parameters) + cursor.execute(statement, parameters, **context.execution_options) def do_rollback(self, dbapi_connection: trino_dbapi.Connection): if dbapi_connection.transaction is not None: