diff --git a/integration_tests/test_dbapi.py b/integration_tests/test_dbapi.py index dded9411..0d53ae90 100644 --- a/integration_tests/test_dbapi.py +++ b/integration_tests/test_dbapi.py @@ -68,6 +68,21 @@ def test_select_query_result_iteration(presto_connection): assert len(list(rows0)) == len(rows1) +def test_select_cursor_iteration(presto_connection): + cur0 = presto_connection.cursor() + cur0.execute("select nationkey from tpch.sf1.nation") + rows0 = [] + for row in cur0: + rows0.append(row) + + cur1 = presto_connection.cursor() + cur1.execute("select nationkey from tpch.sf1.nation") + rows1 = cur1.fetchall() + + assert len(rows0) == len(rows1) + assert sorted(rows0) == sorted(rows1) + + def test_select_query_no_result(presto_connection): cur = presto_connection.cursor() cur.execute("select * from system.runtime.nodes where false") diff --git a/presto/dbapi.py b/presto/dbapi.py index 24633e98..5b43fa24 100644 --- a/presto/dbapi.py +++ b/presto/dbapi.py @@ -189,6 +189,9 @@ def __init__(self, connection, request): self._iterator = None self._query = None + def __iter__(self): + return self._iterator + @property def connection(self): return self._connection