diff --git a/python/tests/test_cursor.py b/python/tests/test_cursor.py index 07fca375..b9546f22 100644 --- a/python/tests/test_cursor.py +++ b/python/tests/test_cursor.py @@ -168,8 +168,7 @@ async def test_cursor_as_async_manager( querystring=f"SELECT * FROM {table_name}", fetch_number=fetch_number, ) as cursor: - async for result in cursor: - all_results.append(result) # noqa: PERF401 + all_results.extend([result async for result in cursor]) assert len(all_results) == expected_num_results diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index f8f8ade5..cb53bad7 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -1,4 +1,5 @@ import datetime +import sys import uuid from decimal import Decimal from enum import Enum @@ -57,6 +58,7 @@ from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass +uuid_ = uuid.uuid4() pytestmark = pytest.mark.anyio now_datetime = datetime.datetime.now() # noqa: DTZ005 now_datetime_with_tz = datetime.datetime( @@ -69,7 +71,30 @@ 142574, tzinfo=datetime.timezone.utc, ) -uuid_ = uuid.uuid4() + +now_datetime_with_tz_in_asia_jakarta = datetime.datetime( + 2024, + 4, + 13, + 17, + 3, + 46, + 142574, + tzinfo=datetime.timezone.utc, +) +if sys.version_info >= (3, 9): + import zoneinfo + + now_datetime_with_tz_in_asia_jakarta = datetime.datetime( + 2024, + 4, + 13, + 17, + 3, + 46, + 142574, + tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"), + ) async def test_as_class( @@ -125,6 +150,7 @@ async def test_as_class( ("TIME", now_datetime.time(), now_datetime.time()), ("TIMESTAMP", now_datetime, now_datetime), ("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz), + ("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta), ("UUID", uuid_, str(uuid_)), ("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")), ( @@ -287,6 +313,11 @@ async def test_as_class( [now_datetime_with_tz, now_datetime_with_tz], [now_datetime_with_tz, now_datetime_with_tz], ), + ( + "TIMESTAMPTZ ARRAY", + [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], + [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], + ), ( "TIMESTAMPTZ ARRAY", [[now_datetime_with_tz], [now_datetime_with_tz]], diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 23e86a44..97dc66a1 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -127,15 +127,6 @@ impl Connection { #[pymethods] impl Connection { - #[must_use] - pub fn __aiter__(self_: Py) -> Py { - self_ - } - - fn __await__(self_: Py) -> Py { - self_ - } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 3f8008be..74e353b2 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -13,15 +13,15 @@ use crate::{ }; /// Additional implementation for the `Object` type. -#[allow(clippy::ref_option)] +#[allow(clippy::ref_option_ref)] trait CursorObjectTrait { async fn cursor_start( &self, cursor_name: &str, - scroll: &Option, + scroll: Option<&bool>, querystring: &str, - prepared: &Option, - parameters: &Option>, + prepared: Option<&bool>, + parameters: Option<&Py>, ) -> RustPSQLDriverPyResult<()>; async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; @@ -34,14 +34,14 @@ impl CursorObjectTrait for Object { /// /// # Errors /// May return Err Result if cannot execute querystring. - #[allow(clippy::ref_option)] + #[allow(clippy::ref_option_ref)] async fn cursor_start( &self, cursor_name: &str, - scroll: &Option, + scroll: Option<&bool>, querystring: &str, - prepared: &Option, - parameters: &Option>, + prepared: Option<&bool>, + parameters: Option<&Py>, ) -> RustPSQLDriverPyResult<()> { let mut cursor_init_query = format!("DECLARE {cursor_name}"); if let Some(scroll) = scroll { diff --git a/src/query_result.rs b/src/query_result.rs index 06299b86..c4025ee3 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -10,11 +10,11 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po /// May return Err Result if can not convert /// postgres type to python or set new key-value pair /// in python dict. -#[allow(clippy::ref_option)] +#[allow(clippy::ref_option_ref)] fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let python_dict = PyDict::new_bound(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { diff --git a/src/value_converter.rs b/src/value_converter.rs index 74020232..b3b252fe 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -1,4 +1,5 @@ -use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone}; +use chrono_tz::Tz; use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect}; use itertools::Itertools; use macaddr::{MacAddr6, MacAddr8}; @@ -626,8 +627,7 @@ impl ToSql for PythonDTO { #[allow(clippy::needless_pass_by_value)] pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult> { let mut result_vec: Vec = vec![]; - - result_vec = Python::with_gil(|gil| { + Python::with_gil(|gil| { let params = parameters.extract::>>(gil).map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "Cannot convert you parameters argument into Rust type, please use List/Tuple" @@ -637,8 +637,9 @@ pub fn convert_parameters(parameters: Py) -> RustPSQLDriverPyResult, RustPSQLDriverError>(result_vec) + Ok::<(), RustPSQLDriverError>(()) })?; + Ok(result_vec) } @@ -744,6 +745,84 @@ pub fn py_sequence_into_postgres_array( } } +/// Extract a value from a Python object, raising an error if missing or invalid +/// +/// # Type Parameters +/// - `T`: The type to which the attribute's value will be converted. This type must implement the `FromPyObject` trait +/// +/// # Errors +/// This function will return `Err` in the following cases: +/// - The Python object does not have the specified attribute +/// - The attribute exists but cannot be extracted into the specified Rust type +fn extract_value_from_python_object_or_raise<'py, T>( + parameter: &'py pyo3::Bound<'_, PyAny>, + attr_name: &str, +) -> Result +where + T: FromPyObject<'py>, +{ + parameter + .getattr(attr_name) + .ok() + .and_then(|attr| attr.extract::().ok()) + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError("Invalid attribute".into()) + }) +} + +/// Extract a timezone-aware datetime from a Python object. +/// This function retrieves various datetime components (`year`, `month`, `day`, etc.) +/// from a Python object and constructs a `DateTime` +/// +/// # Errors +/// This function will return `Err` in the following cases: +/// - The Python object does not contain or support one or more required datetime attributes +/// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day) +/// - The timezone information (`tzinfo`) is not available or cannot be parsed +/// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions) +fn extract_datetime_from_python_object_attrs( + parameter: &pyo3::Bound<'_, PyAny>, +) -> Result, RustPSQLDriverError> { + let year = extract_value_from_python_object_or_raise::(parameter, "year")?; + let month = extract_value_from_python_object_or_raise::(parameter, "month")?; + let day = extract_value_from_python_object_or_raise::(parameter, "day")?; + let hour = extract_value_from_python_object_or_raise::(parameter, "hour")?; + let minute = extract_value_from_python_object_or_raise::(parameter, "minute")?; + let second = extract_value_from_python_object_or_raise::(parameter, "second")?; + let microsecond = extract_value_from_python_object_or_raise::(parameter, "microsecond")?; + + let date = NaiveDate::from_ymd_opt(year, month, day) + .ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid date".into()))?; + let time = NaiveTime::from_hms_micro_opt(hour, minute, second, microsecond) + .ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid time".into()))?; + let naive_datetime = NaiveDateTime::new(date, time); + + let raw_timestamp_tz = parameter + .getattr("tzinfo") + .ok() + .and_then(|tzinfo| tzinfo.getattr("key").ok()) + .and_then(|key| key.extract::().ok()) + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError("Invalid timezone info".into()) + })?; + + let fixed_offset_datetime = raw_timestamp_tz + .parse::() + .map_err(|_| { + RustPSQLDriverError::PyToRustValueConversionError("Failed to parse TZ".into()) + })? + .from_local_datetime(&naive_datetime) + .single() + .ok_or_else(|| { + RustPSQLDriverError::PyToRustValueConversionError( + "Ambiguous or invalid datetime".into(), + ) + })? + .fixed_offset(); + + Ok(fixed_offset_datetime) +} + /// Convert single python parameter to `PythonDTO` enum. /// /// # Errors @@ -849,6 +928,11 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); } + let timestamp_tz = extract_datetime_from_python_object_attrs(parameter); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + return Err(RustPSQLDriverError::PyToRustValueConversionError( "Can not convert you datetime to rust type".into(), )); @@ -1655,7 +1739,7 @@ pub fn other_postgres_bytes_to_py( } Err(RustPSQLDriverError::RustToPyValueConversionError( - format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.") + format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.") )) } @@ -1668,7 +1752,7 @@ pub fn composite_postgres_to_py( py: Python<'_>, fields: &Vec, buf: &mut &[u8], - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py); @@ -1737,7 +1821,7 @@ pub fn raw_bytes_data_process( raw_bytes_data: &mut &[u8], column_name: &str, column_type: &Type, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { if let Some(custom_decoders) = custom_decoders { let py_encoder_func = custom_decoders @@ -1776,7 +1860,7 @@ pub fn postgres_to_py( row: &Row, column: &Column, column_i: usize, - custom_decoders: &Option>, + custom_decoders: Option<&Py>, ) -> RustPSQLDriverPyResult> { let raw_bytes_data = row.col_buffer(column_i); if let Some(mut raw_bytes_data) = raw_bytes_data {