diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 4e3d40df..09d496bd 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -102,8 +102,19 @@ def test_select_query_result_iteration(trino_connection): assert len(list(rows0)) == len(rows1) -def test_select_query_result_iteration_statement_params(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_select_query_result_iteration_statement_params(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) + cur.execute( """ SELECT * FROM ( @@ -118,10 +129,25 @@ def test_select_query_result_iteration_statement_params(trino_connection): """, params=(3,) # expecting all the rows with id >= 3 ) + rows = cur.fetchall() + assert len(rows) == 3 + assert [3, 'three', 'c'] in rows + assert [4, 'four', 'd'] in rows + assert [5, 'five', 'e'] in rows -def test_none_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_none_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(None,)) rows = cur.fetchall() @@ -129,8 +155,18 @@ def test_none_query_param(trino_connection): assert_cursor_description(cur, trino_type="unknown") -def test_string_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_string_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=("six'",)) rows = cur.fetchall() @@ -139,9 +175,20 @@ def test_string_query_param(trino_connection): assert_cursor_description(cur, trino_type="varchar(4)", size=4) -def test_execute_many(trino_connection): +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_execute_many(legacy_prepared_statements, run_trino): + try: - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)") cur.fetchall() operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)" @@ -163,13 +210,24 @@ def test_execute_many(trino_connection): assert rows[1] == [2, "value2"] assert rows[2] == [3, "value3"] finally: - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many") -def test_execute_many_without_params(trino_connection): +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_execute_many_without_params(legacy_prepared_statements, run_trino): + try: - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)") cur.fetchall() with pytest.raises(TrinoUserError) as e: @@ -177,12 +235,22 @@ def test_execute_many_without_params(trino_connection): cur.fetchall() assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value) finally: - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("DROP TABLE IF EXISTS memory.default.test_execute_many_without_param") -def test_execute_many_select(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_execute_many_select(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) with pytest.raises(NotSupportedError) as e: cur.executemany("SELECT ?, ?", [(1, "value1"), (2, "value2")]) assert "Query must return update type" in str(e.value) @@ -255,8 +323,18 @@ def test_legacy_primitive_types_with_connection_and_cursor( assert rows[0][6] == '-2001-08-22' -def test_decimal_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_decimal_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(Decimal('1112.142857'),)) rows = cur.fetchall() @@ -265,8 +343,18 @@ def test_decimal_query_param(trino_connection): assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6) -def test_decimal_scientific_notation_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_decimal_scientific_notation_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(Decimal('0E-10'),)) rows = cur.fetchall() @@ -294,8 +382,18 @@ def test_null_decimal(trino_connection): assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) -def test_biggest_decimal(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_biggest_decimal(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = Decimal('99999999999999999999999999999999999999') cur.execute("SELECT ?", params=(params,)) @@ -305,8 +403,18 @@ def test_biggest_decimal(trino_connection): assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) -def test_smallest_decimal(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_smallest_decimal(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = Decimal('-99999999999999999999999999999999999999') cur.execute("SELECT ?", params=(params,)) @@ -316,8 +424,18 @@ def test_smallest_decimal(trino_connection): assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) -def test_highest_precision_decimal(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_highest_precision_decimal(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = Decimal('0.99999999999999999999999999999999999999') cur.execute("SELECT ?", params=(params,)) @@ -327,8 +445,18 @@ def test_highest_precision_decimal(trino_connection): assert_cursor_description(cur, trino_type="decimal(38, 38)", precision=38, scale=38) -def test_datetime_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_datetime_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = datetime(2020, 1, 1, 16, 43, 22, 320000) @@ -339,8 +467,18 @@ def test_datetime_query_param(trino_connection): assert_cursor_description(cur, trino_type="timestamp(6)", precision=6) -def test_datetime_with_utc_time_zone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_datetime_with_utc_time_zone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('UTC')) @@ -351,8 +489,18 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection): assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6) -def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_datetime_with_numeric_offset_time_zone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) tz = timezone(-timedelta(hours=5, minutes=30)) @@ -365,8 +513,18 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6) -def test_datetime_with_named_time_zone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_datetime_with_named_time_zone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('America/Los_Angeles')) @@ -407,8 +565,18 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection): assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3) -def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_datetimes_with_time_zone_in_dst_gap_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) # This is a datetime that lies within a DST transition and not actually exists. params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo('Europe/Brussels')) @@ -417,11 +585,21 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection): cur.fetchall() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) @pytest.mark.parametrize('fold', [0, 1]) -def test_doubled_datetimes(trino_connection, fold): +def test_doubled_datetimes(fold, legacy_prepared_statements, run_trino): # Trino doesn't distinguish between doubled datetimes that lie within a DST transition. # See also https://github.com/trinodb/trino/issues/5781 - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'), fold=fold) @@ -431,8 +609,18 @@ def test_doubled_datetimes(trino_connection, fold): assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern')) -def test_date_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_date_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = datetime(2020, 1, 1, 0, 0, 0).date() @@ -469,8 +657,18 @@ def test_unsupported_python_dates(trino_connection): cur.fetchall() -def test_supported_special_dates_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_supported_special_dates_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) for params in ( # min python date @@ -510,8 +708,18 @@ def test_char(trino_connection): assert_cursor_description(cur, trino_type="char(5)", size=5) -def test_time_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_time_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = time(12, 3, 44, 333000) @@ -522,8 +730,18 @@ def test_time_query_param(trino_connection): assert_cursor_description(cur, trino_type="time(6)", precision=6) -def test_time_with_named_time_zone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_time_with_named_time_zone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = time(16, 43, 22, 320000, tzinfo=ZoneInfo('Asia/Shanghai')) @@ -534,8 +752,18 @@ def test_time_with_named_time_zone_query_param(trino_connection): assert rows[0][0].tzinfo == timezone(timedelta(seconds=28800)) -def test_time_with_numeric_offset_time_zone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_time_with_numeric_offset_time_zone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) tz = timezone(-timedelta(hours=8, minutes=0)) params = time(16, 43, 22, 320000, tzinfo=tz) @@ -600,6 +828,16 @@ def test_null_date_with_time_zone(trino_connection): assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3) +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) @pytest.mark.parametrize( "binary_input", [ @@ -610,8 +848,8 @@ def test_null_date_with_time_zone(trino_connection): bytearray([1, 2, 3]), ], ) -def test_binary_query_param(trino_connection, binary_input): - cur = trino_connection.cursor() +def test_binary_query_param(binary_input, legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(binary_input,)) rows = cur.fetchall() @@ -619,8 +857,18 @@ def test_binary_query_param(trino_connection, binary_input): assert rows[0][0] == binary_input -def test_array_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_array_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=([1, 2, 3],)) rows = cur.fetchall() @@ -638,8 +886,18 @@ def test_array_query_param(trino_connection): assert rows[0][0] == "array(integer)" -def test_array_none_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_array_none_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = [None, None] @@ -654,8 +912,18 @@ def test_array_none_query_param(trino_connection): assert rows[0][0] == "array(unknown)" -def test_array_none_and_another_type_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_array_none_and_another_type_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = [None, 1] @@ -670,8 +938,18 @@ def test_array_none_and_another_type_query_param(trino_connection): assert rows[0][0] == "array(integer)" -def test_array_timestamp_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_array_timestamp_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)] @@ -686,8 +964,18 @@ def test_array_timestamp_query_param(trino_connection): assert rows[0][0] == "array(timestamp(6))" -def test_array_timestamp_with_timezone_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_array_timestamp_with_timezone_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = [ datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo('UTC')), @@ -705,8 +993,18 @@ def test_array_timestamp_with_timezone_query_param(trino_connection): assert rows[0][0] == "array(timestamp(6) with time zone)" -def test_dict_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_dict_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=({"foo": "bar"},)) rows = cur.fetchall() @@ -719,8 +1017,18 @@ def test_dict_query_param(trino_connection): assert rows[0][0] == "map(varchar(3), varchar(3))" -def test_dict_timestamp_query_param_types(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_dict_timestamp_query_param_types(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = {"foo": datetime(2020, 1, 1, 16, 43, 22, 320000)} cur.execute("SELECT ?", params=(params,)) @@ -729,8 +1037,18 @@ def test_dict_timestamp_query_param_types(trino_connection): assert rows[0][0] == params -def test_boolean_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_boolean_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(True,)) rows = cur.fetchall() @@ -743,8 +1061,18 @@ def test_boolean_query_param(trino_connection): assert rows[0][0] is False -def test_row(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_row(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = (1, Decimal("2.0"), datetime(2020, 1, 1, 0, 0, 0)) cur.execute("SELECT ?", (params,)) rows = cur.fetchall() @@ -752,8 +1080,18 @@ def test_row(trino_connection): assert rows[0][0] == params -def test_nested_row(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_nested_row(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) params = ((1, "test", Decimal("3.1")), Decimal("2.0"), datetime(2020, 1, 1, 0, 0, 0)) cur.execute("SELECT ?", (params,)) rows = cur.fetchall() @@ -812,8 +1150,18 @@ def test_nested_named_row(trino_connection): assert str(rows[0][0]) == "(x: Decimal('2.30'), y: (x: 1, y: 'test'))" -def test_float_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_float_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(1.1,)) rows = cur.fetchall() @@ -821,8 +1169,18 @@ def test_float_query_param(trino_connection): assert rows[0][0] == 1.1 -def test_float_nan_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_float_nan_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(float("nan"),)) rows = cur.fetchall() @@ -831,8 +1189,18 @@ def test_float_nan_query_param(trino_connection): assert math.isnan(rows[0][0]) -def test_float_inf_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_float_inf_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(float("inf"),)) rows = cur.fetchall() @@ -845,8 +1213,18 @@ def test_float_inf_query_param(trino_connection): assert rows[0][0] == float("-inf") -def test_int_query_param(trino_connection): - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_int_query_param(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) cur.execute("SELECT ?", params=(3,)) rows = cur.fetchall() @@ -860,13 +1238,23 @@ def test_int_query_param(trino_connection): assert_cursor_description(cur, trino_type="bigint") +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) @pytest.mark.parametrize('params', [ 'NOT A LIST OR TUPPLE', {'invalid', 'params'}, object, ]) -def test_select_query_invalid_params(trino_connection, params): - cur = trino_connection.cursor() +def test_select_query_invalid_params(params, legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) with pytest.raises(AssertionError): cur.execute('SELECT ?', params=params) @@ -1036,10 +1424,20 @@ def test_transaction_autocommit(trino_connection_in_autocommit): in str(transaction_error.value) -def test_invalid_query_throws_correct_error(trino_connection): +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(False, marks=pytest.mark.skipif( + trino_version() <= '417', + reason="EXECUTE IMMEDIATE was introduced in version 418")), + None + ] +) +def test_invalid_query_throws_correct_error(legacy_prepared_statements, run_trino): """Tests that an invalid query raises the correct exception """ - cur = trino_connection.cursor() + cur = get_cursor(legacy_prepared_statements, run_trino) with pytest.raises(TrinoQueryError): cur.execute( """ @@ -1225,13 +1623,17 @@ def assert_role_headers(cursor, expected_header): assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header -def test_prepared_statements(run_trino): - _, host, port = run_trino - - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch", - ) - cur = trino_connection.cursor() +@pytest.mark.parametrize( + "legacy_prepared_statements", + [ + True, + pytest.param(None, marks=pytest.mark.skipif( + trino_version() > '417', + reason="This would use EXECUTE IMMEDIATE")) + ] +) +def test_prepared_statements(legacy_prepared_statements, run_trino): + cur = get_cursor(legacy_prepared_statements, run_trino) # Implicit prepared statements must work and deallocate statements on finish assert cur._request._client_session.prepared_statements == {} @@ -1380,6 +1782,18 @@ def test_rowcount_insert(trino_connection): assert cur.rowcount == 1 +def get_cursor(legacy_prepared_statements, run_trino): + _, host, port = run_trino + + connection = trino.dbapi.Connection( + host=host, + port=port, + user="test", + legacy_prepared_statements=legacy_prepared_statements, + ) + return connection.cursor() + + def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None): assert cur.description[0][1] == trino_type assert cur.description[0][2] is None diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 53fc0af2..5c08a154 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -23,8 +23,11 @@ @pytest.fixture def trino_connection(run_trino, request): _, host, port = run_trino + connect_args = {"source": "test", "max_attempts": 1} + if trino_version() <= '417': + connect_args["legacy_prepared_statements"] = True engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", - connect_args={"source": "test", "max_attempts": 1}) + connect_args=connect_args) yield engine, engine.connect() diff --git a/trino/dbapi.py b/trino/dbapi.py index d62b567b..bcc85e8c 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -21,7 +21,10 @@ import datetime import math import uuid +from collections import OrderedDict from decimal import Decimal +from threading import Lock +from time import time from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types from urllib.parse import urlparse @@ -78,6 +81,39 @@ logger = trino.logging.get_logger(__name__) +class TimeBoundLRUCache: + def __init__(self, capacity: int, ttl_seconds: int): + self.capacity = capacity + self.ttl_seconds = ttl_seconds + self.cache = OrderedDict() + self.lock = Lock() + + def get(self, key): + with self.lock: + if key not in self.cache: + return None + value, timestamp = self.cache[key] + if time() - timestamp > self.ttl_seconds: + self.cache.pop(key) + return None + self.cache.move_to_end(key) + self.cache[key] = value, time() + return value + + def put(self, key, value): + with self.lock: + self.cache[key] = value, time() + self.cache.move_to_end(key) + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) + + def __repr__(self): + return f"LRUCache(capacity: {self.capacity}, ttl: {self.ttl_seconds} seconds, {self.cache})" + + +legacyPreparedStatementsCache = TimeBoundLRUCache(1024, 3600) + + def connect(*args, **kwargs): """Constructor for creating a connection to the database. @@ -117,6 +153,7 @@ def __init__( http_session=None, client_tags=None, legacy_primitive_types=False, + legacy_prepared_statements=None, roles=None, timezone=None, ): @@ -162,6 +199,7 @@ def __init__( self._request = None self._transaction = None self.legacy_primitive_types = legacy_primitive_types + self.legacy_prepared_statements = legacy_prepared_statements @property def isolation_level(self): @@ -228,10 +266,31 @@ def cursor(self, legacy_primitive_types: bool = None): return Cursor( self, request, - # if legacy_primitive_types is not explicitly set in Cursor, take from Connection + # if legacy params are not explicitly set in Cursor, take them from Connection legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types ) + def _use_legacy_prepared_statements(self): + if self.legacy_prepared_statements is not None: + return self.legacy_prepared_statements + + value = legacyPreparedStatementsCache.get((self.host, self.port)) + if value is None: + try: + query = trino.client.TrinoQuery( + self._create_request(), + # The version() function was introduced in Trino version 352 + query="SELECT node_version FROM system.runtime.nodes WHERE coordinator = true") + rows = query.execute().rows + if rows: + version = rows[0][0] + value = version <= "417" + legacyPreparedStatementsCache.put((self.host, self.port), value) + except Exception: + # not updating the cache + value = False + return value + class DescribeOutput(NamedTuple): name: str @@ -280,7 +339,11 @@ class Cursor(object): """ - def __init__(self, connection, request, legacy_primitive_types: bool = False): + def __init__( + self, + connection, + request, + legacy_primitive_types: bool = False): if not isinstance(connection, Connection): raise ValueError( "connection must be a Connection object: {}".format(type(connection)) @@ -391,6 +454,18 @@ def _execute_prepared_statement( sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types) + def _execute_immediate_statement(self, statement: str, params): + """ + Binds parameters and executes a statement in one call. + + :param statement: sql to be executed. + :param params: parameters to be bound. + """ + sql = "EXECUTE IMMEDIATE '" + statement.replace("'", "''") + \ + "' USING " + ",".join(map(self._format_prepared_param, params)) + return trino.client.TrinoQuery( + self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types) + def _format_prepared_param(self, param): """ Formats parameters to be passed in an @@ -492,22 +567,26 @@ def execute(self, operation, params=None): 'parameter values' ) - statement_name = self._generate_unique_statement_name() - self._prepare_statement(operation, statement_name) - - try: - # Send execute statement and assign the return value to `results` - # as it will be returned by the function - self._query = self._execute_prepared_statement( - statement_name, params - ) + if self.connection._use_legacy_prepared_statements(): + statement_name = self._generate_unique_statement_name() + self._prepare_statement(operation, statement_name) + + try: + # Send execute statement and assign the return value to `results` + # as it will be returned by the function + self._query = self._execute_prepared_statement( + statement_name, params + ) + self._iterator = iter(self._query.execute()) + finally: + # Send deallocate statement + # At this point the query can be deallocated since it has already + # been executed + # TODO: Consider caching prepared statements if requested by caller + self._deallocate_prepared_statement(statement_name) + else: + self._query = self._execute_immediate_statement(operation, params) self._iterator = iter(self._query.execute()) - finally: - # Send deallocate statement - # At this point the query can be deallocated since it has already - # been executed - # TODO: Consider caching prepared statements if requested by caller - self._deallocate_prepared_statement(statement_name) else: self._query = trino.client.TrinoQuery(self._request, query=operation, diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index d5900119..baccbf75 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -128,6 +128,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "legacy_primitive_types" in url.query: kwargs["legacy_primitive_types"] = json.loads(unquote_plus(url.query["legacy_primitive_types"])) + if "legacy_prepared_statements" in url.query: + kwargs["legacy_prepared_statements"] = unquote_plus(url.query["legacy_prepared_statements"]) + if "verify" in url.query: kwargs["verify"] = json.loads(unquote_plus(url.query["verify"])) diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index cfa9c6b1..e69e4262 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -23,6 +23,7 @@ def _url( extra_credential: Optional[List[Tuple[str, str]]] = None, client_tags: Optional[List[str]] = None, legacy_primitive_types: Optional[bool] = None, + legacy_prepared_statements: Optional[bool] = None, access_token: Optional[str] = None, cert: Optional[str] = None, key: Optional[str] = None, @@ -85,6 +86,9 @@ def _url( if legacy_primitive_types is not None: trino_url += f"&legacy_primitive_types={json.dumps(legacy_primitive_types)}" + if legacy_prepared_statements is not None: + trino_url += f"&legacy_prepared_statements={json.dumps(legacy_prepared_statements)}" + if access_token is not None: trino_url += f"&access_token={quote_plus(access_token)}"