Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import uuid
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Tuple

import pytest
import pytz
Expand All @@ -21,7 +23,7 @@
import trino
from tests.integration.conftest import trino_version
from trino import constants
from trino.dbapi import DescribeOutput
from trino.dbapi import Cursor, DescribeOutput
from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError
from trino.transaction import IsolationLevel

Expand Down Expand Up @@ -1266,10 +1268,56 @@ def test_describe_table_query(run_trino):
]


def test_rowcount_select(trino_connection):
cur = trino_connection.cursor()
cur.execute("SELECT 1 as a")
cur.fetchall()
assert cur.rowcount == -1


def test_rowcount_create_table(trino_connection):
with _TestTable(trino_connection, "memory.default.test_rowcount_create_table", "(a varchar)") as (_, cur):
assert cur.rowcount == -1


def test_rowcount_create_table_as_select(trino_connection):
with _TestTable(
trino_connection,
"memory.default.test_rowcount_ctas", "AS SELECT 1 a UNION ALL SELECT 2"
) as (_, cur):
assert cur.rowcount == 2


def test_rowcount_insert(trino_connection):
with _TestTable(trino_connection, "memory.default.test_rowcount_ctas", "(a VARCHAR)") as (table, cur):
cur.execute(f"INSERT INTO {table.table_name} (a) VALUES ('test')")
assert cur.rowcount == 1


def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
assert cur.description[0][1] == trino_type
assert cur.description[0][2] is None
assert cur.description[0][3] is size
assert cur.description[0][4] is precision
assert cur.description[0][5] is scale
assert cur.description[0][6] is None


class _TestTable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a utils module under test directory? I see some other such classes and methods which we be useful across tests.

def __init__(self, conn, table_name_prefix, table_definition) -> None:
self._conn = conn
self._table_name = table_name_prefix + '_' + str(uuid.uuid4().hex)
self._table_definition = table_definition

def __enter__(self) -> Tuple["_TestTable", Cursor]:
return (
self,
self._conn.cursor().execute(f"CREATE TABLE {self._table_name} {self._table_definition}")
)

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self._conn.cursor().execute(f"DROP TABLE {self._table_name}")

@property
def table_name(self):
return self._table_name
29 changes: 19 additions & 10 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import threading
import urllib.parse
import warnings
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from time import sleep
Expand Down Expand Up @@ -290,16 +291,17 @@ def get_roles_values(headers, header):
]


class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
self.id = id
self.stats = stats
self.warnings = warnings
self.info_uri = info_uri
self.next_uri = next_uri
self.update_type = update_type
self.rows = rows
self.columns = columns
@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. ❤️

class TrinoStatus:
id: str
stats: Dict[str, str]
warnings: List[Any]
info_uri: str
next_uri: Optional[str]
update_type: Optional[str]
update_count: Optional[int]
rows: List[Any]
columns: List[Any]

def __repr__(self):
return (
Expand Down Expand Up @@ -665,6 +667,7 @@ def process(self, http_response) -> TrinoStatus:
info_uri=response["infoUri"],
next_uri=self._next_uri,
update_type=response.get("updateType"),
update_count=response.get("updateCount"),
rows=response.get("data", []),
columns=response.get("columns"),
)
Expand Down Expand Up @@ -742,6 +745,7 @@ def __init__(
self._cancelled = False
self._request = request
self._update_type = None
self._update_count = None
self._sql = sql
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
Expand All @@ -764,6 +768,10 @@ def stats(self):
def update_type(self):
return self._update_type

@property
def update_count(self):
return self._update_count

@property
def warnings(self):
return self._warnings
Expand Down Expand Up @@ -808,6 +816,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
def _update_state(self, status):
self._stats.update(status.stats)
self._update_type = status.update_type
self._update_count = status.update_count
if not self._row_mapper and status.columns:
self._row_mapper = RowMapperFactory().create(columns=status.columns,
legacy_primitive_types=self._legacy_primitive_types)
Expand Down
17 changes: 12 additions & 5 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,20 @@ def description(self) -> List[ColumnDescription]:

@property
def rowcount(self):
"""Not supported.
"""The rowcount will be returned for INSERT, UPDATE, DELETE, MERGE
and CTAS statements based on `update_count` returned by the Trino
API.

Trino cannot reliablity determine the number of rows returned by an
operation. For example, the result of a SELECT query is streamed and
the number of rows is only knowns when all rows have been retrieved.
"""
If the rowcount can't be determined, -1 will be returned.

Trino cannot reliably determine the number of rows returned for DQL
queries. For example, the result of a SELECT query is streamed and
the number of rows is only known when all rows have been retrieved.

See https://peps.python.org/pep-0249/#rowcount
"""
if self._query is not None and self._query.update_count is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above can be updated to reflect the change.

return self._query.update_count
return -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not obvious so maybe a link to https://peps.python.org/pep-0249/#rowcount here would be useful i.e. the -1 is what the DB-API requires.


@property
Expand Down