Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ async def test_cursor_as_async_manager(
querystring=f"SELECT * FROM {table_name}",
fetch_number=fetch_number,
) as cursor:
async for result in cursor:
all_results.append(result) # noqa: PERF401
all_results.extend([result async for result in cursor])

assert len(all_results) == expected_num_results

Expand Down
33 changes: 32 additions & 1 deletion python/tests/test_value_converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import sys
import uuid
from decimal import Decimal
from enum import Enum
Expand Down Expand Up @@ -57,6 +58,7 @@

from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass

uuid_ = uuid.uuid4()
pytestmark = pytest.mark.anyio
now_datetime = datetime.datetime.now() # noqa: DTZ005
now_datetime_with_tz = datetime.datetime(
Expand All @@ -69,7 +71,30 @@
142574,
tzinfo=datetime.timezone.utc,
)
uuid_ = uuid.uuid4()

now_datetime_with_tz_in_asia_jakarta = datetime.datetime(
2024,
4,
13,
17,
3,
46,
142574,
tzinfo=datetime.timezone.utc,
)
if sys.version_info >= (3, 9):
import zoneinfo

now_datetime_with_tz_in_asia_jakarta = datetime.datetime(
2024,
4,
13,
17,
3,
46,
142574,
tzinfo=zoneinfo.ZoneInfo(key="Asia/Jakarta"),
)


async def test_as_class(
Expand Down Expand Up @@ -125,6 +150,7 @@ async def test_as_class(
("TIME", now_datetime.time(), now_datetime.time()),
("TIMESTAMP", now_datetime, now_datetime),
("TIMESTAMPTZ", now_datetime_with_tz, now_datetime_with_tz),
("TIMESTAMPTZ", now_datetime_with_tz_in_asia_jakarta, now_datetime_with_tz_in_asia_jakarta),
("UUID", uuid_, str(uuid_)),
("INET", IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")),
(
Expand Down Expand Up @@ -287,6 +313,11 @@ async def test_as_class(
[now_datetime_with_tz, now_datetime_with_tz],
[now_datetime_with_tz, now_datetime_with_tz],
),
(
"TIMESTAMPTZ ARRAY",
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
[now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta],
),
(
"TIMESTAMPTZ ARRAY",
[[now_datetime_with_tz], [now_datetime_with_tz]],
Expand Down
9 changes: 0 additions & 9 deletions src/driver/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,6 @@ impl Connection {

#[pymethods]
impl Connection {
#[must_use]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels unnecessary.
aenter and aexit here the main methods, as I see

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, It needs to be checked)

pub fn __aiter__(self_: Py<Self>) -> Py<Self> {
self_
}

fn __await__(self_: Py<Self>) -> Py<Self> {
self_
}

async fn __aenter__<'a>(self_: Py<Self>) -> RustPSQLDriverPyResult<Py<Self>> {
let (db_client, db_pool) = pyo3::Python::with_gil(|gil| {
let self_ = self_.borrow(gil);
Expand Down
16 changes: 8 additions & 8 deletions src/driver/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ use crate::{
};

/// Additional implementation for the `Object` type.
#[allow(clippy::ref_option)]
#[allow(clippy::ref_option_ref)]
trait CursorObjectTrait {
async fn cursor_start(
&self,
cursor_name: &str,
scroll: &Option<bool>,
scroll: Option<&bool>,
querystring: &str,
prepared: &Option<bool>,
parameters: &Option<Py<PyAny>>,
prepared: Option<&bool>,
parameters: Option<&Py<PyAny>>,
) -> RustPSQLDriverPyResult<()>;

async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>;
Expand All @@ -34,14 +34,14 @@ impl CursorObjectTrait for Object {
///
/// # Errors
/// May return Err Result if cannot execute querystring.
#[allow(clippy::ref_option)]
#[allow(clippy::ref_option_ref)]
async fn cursor_start(
&self,
cursor_name: &str,
scroll: &Option<bool>,
scroll: Option<&bool>,
querystring: &str,
prepared: &Option<bool>,
parameters: &Option<Py<PyAny>>,
prepared: Option<&bool>,
parameters: Option<&Py<PyAny>>,
) -> RustPSQLDriverPyResult<()> {
let mut cursor_init_query = format!("DECLARE {cursor_name}");
if let Some(scroll) = scroll {
Expand Down
4 changes: 2 additions & 2 deletions src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po
/// May return Err Result if can not convert
/// postgres type to python or set new key-value pair
/// in python dict.
#[allow(clippy::ref_option)]
#[allow(clippy::ref_option_ref)]
fn row_to_dict<'a>(
py: Python<'a>,
postgres_row: &'a Row,
custom_decoders: &Option<Py<PyDict>>,
custom_decoders: Option<&Py<PyDict>>,
) -> RustPSQLDriverPyResult<pyo3::Bound<'a, PyDict>> {
let python_dict = PyDict::new_bound(py);
for (column_idx, column) in postgres_row.columns().iter().enumerate() {
Expand Down
100 changes: 92 additions & 8 deletions src/value_converter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone};
use chrono_tz::Tz;
use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect};
use itertools::Itertools;
use macaddr::{MacAddr6, MacAddr8};
Expand Down Expand Up @@ -626,8 +627,7 @@ impl ToSql for PythonDTO {
#[allow(clippy::needless_pass_by_value)]
pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<PythonDTO>> {
let mut result_vec: Vec<PythonDTO> = vec![];

result_vec = Python::with_gil(|gil| {
Python::with_gil(|gil| {
let params = parameters.extract::<Vec<Py<PyAny>>>(gil).map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError(
"Cannot convert you parameters argument into Rust type, please use List/Tuple"
Expand All @@ -637,8 +637,9 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
for parameter in params {
result_vec.push(py_to_rust(parameter.bind(gil))?);
}
Ok::<Vec<PythonDTO>, RustPSQLDriverError>(result_vec)
Ok::<(), RustPSQLDriverError>(())
})?;

Ok(result_vec)
}

Expand Down Expand Up @@ -744,6 +745,84 @@ pub fn py_sequence_into_postgres_array(
}
}

/// Extract a value from a Python object, raising an error if missing or invalid
///
/// # Type Parameters
/// - `T`: The type to which the attribute's value will be converted. This type must implement the `FromPyObject` trait
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not have the specified attribute
/// - The attribute exists but cannot be extracted into the specified Rust type
fn extract_value_from_python_object_or_raise<'py, T>(
parameter: &'py pyo3::Bound<'_, PyAny>,
attr_name: &str,
) -> Result<T, RustPSQLDriverError>
where
T: FromPyObject<'py>,
{
parameter
.getattr(attr_name)
.ok()
.and_then(|attr| attr.extract::<T>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid attribute".into())
})
}

/// Extract a timezone-aware datetime from a Python object.
/// This function retrieves various datetime components (`year`, `month`, `day`, etc.)
/// from a Python object and constructs a `DateTime<FixedOffset>`
///
/// # Errors
/// This function will return `Err` in the following cases:
/// - The Python object does not contain or support one or more required datetime attributes
/// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day)
/// - The timezone information (`tzinfo`) is not available or cannot be parsed
/// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions)
fn extract_datetime_from_python_object_attrs(
parameter: &pyo3::Bound<'_, PyAny>,
) -> Result<DateTime<FixedOffset>, RustPSQLDriverError> {
let year = extract_value_from_python_object_or_raise::<i32>(parameter, "year")?;
let month = extract_value_from_python_object_or_raise::<u32>(parameter, "month")?;
let day = extract_value_from_python_object_or_raise::<u32>(parameter, "day")?;
let hour = extract_value_from_python_object_or_raise::<u32>(parameter, "hour")?;
let minute = extract_value_from_python_object_or_raise::<u32>(parameter, "minute")?;
let second = extract_value_from_python_object_or_raise::<u32>(parameter, "second")?;
let microsecond = extract_value_from_python_object_or_raise::<u32>(parameter, "microsecond")?;

let date = NaiveDate::from_ymd_opt(year, month, day)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid date".into()))?;
let time = NaiveTime::from_hms_micro_opt(hour, minute, second, microsecond)
.ok_or_else(|| RustPSQLDriverError::PyToRustValueConversionError("Invalid time".into()))?;
let naive_datetime = NaiveDateTime::new(date, time);

let raw_timestamp_tz = parameter
.getattr("tzinfo")
.ok()
.and_then(|tzinfo| tzinfo.getattr("key").ok())
.and_then(|key| key.extract::<String>().ok())
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError("Invalid timezone info".into())
})?;

let fixed_offset_datetime = raw_timestamp_tz
.parse::<Tz>()
.map_err(|_| {
RustPSQLDriverError::PyToRustValueConversionError("Failed to parse TZ".into())
})?
.from_local_datetime(&naive_datetime)
.single()
.ok_or_else(|| {
RustPSQLDriverError::PyToRustValueConversionError(
"Ambiguous or invalid datetime".into(),
)
})?
.fixed_offset();

Ok(fixed_offset_datetime)
}

/// Convert single python parameter to `PythonDTO` enum.
///
/// # Errors
Expand Down Expand Up @@ -849,6 +928,11 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
return Ok(PythonDTO::PyDateTime(pydatetime_no_tz));
}

let timestamp_tz = extract_datetime_from_python_object_attrs(parameter);
if let Ok(pydatetime_tz) = timestamp_tz {
return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz));
}

return Err(RustPSQLDriverError::PyToRustValueConversionError(
"Can not convert you datetime to rust type".into(),
));
Expand Down Expand Up @@ -1655,7 +1739,7 @@ pub fn other_postgres_bytes_to_py(
}

Err(RustPSQLDriverError::RustToPyValueConversionError(
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
))
}

Expand All @@ -1668,7 +1752,7 @@ pub fn composite_postgres_to_py(
py: Python<'_>,
fields: &Vec<Field>,
buf: &mut &[u8],
custom_decoders: &Option<Py<PyDict>>,
custom_decoders: Option<&Py<PyDict>>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py);

Expand Down Expand Up @@ -1737,7 +1821,7 @@ pub fn raw_bytes_data_process(
raw_bytes_data: &mut &[u8],
column_name: &str,
column_type: &Type,
custom_decoders: &Option<Py<PyDict>>,
custom_decoders: Option<&Py<PyDict>>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
if let Some(custom_decoders) = custom_decoders {
let py_encoder_func = custom_decoders
Expand Down Expand Up @@ -1776,7 +1860,7 @@ pub fn postgres_to_py(
row: &Row,
column: &Column,
column_i: usize,
custom_decoders: &Option<Py<PyDict>>,
custom_decoders: Option<&Py<PyDict>>,
) -> RustPSQLDriverPyResult<Py<PyAny>> {
let raw_bytes_data = row.col_buffer(column_i);
if let Some(mut raw_bytes_data) = raw_bytes_data {
Expand Down
Loading