From 5086c86f4cec1b0b513ab9b3c5f5280a7d0973e5 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Wed, 18 Jan 2023 15:18:39 +0100 Subject: [PATCH 1/3] Refactor TrinoStatus into `@dataclass` --- trino/client.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/trino/client.py b/trino/client.py index 3ab35b09..016fd4b4 100644 --- a/trino/client.py +++ b/trino/client.py @@ -44,6 +44,7 @@ import threading import urllib.parse import warnings +from dataclasses import dataclass from datetime import date, datetime, time, timedelta, timezone, tzinfo from decimal import Decimal from time import sleep @@ -290,16 +291,16 @@ def get_roles_values(headers, header): ] -class TrinoStatus(object): - def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): - self.id = id - self.stats = stats - self.warnings = warnings - self.info_uri = info_uri - self.next_uri = next_uri - self.update_type = update_type - self.rows = rows - self.columns = columns +@dataclass +class TrinoStatus: + id: str + stats: Dict[str, str] + warnings: List[Any] + info_uri: str + next_uri: Optional[str] + update_type: Optional[str] + rows: List[Any] + columns: List[Any] def __repr__(self): return ( From 6ef81169aae22062023192426c4d29f57aecb94f Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Thu, 19 Jan 2023 16:26:29 +0100 Subject: [PATCH 2/3] Add TestTable utility class --- tests/integration/test_dbapi_integration.py | 24 ++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index ae325778..33cd109a 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -10,8 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import uuid from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal +from typing import Tuple import pytest import pytz @@ -21,7 +23,7 @@ import trino from tests.integration.conftest import trino_version from trino import constants -from trino.dbapi import DescribeOutput +from trino.dbapi import Cursor, DescribeOutput from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError from trino.transaction import IsolationLevel @@ -1273,3 +1275,23 @@ def assert_cursor_description(cur, trino_type, size=None, precision=None, scale= assert cur.description[0][4] is precision assert cur.description[0][5] is scale assert cur.description[0][6] is None + + +class _TestTable: + def __init__(self, conn, table_name_prefix, table_definition) -> None: + self._conn = conn + self._table_name = table_name_prefix + '_' + str(uuid.uuid4().hex) + self._table_definition = table_definition + + def __enter__(self) -> Tuple["_TestTable", Cursor]: + return ( + self, + self._conn.cursor().execute(f"CREATE TABLE {self._table_name} {self._table_definition}") + ) + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self._conn.cursor().execute(f"DROP TABLE {self._table_name}") + + @property + def table_name(self): + return self._table_name From f8f9aa88c245647f2c23b954fa250f471502211d Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Wed, 18 Jan 2023 15:19:36 +0100 Subject: [PATCH 3/3] Return row count if available in status --- tests/integration/test_dbapi_integration.py | 26 +++++++++++++++++++++ trino/client.py | 8 +++++++ trino/dbapi.py | 17 ++++++++++---- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 33cd109a..dd37a14b 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1268,6 +1268,32 @@ def test_describe_table_query(run_trino): ] +def test_rowcount_select(trino_connection): + cur = trino_connection.cursor() + cur.execute("SELECT 1 as a") + cur.fetchall() + assert cur.rowcount == -1 + + +def test_rowcount_create_table(trino_connection): + with _TestTable(trino_connection, "memory.default.test_rowcount_create_table", "(a varchar)") as (_, cur): + assert cur.rowcount == -1 + + +def test_rowcount_create_table_as_select(trino_connection): + with _TestTable( + trino_connection, + "memory.default.test_rowcount_ctas", "AS SELECT 1 a UNION ALL SELECT 2" + ) as (_, cur): + assert cur.rowcount == 2 + + +def test_rowcount_insert(trino_connection): + with _TestTable(trino_connection, "memory.default.test_rowcount_ctas", "(a VARCHAR)") as (table, cur): + cur.execute(f"INSERT INTO {table.table_name} (a) VALUES ('test')") + assert cur.rowcount == 1 + + def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None): assert cur.description[0][1] == trino_type assert cur.description[0][2] is None diff --git a/trino/client.py b/trino/client.py index 016fd4b4..a645a589 100644 --- a/trino/client.py +++ b/trino/client.py @@ -299,6 +299,7 @@ class TrinoStatus: info_uri: str next_uri: Optional[str] update_type: Optional[str] + update_count: Optional[int] rows: List[Any] columns: List[Any] @@ -666,6 +667,7 @@ def process(self, http_response) -> TrinoStatus: info_uri=response["infoUri"], next_uri=self._next_uri, update_type=response.get("updateType"), + update_count=response.get("updateCount"), rows=response.get("data", []), columns=response.get("columns"), ) @@ -743,6 +745,7 @@ def __init__( self._cancelled = False self._request = request self._update_type = None + self._update_count = None self._sql = sql self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types @@ -765,6 +768,10 @@ def stats(self): def update_type(self): return self._update_type + @property + def update_count(self): + return self._update_count + @property def warnings(self): return self._warnings @@ -809,6 +816,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: def _update_state(self, status): self._stats.update(status.stats) self._update_type = status.update_type + self._update_count = status.update_count if not self._row_mapper and status.columns: self._row_mapper = RowMapperFactory().create(columns=status.columns, legacy_primitive_types=self._legacy_primitive_types) diff --git a/trino/dbapi.py b/trino/dbapi.py index 6cb2a97a..d8ae5f72 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -315,13 +315,20 @@ def description(self) -> List[ColumnDescription]: @property def rowcount(self): - """Not supported. + """The rowcount will be returned for INSERT, UPDATE, DELETE, MERGE + and CTAS statements based on `update_count` returned by the Trino + API. - Trino cannot reliablity determine the number of rows returned by an - operation. For example, the result of a SELECT query is streamed and - the number of rows is only knowns when all rows have been retrieved. - """ + If the rowcount can't be determined, -1 will be returned. + + Trino cannot reliably determine the number of rows returned for DQL + queries. For example, the result of a SELECT query is streamed and + the number of rows is only known when all rows have been retrieved. + See https://peps.python.org/pep-0249/#rowcount + """ + if self._query is not None and self._query.update_count is not None: + return self._query.update_count return -1 @property