Skip to content

Commit

Permalink
feat: cursor.rowcount now returns count
Browse files Browse the repository at this point in the history
  • Loading branch information
tekumara committed Jan 28, 2024
1 parent 69f00a6 commit 8a8264e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
13 changes: 8 additions & 5 deletions fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self._arraysize = 1
self._arrow_table = None
self._arrow_table_fetch_index = None
self._rowcount = None
self._converter = snowflake.connector.converter.SnowflakeConverter()

def __enter__(self) -> Self:
Expand Down Expand Up @@ -134,6 +135,8 @@ def _execute(
**kwargs: Any,
) -> FakeSnowflakeCursor:
self._arrow_table = None
self._arrow_table_fetch_index = None
self._rowcount = None

command, params = self._rewrite_with_params(command, params)
expression = parse_one(command, read="snowflake")
Expand Down Expand Up @@ -218,6 +221,7 @@ def _execute(
else:
raise e

effective_count = None
if cmd == "USE DATABASE" and (ident := expression.find(exp.Identifier)) and isinstance(ident.this, str):
self._conn.database = ident.this.upper()
self._conn.database_set = True
Expand Down Expand Up @@ -252,8 +256,8 @@ def _execute(
self._conn.schema = None

elif cmd == "INSERT":
(count,) = self._duck_conn.fetchall()[0]
result_sql = SQL_INSERTED_ROWS.substitute(count=count)
(effective_count,) = self._duck_conn.fetchall()[0]
result_sql = SQL_INSERTED_ROWS.substitute(count=effective_count)

elif cmd == "DESCRIBE TABLE":
# DESCRIBE TABLE has already been run above to detect and error if the table exists
Expand Down Expand Up @@ -283,7 +287,7 @@ def _execute(
self._duck_conn.execute(result_sql)

self._arrow_table = self._duck_conn.fetch_arrow_table()
self._arrow_table_fetch_index = None
self._rowcount = effective_count or self._arrow_table.num_rows

self._last_sql = result_sql or sql
self._last_params = params
Expand Down Expand Up @@ -347,8 +351,7 @@ def get_result_batches(self) -> list[ResultBatch] | None:

@property
def rowcount(self) -> int | None:
# TODO: return number of rows updated/inserted (using returning)
return None
return self._rowcount

@property
def sfqid(self) -> str | None:
Expand Down
25 changes: 12 additions & 13 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_connect_reuse_db():
def test_connect_without_database(_fakesnow_no_auto_create: None):
with snowflake.connector.connect() as conn, conn.cursor() as cur:
with pytest.raises(snowflake.connector.errors.ProgrammingError) as excinfo:
cur.execute("SELECT * FROM customers")
cur.execute("select * from customers")

# actual snowflake error message is:
#
Expand All @@ -129,7 +129,7 @@ def test_connect_without_database(_fakesnow_no_auto_create: None):
# )

with pytest.raises(snowflake.connector.errors.ProgrammingError) as excinfo:
cur.execute("SELECT * FROM jaffles.customers")
cur.execute("select * from jaffles.customers")

assert (
"090105 (22000): Cannot perform SELECT. This session does not have a current database. Call 'USE DATABASE', or use a qualified name."
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_connect_without_schema(_fakesnow: None):
assert not conn.schema

with pytest.raises(snowflake.connector.errors.ProgrammingError) as excinfo:
cur.execute("SELECT * FROM customers")
cur.execute("select * from customers")

# actual snowflake error message is:
#
Expand Down Expand Up @@ -862,19 +862,18 @@ def test_random(cur: snowflake.connector.cursor.SnowflakeCursor):


def test_rowcount(cur: snowflake.connector.cursor.SnowflakeCursor):
cur.execute("create table example(id int)")
cur.execute("insert into example SELECT * FROM (VALUES (1), (2), (3));")
# TODO: rows inserted ie: 3
assert cur.rowcount is None
# TODO: selected rows
# cur.execute("SELECT * FROM example where id > 1")
# assert cur.rowcount == 2
cur.execute("create or replace table example(id int)")
cur.execute("insert into example select * from (VALUES (1), (2), (3), (4));")
assert cur.rowcount == 4
cur.execute("select * from example where id > 1")
assert cur.rowcount == 3


def test_sample(cur: snowflake.connector.cursor.SnowflakeCursor):
cur.execute("create table example(id int)")
cur.execute("insert into example SELECT * FROM (VALUES (1), (2), (3), (4));")
cur.execute("SELECT * FROM example SAMPLE (50) SEED (420)")
cur.execute("insert into example select * from (VALUES (1), (2), (3), (4));")
cur.execute("select * from example SAMPLE (50) SEED (420)")
# sampling small sizes isn't exact
assert cur.fetchall() == [(1,), (2,), (3,)]

Expand Down Expand Up @@ -1071,7 +1070,7 @@ def test_transactions(conn: snowflake.connector.SnowflakeConnection):
# transactions are per session, cursors are just different result sets,
# so a new cursor will see the uncommitted values
with conn.cursor() as cur:
cur.execute("SELECT * FROM table1")
cur.execute("select * from table1")
assert cur.fetchall() == [(2,)]

conn.commit()
Expand Down Expand Up @@ -1134,7 +1133,7 @@ def test_use_invalid_schema(_fakesnow: None):

def test_values(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor(snowflake.connector.cursor.DictCursor) as cur:
cur.execute("SELECT * FROM VALUES ('Amsterdam', 1), ('London', 2)")
cur.execute("select * from VALUES ('Amsterdam', 1), ('London', 2)")

assert cur.fetchall() == [
{"COLUMN1": "Amsterdam", "COLUMN2": 1},
Expand Down

0 comments on commit 8a8264e

Please sign in to comment.