From 8b7286afe55c56da8db7a58a5ca2c65a1cf6db26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 10:25:28 +0200 Subject: [PATCH 01/12] Makefile: update only stable toolchain It's annoying to have to wait for nightly updates. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 15d7d333..c644cdf8 100644 --- a/Makefile +++ b/Makefile @@ -262,7 +262,7 @@ build-examples: _update-rust-tooling: @echo "Run rustup update" - @rustup update + @rustup update stable check-cargo: install-cargo-if-missing _update-rust-tooling @echo "Running \"cargo check\" in ./scylla-rust-wrapper" From a51cdde0491b300bf5451c1427c6fd4955f0804d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 10:24:24 +0200 Subject: [PATCH 02/12] tree: extract cql_types module There are plenty of top-level modules in the crate that contain CQL types (collection, tuple, user_type, uuid, inet, date_time, value). This change moves them all into a single cql_types module, which should make it easier to find them and understand their purpose. This is a purely organizational change, no logic was changed. --- scylla-rust-wrapper/src/api.rs | 14 +++++----- scylla-rust-wrapper/src/batch.rs | 2 +- scylla-rust-wrapper/src/binding.rs | 28 +++++++++---------- scylla-rust-wrapper/src/cluster.rs | 2 +- .../src/{ => cql_types}/collection.rs | 6 ++-- .../src/{ => cql_types}/date_time.rs | 0 .../src/{ => cql_types}/inet.rs | 0 scylla-rust-wrapper/src/cql_types/mod.rs | 7 +++++ .../src/{ => cql_types}/tuple.rs | 4 +-- .../src/{ => cql_types}/user_type.rs | 4 +-- .../src/{ => cql_types}/uuid.rs | 0 .../src/{ => cql_types}/value.rs | 2 +- scylla-rust-wrapper/src/future.rs | 2 +- scylla-rust-wrapper/src/lib.rs | 8 +----- scylla-rust-wrapper/src/query_result.rs | 4 +-- scylla-rust-wrapper/src/ser_de_tests.rs | 4 +-- scylla-rust-wrapper/src/session.rs | 2 +- scylla-rust-wrapper/src/statement.rs | 8 +++--- 18 files changed, 49 insertions(+), 48 deletions(-) rename scylla-rust-wrapper/src/{ => cql_types}/collection.rs (99%) rename scylla-rust-wrapper/src/{ => cql_types}/date_time.rs (100%) rename scylla-rust-wrapper/src/{ => cql_types}/inet.rs (100%) create mode 100644 scylla-rust-wrapper/src/cql_types/mod.rs rename scylla-rust-wrapper/src/{ => cql_types}/tuple.rs (98%) rename scylla-rust-wrapper/src/{ => cql_types}/user_type.rs (98%) rename scylla-rust-wrapper/src/{ => cql_types}/uuid.rs (100%) rename scylla-rust-wrapper/src/{ => cql_types}/value.rs (99%) diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index 220e4d3e..4ba30f7c 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -562,7 +562,7 @@ pub mod data_type { pub mod collection { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::collection::{ + pub use crate::cql_types::collection::{ CassCollection, cass_collection_append_bool, cass_collection_append_bytes, @@ -594,7 +594,7 @@ pub mod collection { pub mod tuple { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::tuple::{ + pub use crate::cql_types::tuple::{ CassTuple, cass_tuple_data_type, cass_tuple_free, @@ -627,7 +627,7 @@ pub mod tuple { pub mod user_type { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::user_type::{ + pub use crate::cql_types::user_type::{ CassUserType, cass_user_type_data_type, cass_user_type_free, @@ -832,7 +832,7 @@ pub mod value { pub mod uuid_gen { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::uuid::{ + pub use crate::cql_types::uuid::{ CassUuidGen, cass_uuid_gen_free, cass_uuid_gen_from_time, @@ -846,7 +846,7 @@ pub mod uuid_gen { pub mod uuid { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::uuid::{ + pub use crate::cql_types::uuid::{ cass_uuid_from_string, cass_uuid_from_string_n, cass_uuid_max_from_time, @@ -927,7 +927,7 @@ pub mod log { pub mod inet { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::inet::{ + pub use crate::cql_types::inet::{ CassInet, cass_inet_from_string, cass_inet_from_string_n, @@ -940,7 +940,7 @@ pub mod inet { pub mod date_time { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::date_time::{ + pub use crate::cql_types::date_time::{ cass_date_from_epoch, cass_date_time_to_epoch, cass_time_from_epoch diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index cf1fba17..0e3f1cf8 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -6,11 +6,11 @@ use crate::cass_error::CassError; pub use crate::cass_types::CassBatchType; use crate::cass_types::{CassConsistency, make_batch_type}; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::value::CassCqlValue; use crate::exec_profile::PerStatementExecProfile; use crate::retry_policy::CassRetryPolicy; use crate::statement::{BoundStatement, CassStatement}; use crate::types::*; -use crate::value::CassCqlValue; use scylla::statement::batch::Batch; use scylla::statement::{Consistency, SerialConsistency}; use scylla::value::MaybeUnset; diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index 2574c28a..c4eb1acf 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -59,7 +59,7 @@ macro_rules! make_index_binder { ) -> CassError { // For some reason detected as unused, which is not true #[allow(unused_imports)] - use crate::value::CassCqlValue::*; + use crate::cql_types::value::CassCqlValue::*; let Some(this) = BoxFFI::as_mut_ref(this) else { tracing::error!("Provided null pointer to {}!", stringify!($fn_by_idx)); return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -83,7 +83,7 @@ macro_rules! make_name_binder { ) -> CassError { // For some reason detected as unused, which is not true #[allow(unused_imports)] - use crate::value::CassCqlValue::*; + use crate::cql_types::value::CassCqlValue::*; let Some(this) = BoxFFI::as_mut_ref(this) else { tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name)); return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -109,7 +109,7 @@ macro_rules! make_name_n_binder { ) -> CassError { // For some reason detected as unused, which is not true #[allow(unused_imports)] - use crate::value::CassCqlValue::*; + use crate::cql_types::value::CassCqlValue::*; let Some(this) = BoxFFI::as_mut_ref(this) else { tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name_n)); return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -133,7 +133,7 @@ macro_rules! make_appender { ) -> CassError { // For some reason detected as unused, which is not true #[allow(unused_imports)] - use crate::value::CassCqlValue::*; + use crate::cql_types::value::CassCqlValue::*; let Some(this) = BoxFFI::as_mut_ref(this) else { tracing::error!("Provided null pointer to {}!", stringify!($fn_append)); return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -263,8 +263,8 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |v: crate::uuid::CassUuid| Ok(Some(Uuid(v.into()))), - [v @ crate::uuid::CassUuid] + |v: crate::cql_types::uuid::CassUuid| Ok(Some(Uuid(v.into()))), + [v @ crate::cql_types::uuid::CassUuid] ); }; (inet, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -272,7 +272,7 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |v: crate::inet::CassInet| { + |v: crate::cql_types::inet::CassInet| { // Err if length in struct is invalid. // cppdriver doesn't check this - it encodes any length given to it // but it doesn't seem like something we wanna do. Also, rust driver can't @@ -282,7 +282,7 @@ macro_rules! invoke_binder_maker_macro_with_type { Err(_) => Err(CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE), } }, - [v @ crate::inet::CassInet] + [v @ crate::cql_types::inet::CassInet] ); }; (duration, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -318,10 +318,10 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: CassBorrowedSharedPtr| { + |p: CassBorrowedSharedPtr| { Ok(Some(std::convert::Into::into(BoxFFI::as_ref(p).unwrap()))) }, - [p @ CassBorrowedSharedPtr] + [p @ CassBorrowedSharedPtr] ); }; (tuple, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -329,10 +329,10 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: CassBorrowedSharedPtr| { + |p: CassBorrowedSharedPtr| { Ok(Some(BoxFFI::as_ref(p).unwrap().into())) }, - [p @ CassBorrowedSharedPtr] + [p @ CassBorrowedSharedPtr] ); }; (user_type, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -340,10 +340,10 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: CassBorrowedSharedPtr| { + |p: CassBorrowedSharedPtr| { Ok(Some(BoxFFI::as_ref(p).unwrap().into())) }, - [p @ CassBorrowedSharedPtr] + [p @ CassBorrowedSharedPtr] ); }; } diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index 8a2f9169..e1d19f9c 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -2,6 +2,7 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; use crate::config_value::MaybeUnsetConfig; +use crate::cql_types::uuid::CassUuid; use crate::exec_profile::{CassExecProfile, ExecProfileName, exec_profile_builder_modify}; use crate::load_balancing::{ CassHostFilter, DcRestriction, LoadBalancingConfig, LoadBalancingKind, @@ -11,7 +12,6 @@ use crate::runtime::{RUNTIMES, Runtime}; use crate::ssl::CassSsl; use crate::timestamp_generator::CassTimestampGen; use crate::types::*; -use crate::uuid::CassUuid; use openssl::ssl::SslContextBuilder; use openssl_sys::SSL_CTX_up_ref; use scylla::client::execution_profile::ExecutionProfileBuilder; diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/cql_types/collection.rs similarity index 99% rename from scylla-rust-wrapper/src/collection.rs rename to scylla-rust-wrapper/src/cql_types/collection.rs index 7daffb54..c5b89c9d 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/cql_types/collection.rs @@ -1,9 +1,9 @@ +use crate::argconv::*; use crate::cass_collection_types::CassCollectionType; use crate::cass_error::CassError; use crate::cass_types::{CassDataType, CassDataTypeInner, MapDataType}; +use crate::cql_types::value::{self, CassCqlValue}; use crate::types::*; -use crate::value::CassCqlValue; -use crate::{argconv::*, value}; use std::convert::TryFrom; use std::sync::Arc; use std::sync::LazyLock; @@ -285,7 +285,7 @@ mod tests { CassDataType, CassDataTypeInner, CassValueType, MapDataType, cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, }, - collection::{ + cql_types::collection::{ cass_collection_append_double, cass_collection_append_float, cass_collection_free, }, testing::assert_cass_error_eq, diff --git a/scylla-rust-wrapper/src/date_time.rs b/scylla-rust-wrapper/src/cql_types/date_time.rs similarity index 100% rename from scylla-rust-wrapper/src/date_time.rs rename to scylla-rust-wrapper/src/cql_types/date_time.rs diff --git a/scylla-rust-wrapper/src/inet.rs b/scylla-rust-wrapper/src/cql_types/inet.rs similarity index 100% rename from scylla-rust-wrapper/src/inet.rs rename to scylla-rust-wrapper/src/cql_types/inet.rs diff --git a/scylla-rust-wrapper/src/cql_types/mod.rs b/scylla-rust-wrapper/src/cql_types/mod.rs new file mode 100644 index 00000000..72e5a469 --- /dev/null +++ b/scylla-rust-wrapper/src/cql_types/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod collection; +pub(crate) mod date_time; +pub(crate) mod inet; +pub(crate) mod tuple; +pub(crate) mod user_type; +pub(crate) mod uuid; +pub(crate) mod value; diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/cql_types/tuple.rs similarity index 98% rename from scylla-rust-wrapper/src/tuple.rs rename to scylla-rust-wrapper/src/cql_types/tuple.rs index e22a8866..f11a435b 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/cql_types/tuple.rs @@ -2,9 +2,9 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::cass_types::CassDataTypeInner; +use crate::cql_types::value; +use crate::cql_types::value::CassCqlValue; use crate::types::*; -use crate::value; -use crate::value::CassCqlValue; use std::sync::Arc; use std::sync::LazyLock; diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/cql_types/user_type.rs similarity index 98% rename from scylla-rust-wrapper/src/user_type.rs rename to scylla-rust-wrapper/src/cql_types/user_type.rs index 946ab101..f2fd60fe 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/cql_types/user_type.rs @@ -1,8 +1,8 @@ +use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::{CassDataType, CassDataTypeInner}; +use crate::cql_types::value::{self, CassCqlValue}; use crate::types::*; -use crate::value::CassCqlValue; -use crate::{argconv::*, value}; use std::os::raw::c_char; use std::sync::Arc; diff --git a/scylla-rust-wrapper/src/uuid.rs b/scylla-rust-wrapper/src/cql_types/uuid.rs similarity index 100% rename from scylla-rust-wrapper/src/uuid.rs rename to scylla-rust-wrapper/src/cql_types/uuid.rs diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/cql_types/value.rs similarity index 99% rename from scylla-rust-wrapper/src/value.rs rename to scylla-rust-wrapper/src/cql_types/value.rs index 06933778..ef480132 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/cql_types/value.rs @@ -417,7 +417,7 @@ mod tests { use crate::{ cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType, UdtDataType}, - value::{CassCqlValue, is_type_compatible}, + cql_types::value::{CassCqlValue, is_type_compatible}, }; fn all_value_data_types() -> Vec { diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 971bf95c..db97a247 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -1,10 +1,10 @@ use crate::argconv::*; use crate::cass_error::{CassError, CassErrorMessage, CassErrorResult, ToCassError as _}; +use crate::cql_types::uuid::CassUuid; use crate::prepared::CassPrepared; use crate::query_result::{CassNode, CassResult}; use crate::runtime::Runtime; use crate::types::*; -use crate::uuid::CassUuid; use futures::future; use std::future::Future; use std::mem; diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index af0665a0..593c6d71 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -14,12 +14,10 @@ pub(crate) mod batch; pub(crate) mod cass_error; pub(crate) mod cass_types; pub(crate) mod cluster; -pub(crate) mod collection; pub(crate) mod config_value; -pub(crate) mod date_time; +pub(crate) mod cql_types; pub(crate) mod exec_profile; pub(crate) mod future; -pub(crate) mod inet; #[cfg(cpp_integration_testing)] pub(crate) mod integration_testing; pub(crate) mod iterator; @@ -39,10 +37,6 @@ pub(crate) mod statement; #[cfg(test)] pub(crate) mod testing; pub(crate) mod timestamp_generator; -pub(crate) mod tuple; -pub(crate) mod user_type; -pub(crate) mod uuid; -pub(crate) mod value; /// Includes a file generated by bindgen called `filename`. macro_rules! include_bindgen_generated { diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 4c6615bf..89965526 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -6,9 +6,9 @@ use crate::cass_types::{ CassColumnSpec, CassDataType, CassDataTypeInner, MapDataType, cass_data_type_type, get_column_type, }; -use crate::inet::CassInet; +use crate::cql_types::inet::CassInet; +use crate::cql_types::uuid::CassUuid; use crate::types::*; -use crate::uuid::CassUuid; use cass_raw_value::CassRawValue; use row_with_self_borrowed_result_data::RowWithSelfBorrowedResultData; use scylla::cluster::metadata::{ColumnType, NativeType}; diff --git a/scylla-rust-wrapper/src/ser_de_tests.rs b/scylla-rust-wrapper/src/ser_de_tests.rs index 6102b555..e2f0a3e5 100644 --- a/scylla-rust-wrapper/src/ser_de_tests.rs +++ b/scylla-rust-wrapper/src/ser_de_tests.rs @@ -27,7 +27,8 @@ use crate::argconv::{ }; use crate::cass_error::CassError; use crate::cass_types::get_column_type; -use crate::inet::CassInet; +use crate::cql_types::inet::CassInet; +use crate::cql_types::uuid::CassUuid; use crate::iterator::{ CassIterator, CassIteratorType, cass_iterator_fields_from_user_type, cass_iterator_free, cass_iterator_from_collection, cass_iterator_from_map, cass_iterator_from_tuple, @@ -44,7 +45,6 @@ use crate::query_result::{ }; use crate::testing::{assert_cass_error_eq, setup_tracing}; use crate::types::size_t; -use crate::uuid::CassUuid; fn do_serialize(t: T, typ: &ColumnType) -> Vec { let mut ret = Vec::new(); diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 19a1aa3a..50112fa9 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -4,6 +4,7 @@ use crate::cass_error::*; use crate::cass_metrics_types::CassMetrics; use crate::cass_types::get_column_type; use crate::cluster::CassCluster; +use crate::cql_types::uuid::CassUuid; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; @@ -13,7 +14,6 @@ use crate::query_result::{CassResult, CassResultKind, CassResultMetadata}; use crate::runtime::Runtime; use crate::statement::{BoundStatement, CassStatement, SimpleQueryRowSerializer}; use crate::types::size_t; -use crate::uuid::CassUuid; use scylla::client::execution_profile::ExecutionProfileHandle; use scylla::client::session::Session; use scylla::client::session_builder::SessionBuilder; diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index 7378b79c..66687170 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,14 +1,14 @@ +use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::inet::CassInet; +use crate::cql_types::value::{self, CassCqlValue}; use crate::exec_profile::PerStatementExecProfile; -use crate::inet::CassInet; use crate::prepared::CassPrepared; use crate::query_result::{CassNode, CassResult}; use crate::retry_policy::CassRetryPolicy; use crate::types::*; -use crate::value::CassCqlValue; -use crate::{argconv::*, value}; use scylla::frame::types::Consistency; use scylla::policies::load_balancing::{NodeIdentifier, SingleTargetLoadBalancingPolicy}; use scylla::response::{PagingState, PagingStateResponse}; @@ -875,7 +875,7 @@ mod tests { use crate::argconv::{BoxFFI, RefFFI}; use crate::cass_error::CassError; - use crate::inet::CassInet; + use crate::cql_types::inet::CassInet; use crate::statement::{ cass_statement_set_host, cass_statement_set_host_inet, cass_statement_set_node, }; From a094617e74e6fd967c02dfda94e3d1349dcdc8fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 10:39:51 +0200 Subject: [PATCH 03/12] tree: extract `data_type` module from cass_types Most of the `cass_types` module is CassDataType-related entities. Those were moved to the newly created `cql_types::data_type` module. --- scylla-rust-wrapper/src/api.rs | 2 +- scylla-rust-wrapper/src/argconv.rs | 4 +- scylla-rust-wrapper/src/batch.rs | 3 +- scylla-rust-wrapper/src/cass_types.rs | 938 ------------------ .../src/cql_types/collection.rs | 11 +- .../src/cql_types/data_type.rs | 906 +++++++++++++++++ scylla-rust-wrapper/src/cql_types/mod.rs | 1 + scylla-rust-wrapper/src/cql_types/tuple.rs | 10 +- .../src/cql_types/user_type.rs | 2 +- scylla-rust-wrapper/src/cql_types/value.rs | 10 +- scylla-rust-wrapper/src/iterator.rs | 3 +- scylla-rust-wrapper/src/metadata.rs | 3 +- scylla-rust-wrapper/src/prepared.rs | 2 +- scylla-rust-wrapper/src/query_result.rs | 4 +- scylla-rust-wrapper/src/ser_de_tests.rs | 2 +- scylla-rust-wrapper/src/session.rs | 2 +- 16 files changed, 940 insertions(+), 963 deletions(-) create mode 100644 scylla-rust-wrapper/src/cql_types/data_type.rs diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index 4ba30f7c..eddbf51f 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -526,7 +526,7 @@ pub mod batch { pub mod data_type { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::cass_types::{ + pub use crate::cql_types::data_type::{ CassDataType, cass_data_sub_type_count, cass_data_type_add_sub_type, diff --git a/scylla-rust-wrapper/src/argconv.rs b/scylla-rust-wrapper/src/argconv.rs index 3c9a9849..887be340 100644 --- a/scylla-rust-wrapper/src/argconv.rs +++ b/scylla-rust-wrapper/src/argconv.rs @@ -91,7 +91,7 @@ mod sealed { /// There is no way to obtain a mutable reference from such pointer. /// /// In some cases, we need to be able to mutate the data behind a shared pointer. -/// There is an example of such use case - namely [`crate::cass_types::CassDataType`]. +/// There is an example of such use case - namely [`crate::cql_types::data_type::CassDataType`]. /// argconv API does not provide a way to mutate such pointer - one can only convert the pointer /// to [`Arc`] or &. It is the API user's responsibility to implement sound interior mutability /// pattern in such case. This is what we currently do - CassDataType wraps CassDataTypeInner @@ -626,7 +626,7 @@ impl BoxFFI for T where T: FFI {} /// C API user should be responsible for freeing (decreasing reference count of) /// associated memory manually via corresponding API call. /// -/// An example of such implementor would be [`CassDataType`](crate::cass_types::CassDataType): +/// An example of such implementor would be [`CassDataType`](crate::cql_types::data_type::CassDataType): /// - it is allocated on the heap via [`Arc::new`] /// - there are multiple owners of the shared CassDataType object /// - some API functions require to increase a reference count of the object diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index 0e3f1cf8..33244d2e 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -4,8 +4,9 @@ use crate::argconv::{ }; use crate::cass_error::CassError; pub use crate::cass_types::CassBatchType; -use crate::cass_types::{CassConsistency, make_batch_type}; +use crate::cass_types::CassConsistency; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::data_type::make_batch_type; use crate::cql_types::value::CassCqlValue; use crate::exec_profile::PerStatementExecProfile; use crate::retry_policy::CassRetryPolicy; diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index b2ec9e17..4ba981b2 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -1,941 +1,3 @@ -use crate::argconv::*; -use crate::cass_error::CassError; -use crate::types::*; -use scylla::cluster::metadata::{CollectionType, NativeType}; -use scylla::frame::response::result::ColumnType; -use scylla::statement::batch::BatchType; -use std::cell::UnsafeCell; -use std::os::raw::c_char; -use std::sync::Arc; - pub use crate::cass_batch_types::CassBatchType; pub use crate::cass_consistency_types::CassConsistency; pub use crate::cass_data_types::CassValueType; - -#[derive(Clone, Debug)] -#[cfg_attr(test, derive(PartialEq, Eq))] -pub(crate) struct UdtDataType { - // Vec to preserve the order of types - pub(crate) field_types: Vec<(String, Arc)>, - - pub(crate) keyspace: String, - pub(crate) name: String, - pub(crate) frozen: bool, -} - -impl UdtDataType { - pub(crate) fn new() -> UdtDataType { - UdtDataType { - field_types: Vec::new(), - keyspace: "".to_string(), - name: "".to_string(), - frozen: false, - } - } - - pub(crate) fn with_capacity(capacity: usize) -> UdtDataType { - UdtDataType { - field_types: Vec::with_capacity(capacity), - keyspace: "".to_string(), - name: "".to_string(), - frozen: false, - } - } - - pub(crate) fn get_field_by_name(&self, name: &str) -> Option<&Arc> { - self.field_types - .iter() - .find(|(field_name, _)| field_name == name) - .map(|(_, t)| t) - } - - fn typecheck_equals(&self, other: &UdtDataType) -> bool { - // See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386 - - if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) { - return false; - } - if !any_string_empty_or_both_equal(&self.name, &other.name) { - return false; - } - - // A comment from cpp-driver: - //// UDT's can be considered equal as long as the mutual first fields shared - //// between them are equal. UDT's are append only as far as fields go, so a - //// newer 'version' of the UDT data type after a schema change event should be - //// treated as equivalent in this scenario, by simply looking at the first N - //// mutual fields they should share. - // - // Iterator returned from zip() is perfect for checking the first mutual fields. - for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) { - // Compare field names. - if field.0 != other_field.0 { - return false; - } - // Compare field types. - if unsafe { - !field - .1 - .get_unchecked() - .typecheck_equals(other_field.1.get_unchecked()) - } { - return false; - } - } - - true - } -} - -fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool { - s1.is_empty() || s2.is_empty() || s1 == s2 -} - -impl Default for UdtDataType { - fn default() -> Self { - Self::new() - } -} - -#[derive(Clone, Debug)] -#[cfg_attr(test, derive(PartialEq, Eq))] -pub(crate) enum MapDataType { - Untyped, - Key(Arc), - KeyAndValue(Arc, Arc), -} - -#[derive(Debug)] -pub(crate) struct CassColumnSpec { - pub(crate) name: String, - pub(crate) data_type: Arc, -} - -#[derive(Clone, Debug)] -#[cfg_attr(test, derive(PartialEq, Eq))] -pub(crate) enum CassDataTypeInner { - Value(CassValueType), - Udt(UdtDataType), - List { - // None stands for untyped list. - typ: Option>, - frozen: bool, - }, - Set { - // None stands for untyped set. - typ: Option>, - frozen: bool, - }, - Map { - typ: MapDataType, - frozen: bool, - }, - // Empty vector stands for untyped tuple. - Tuple(Vec>), - Custom(String), -} - -impl FFI for CassDataType { - type Origin = FromArc; -} - -impl CassDataTypeInner { - /// Checks for equality during typechecks. - /// - /// This takes into account the fact that tuples/collections may be untyped. - pub(crate) fn typecheck_equals(&self, other: &CassDataTypeInner) -> bool { - match self { - CassDataTypeInner::Value(t) => *t == other.get_value_type(), - CassDataTypeInner::Udt(udt) => match other { - CassDataTypeInner::Udt(other_udt) => udt.typecheck_equals(other_udt), - _ => false, - }, - CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match other - { - CassDataTypeInner::List { typ: other_typ, .. } - | CassDataTypeInner::Set { typ: other_typ, .. } => { - // If one of them is list, and the other is set, fail the typecheck. - if self.get_value_type() != other.get_value_type() { - return false; - } - match (typ, other_typ) { - // One of them is untyped, skip the typecheck for subtype. - (None, _) | (_, None) => true, - (Some(typ), Some(other_typ)) => unsafe { - typ.get_unchecked() - .typecheck_equals(other_typ.get_unchecked()) - }, - } - } - _ => false, - }, - CassDataTypeInner::Map { typ: t, .. } => match other { - CassDataTypeInner::Map { typ: t_other, .. } => match (t, t_other) { - // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 - // In cpp-driver the types are held in a vector. - // The logic is following: - - // If either of vectors is empty, skip the typecheck. - (MapDataType::Untyped, _) => true, - (_, MapDataType::Untyped) => true, - - // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. - (MapDataType::Key(k), MapDataType::Key(k_other)) => unsafe { - k.get_unchecked().typecheck_equals(k_other.get_unchecked()) - }, - ( - MapDataType::KeyAndValue(k, v), - MapDataType::KeyAndValue(k_other, v_other), - ) => unsafe { - k.get_unchecked().typecheck_equals(k_other.get_unchecked()) - && v.get_unchecked().typecheck_equals(v_other.get_unchecked()) - }, - _ => false, - }, - _ => false, - }, - CassDataTypeInner::Tuple(sub) => match other { - CassDataTypeInner::Tuple(other_sub) => { - // If either of tuples is untyped, skip the typecheck for subtypes. - if sub.is_empty() || other_sub.is_empty() { - return true; - } - - // If both are non-empty, check for subtypes equality. - if sub.len() != other_sub.len() { - return false; - } - sub.iter() - .zip(other_sub.iter()) - .all(|(typ, other_typ)| unsafe { - typ.get_unchecked() - .typecheck_equals(other_typ.get_unchecked()) - }) - } - _ => false, - }, - CassDataTypeInner::Custom(_) => { - unimplemented!("cpp-rs-driver does not support custom types!") - } - } - } -} - -#[derive(Debug)] -#[repr(transparent)] -pub struct CassDataType(UnsafeCell); - -/// PartialEq and Eq for test purposes. -#[cfg(test)] -impl PartialEq for CassDataType { - fn eq(&self, other: &Self) -> bool { - unsafe { self.get_unchecked() == other.get_unchecked() } - } -} -#[cfg(test)] -impl Eq for CassDataType {} - -unsafe impl Sync for CassDataType {} - -impl CassDataType { - pub(crate) unsafe fn get_unchecked(&self) -> &CassDataTypeInner { - unsafe { &*self.0.get() } - } - - #[allow(clippy::mut_from_ref)] - pub(crate) unsafe fn get_mut_unchecked(&self) -> &mut CassDataTypeInner { - unsafe { &mut *self.0.get() } - } - - pub(crate) const fn new(inner: CassDataTypeInner) -> CassDataType { - CassDataType(UnsafeCell::new(inner)) - } - - pub(crate) fn new_arced(inner: CassDataTypeInner) -> Arc { - Arc::new(CassDataType(UnsafeCell::new(inner))) - } -} - -fn native_type_to_cass_value_type(native_type: &NativeType) -> CassValueType { - use NativeType::*; - match native_type { - Ascii => CassValueType::CASS_VALUE_TYPE_ASCII, - Boolean => CassValueType::CASS_VALUE_TYPE_BOOLEAN, - Blob => CassValueType::CASS_VALUE_TYPE_BLOB, - Counter => CassValueType::CASS_VALUE_TYPE_COUNTER, - Date => CassValueType::CASS_VALUE_TYPE_DATE, - Decimal => CassValueType::CASS_VALUE_TYPE_DECIMAL, - Double => CassValueType::CASS_VALUE_TYPE_DOUBLE, - Duration => CassValueType::CASS_VALUE_TYPE_DURATION, - Float => CassValueType::CASS_VALUE_TYPE_FLOAT, - Int => CassValueType::CASS_VALUE_TYPE_INT, - BigInt => CassValueType::CASS_VALUE_TYPE_BIGINT, - // Rust Driver unifies both VARCHAR and TEXT into NativeType::Text. - // CPP Driver, in accordance to the CQL protocol, has separate types for VARCHAR and TEXT. - // Even worse, Rust Driver even does not handle CQL TEXT correctly! It errors out on TEXT - // type... - // As the DBs (Cassandra and ScyllaDB) seem to send the VARCHAR type in the protocol, - // we will assume that the NativeType::Text is actually a VARCHAR type. - Text => CassValueType::CASS_VALUE_TYPE_VARCHAR, - Timestamp => CassValueType::CASS_VALUE_TYPE_TIMESTAMP, - Inet => CassValueType::CASS_VALUE_TYPE_INET, - SmallInt => CassValueType::CASS_VALUE_TYPE_SMALL_INT, - TinyInt => CassValueType::CASS_VALUE_TYPE_TINY_INT, - Time => CassValueType::CASS_VALUE_TYPE_TIME, - Timeuuid => CassValueType::CASS_VALUE_TYPE_TIMEUUID, - Uuid => CassValueType::CASS_VALUE_TYPE_UUID, - Varint => CassValueType::CASS_VALUE_TYPE_VARINT, - - // NativeType is non_exhaustive - _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, - } -} - -impl CassDataTypeInner { - fn get_sub_data_type(&self, index: usize) -> Option<&Arc> { - match self { - CassDataTypeInner::Udt(udt_data_type) => { - udt_data_type.field_types.get(index).map(|(_, b)| b) - } - CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { - if index > 0 { None } else { typ.as_ref() } - } - CassDataTypeInner::Map { - typ: MapDataType::Untyped, - .. - } => None, - CassDataTypeInner::Map { - typ: MapDataType::Key(k), - .. - } => (index == 0).then_some(k), - CassDataTypeInner::Map { - typ: MapDataType::KeyAndValue(k, v), - .. - } => match index { - 0 => Some(k), - 1 => Some(v), - _ => None, - }, - CassDataTypeInner::Tuple(v) => v.get(index), - _ => None, - } - } - - fn add_sub_data_type(&mut self, sub_type: Arc) -> Result<(), CassError> { - match self { - CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match typ { - Some(_) => { - tracing::error!("Trying to add sub-type to already typed list/set!"); - Err(CassError::CASS_ERROR_LIB_BAD_PARAMS) - } - None => { - *typ = Some(sub_type); - Ok(()) - } - }, - CassDataTypeInner::Map { - typ: MapDataType::KeyAndValue(_, _), - .. - } => { - tracing::error!("Trying to add sub-type to already fully typed map!"); - Err(CassError::CASS_ERROR_LIB_BAD_PARAMS) - } - CassDataTypeInner::Map { - typ: MapDataType::Key(k), - frozen, - } => { - *self = CassDataTypeInner::Map { - typ: MapDataType::KeyAndValue(k.clone(), sub_type), - frozen: *frozen, - }; - Ok(()) - } - CassDataTypeInner::Map { - typ: MapDataType::Untyped, - frozen, - } => { - *self = CassDataTypeInner::Map { - typ: MapDataType::Key(sub_type), - frozen: *frozen, - }; - Ok(()) - } - CassDataTypeInner::Tuple(types) => { - types.push(sub_type); - Ok(()) - } - _ => { - tracing::error!("Trying to add sub-type to non-collection/tuple data type!"); - Err(CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE) - } - } - } - - pub(crate) fn get_udt_type(&self) -> &UdtDataType { - match self { - CassDataTypeInner::Udt(udt) => udt, - _ => panic!("Can get UDT out of non-UDT data type"), - } - } - - pub(crate) fn get_value_type(&self) -> CassValueType { - match &self { - CassDataTypeInner::Value(value_data_type) => *value_data_type, - CassDataTypeInner::Udt { .. } => CassValueType::CASS_VALUE_TYPE_UDT, - CassDataTypeInner::List { .. } => CassValueType::CASS_VALUE_TYPE_LIST, - CassDataTypeInner::Set { .. } => CassValueType::CASS_VALUE_TYPE_SET, - CassDataTypeInner::Map { .. } => CassValueType::CASS_VALUE_TYPE_MAP, - CassDataTypeInner::Tuple(..) => CassValueType::CASS_VALUE_TYPE_TUPLE, - CassDataTypeInner::Custom(..) => CassValueType::CASS_VALUE_TYPE_CUSTOM, - } - } -} - -pub(crate) fn get_column_type(column_type: &ColumnType) -> CassDataType { - use CollectionType::*; - use ColumnType::*; - let inner = match column_type { - Native(native) => CassDataTypeInner::Value(native_type_to_cass_value_type(native)), - Collection { - typ: List(boxed_type), - frozen, - } => CassDataTypeInner::List { - typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), - frozen: *frozen, - }, - Collection { - typ: Map(key, value), - frozen, - } => CassDataTypeInner::Map { - typ: MapDataType::KeyAndValue( - Arc::new(get_column_type(key.as_ref())), - Arc::new(get_column_type(value.as_ref())), - ), - frozen: *frozen, - }, - Collection { - typ: Set(boxed_type), - frozen, - } => CassDataTypeInner::Set { - typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), - frozen: *frozen, - }, - UserDefinedType { definition, frozen } => CassDataTypeInner::Udt(UdtDataType { - field_types: definition - .field_types - .iter() - .map(|(name, col_type)| { - ( - name.clone().into_owned(), - Arc::new(get_column_type(col_type)), - ) - }) - .collect(), - keyspace: definition.keyspace.clone().into_owned(), - name: definition.name.clone().into_owned(), - frozen: *frozen, - }), - Tuple(v) => CassDataTypeInner::Tuple( - v.iter() - .map(|col_type| Arc::new(get_column_type(col_type))) - .collect(), - ), - - // ColumnType is non_exhaustive. - _ => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_UNKNOWN), - }; - - CassDataType::new(inner) -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_new( - value_type: CassValueType, -) -> CassOwnedSharedPtr { - let inner = match value_type { - CassValueType::CASS_VALUE_TYPE_LIST => CassDataTypeInner::List { - typ: None, - frozen: false, - }, - CassValueType::CASS_VALUE_TYPE_SET => CassDataTypeInner::Set { - typ: None, - frozen: false, - }, - CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataTypeInner::Tuple(Vec::new()), - CassValueType::CASS_VALUE_TYPE_MAP => CassDataTypeInner::Map { - typ: MapDataType::Untyped, - frozen: false, - }, - CassValueType::CASS_VALUE_TYPE_UDT => CassDataTypeInner::Udt(UdtDataType::new()), - CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataTypeInner::Custom("".to_string()), - CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ArcFFI::null(), - t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataTypeInner::Value(t), - _ => return ArcFFI::null(), - }; - ArcFFI::into_ptr(CassDataType::new_arced(inner)) -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_new_from_existing( - data_type: CassBorrowedSharedPtr, -) -> CassOwnedSharedPtr { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_new_from_existing!"); - return ArcFFI::null(); - }; - - ArcFFI::into_ptr(CassDataType::new_arced( - unsafe { data_type.get_unchecked() }.clone(), - )) -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_new_tuple( - item_count: size_t, -) -> CassOwnedSharedPtr { - ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Tuple( - Vec::with_capacity(item_count as usize), - ))) -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_new_udt( - field_count: size_t, -) -> CassOwnedSharedPtr { - ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Udt( - UdtDataType::with_capacity(field_count as usize), - ))) -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_free(data_type: CassOwnedSharedPtr) { - ArcFFI::free(data_type); -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_type( - data_type: CassBorrowedSharedPtr, -) -> CassValueType { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_type!"); - return CassValueType::CASS_VALUE_TYPE_UNKNOWN; - }; - - unsafe { data_type.get_unchecked() }.get_value_type() -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_is_frozen( - data_type: CassBorrowedSharedPtr, -) -> cass_bool_t { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_is_frozen!"); - return cass_false; - }; - - let is_frozen = match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Udt(udt) => udt.frozen, - CassDataTypeInner::List { frozen, .. } => *frozen, - CassDataTypeInner::Set { frozen, .. } => *frozen, - CassDataTypeInner::Map { frozen, .. } => *frozen, - _ => false, - }; - - is_frozen as cass_bool_t -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_type_name( - data_type: CassBorrowedSharedPtr, - type_name: *mut *const c_char, - type_name_length: *mut size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_type_name!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Udt(UdtDataType { name, .. }) => { - unsafe { write_str_to_c(name, type_name, type_name_length) }; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to get type name from non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_type_name( - data_type: CassBorrowedSharedPtr, - type_name: *const c_char, -) -> CassError { - unsafe { cass_data_type_set_type_name_n(data_type, type_name, strlen(type_name)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_type_name_n( - data_type_raw: CassBorrowedSharedPtr, - type_name: *const c_char, - type_name_length: size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { - tracing::error!("Provided null data type pointer to cass_data_type_set_type_name_n!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - let type_name_string = unsafe { ptr_to_cstr_n(type_name, type_name_length) } - .unwrap() - .to_string(); - - match unsafe { data_type.get_mut_unchecked() } { - CassDataTypeInner::Udt(udt_data_type) => { - udt_data_type.name = type_name_string; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to set type name on non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_keyspace( - data_type: CassBorrowedSharedPtr, - keyspace: *mut *const c_char, - keyspace_length: *mut size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_keyspace!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Udt(UdtDataType { name, .. }) => { - unsafe { write_str_to_c(name, keyspace, keyspace_length) }; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to get keyspace from non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_keyspace( - data_type: CassBorrowedSharedPtr, - keyspace: *const c_char, -) -> CassError { - unsafe { cass_data_type_set_keyspace_n(data_type, keyspace, strlen(keyspace)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_keyspace_n( - data_type: CassBorrowedSharedPtr, - keyspace: *const c_char, - keyspace_length: size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_set_keyspace_n!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - let keyspace_string = unsafe { ptr_to_cstr_n(keyspace, keyspace_length) } - .unwrap() - .to_string(); - - match unsafe { data_type.get_mut_unchecked() } { - CassDataTypeInner::Udt(udt_data_type) => { - udt_data_type.keyspace = keyspace_string; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to set keyspace on non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_class_name( - data_type: CassBorrowedSharedPtr, - class_name: *mut *const ::std::os::raw::c_char, - class_name_length: *mut size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_class_name!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Custom(name) => { - unsafe { write_str_to_c(name, class_name, class_name_length) }; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to get class name from non-Custom data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_class_name( - data_type: CassBorrowedSharedPtr, - class_name: *const ::std::os::raw::c_char, -) -> CassError { - unsafe { cass_data_type_set_class_name_n(data_type, class_name, strlen(class_name)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_set_class_name_n( - data_type: CassBorrowedSharedPtr, - class_name: *const ::std::os::raw::c_char, - class_name_length: size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_set_class_name_n!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - let class_string = unsafe { ptr_to_cstr_n(class_name, class_name_length) } - .unwrap() - .to_string(); - match unsafe { data_type.get_mut_unchecked() } { - CassDataTypeInner::Custom(name) => { - *name = class_string; - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to set class name on non-Custom data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_sub_type_count( - data_type: CassBorrowedSharedPtr, -) -> size_t { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_sub_type_count!"); - return 0; - }; - - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Value(..) => 0, - CassDataTypeInner::Udt(udt_data_type) => udt_data_type.field_types.len() as size_t, - CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { - typ.is_some() as size_t - } - CassDataTypeInner::Map { typ, .. } => match typ { - MapDataType::Untyped => 0, - MapDataType::Key(_) => 1, - MapDataType::KeyAndValue(_, _) => 2, - }, - CassDataTypeInner::Tuple(v) => v.len() as size_t, - CassDataTypeInner::Custom(..) => 0, - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_sub_type_count( - data_type: CassBorrowedSharedPtr, -) -> size_t { - unsafe { cass_data_type_sub_type_count(data_type) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_sub_data_type( - data_type: CassBorrowedSharedPtr, - index: size_t, -) -> CassBorrowedSharedPtr { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_sub_data_type!"); - return ArcFFI::null(); - }; - - let sub_type: Option<&Arc> = - unsafe { data_type.get_unchecked() }.get_sub_data_type(index as usize); - - match sub_type { - None => ArcFFI::null(), - // Semantic from cppdriver which also returns non-owning pointer - Some(arc) => ArcFFI::as_ptr(arc), - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name( - data_type: CassBorrowedSharedPtr, - name: *const ::std::os::raw::c_char, -) -> CassBorrowedSharedPtr { - unsafe { cass_data_type_sub_data_type_by_name_n(data_type, name, strlen(name)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( - data_type: CassBorrowedSharedPtr, - name: *const ::std::os::raw::c_char, - name_length: size_t, -) -> CassBorrowedSharedPtr { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!( - "Provided null data type pointer to cass_data_type_sub_data_type_by_name_n!" - ); - return ArcFFI::null(); - }; - - let name_str = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap(); - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Udt(udt) => match udt.get_field_by_name(name_str) { - None => ArcFFI::null(), - Some(t) => ArcFFI::as_ptr(t), - }, - _ => ArcFFI::null(), - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_sub_type_name( - data_type: CassBorrowedSharedPtr, - index: size_t, - name: *mut *const ::std::os::raw::c_char, - name_length: *mut size_t, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_sub_type_name!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - match unsafe { data_type.get_unchecked() } { - CassDataTypeInner::Udt(udt) => match udt.field_types.get(index as usize) { - None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, - Some((field_name, _)) => { - unsafe { write_str_to_c(field_name, name, name_length) }; - CassError::CASS_OK - } - }, - _ => { - tracing::error!("Trying to get sub-type name from non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_type( - data_type: CassBorrowedSharedPtr, - sub_data_type: CassBorrowedSharedPtr, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type) else { - tracing::error!("Provided null data type pointer to cass_data_type_add_sub_type!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type) else { - tracing::error!("Provided null sub data type pointer to cass_data_type_add_sub_type!"); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - match unsafe { data_type.get_mut_unchecked() }.add_sub_data_type(sub_data_type) { - Ok(()) => CassError::CASS_OK, - Err(e) => e, - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name( - data_type: CassBorrowedSharedPtr, - name: *const c_char, - sub_data_type: CassBorrowedSharedPtr, -) -> CassError { - unsafe { cass_data_type_add_sub_type_by_name_n(data_type, name, strlen(name), sub_data_type) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( - data_type_raw: CassBorrowedSharedPtr, - name: *const c_char, - name_length: size_t, - sub_data_type_raw: CassBorrowedSharedPtr, -) -> CassError { - let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { - tracing::error!( - "Provided null data type pointer to cass_data_type_add_sub_type_by_name_n!" - ); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type_raw) else { - tracing::error!( - "Provided null sub data type pointer to cass_data_type_add_sub_type_by_name_n!" - ); - return CassError::CASS_ERROR_LIB_BAD_PARAMS; - }; - - let name_string = unsafe { ptr_to_cstr_n(name, name_length) } - .unwrap() - .to_string(); - - match unsafe { data_type.get_mut_unchecked() } { - CassDataTypeInner::Udt(udt_data_type) => { - // The Cpp Driver does not check whether field_types size - // exceeded field_count. - udt_data_type.field_types.push((name_string, sub_data_type)); - CassError::CASS_OK - } - _ => { - tracing::error!("Trying to add sub-type by name to non-UDT data type!"); - CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_value_type( - data_type: CassBorrowedSharedPtr, - sub_value_type: CassValueType, -) -> CassError { - let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - unsafe { cass_data_type_add_sub_type(data_type, ArcFFI::as_ptr(&sub_data_type)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( - data_type: CassBorrowedSharedPtr, - name: *const c_char, - sub_value_type: CassValueType, -) -> CassError { - let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - unsafe { cass_data_type_add_sub_type_by_name(data_type, name, ArcFFI::as_ptr(&sub_data_type)) } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name_n( - data_type: CassBorrowedSharedPtr, - name: *const c_char, - name_length: size_t, - sub_value_type: CassValueType, -) -> CassError { - let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); - unsafe { - cass_data_type_add_sub_type_by_name_n( - data_type, - name, - name_length, - ArcFFI::as_ptr(&sub_data_type), - ) - } -} - -pub(crate) fn make_batch_type(type_: CassBatchType) -> Result { - match type_ { - CassBatchType::CASS_BATCH_TYPE_LOGGED => Ok(BatchType::Logged), - CassBatchType::CASS_BATCH_TYPE_UNLOGGED => Ok(BatchType::Unlogged), - CassBatchType::CASS_BATCH_TYPE_COUNTER => Ok(BatchType::Counter), - _ => Err(()), - } -} diff --git a/scylla-rust-wrapper/src/cql_types/collection.rs b/scylla-rust-wrapper/src/cql_types/collection.rs index c5b89c9d..562e9df8 100644 --- a/scylla-rust-wrapper/src/cql_types/collection.rs +++ b/scylla-rust-wrapper/src/cql_types/collection.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::cass_collection_types::CassCollectionType; use crate::cass_error::CassError; -use crate::cass_types::{CassDataType, CassDataTypeInner, MapDataType}; +use crate::cql_types::data_type::{CassDataType, CassDataTypeInner, MapDataType}; use crate::cql_types::value::{self, CassCqlValue}; use crate::types::*; use std::convert::TryFrom; @@ -281,13 +281,14 @@ mod tests { use crate::{ argconv::ArcFFI, cass_error::CassError, - cass_types::{ - CassDataType, CassDataTypeInner, CassValueType, MapDataType, - cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, - }, + cass_types::CassValueType, cql_types::collection::{ cass_collection_append_double, cass_collection_append_float, cass_collection_free, }, + cql_types::data_type::{ + CassDataType, CassDataTypeInner, MapDataType, cass_data_type_add_sub_type, + cass_data_type_free, cass_data_type_new, + }, testing::assert_cass_error_eq, }; diff --git a/scylla-rust-wrapper/src/cql_types/data_type.rs b/scylla-rust-wrapper/src/cql_types/data_type.rs new file mode 100644 index 00000000..bff491d5 --- /dev/null +++ b/scylla-rust-wrapper/src/cql_types/data_type.rs @@ -0,0 +1,906 @@ +use crate::argconv::*; +use crate::batch::CassBatchType; +use crate::cass_error::CassError; +use crate::cass_types::CassValueType; +use crate::types::*; +use scylla::cluster::metadata::{CollectionType, NativeType}; +use scylla::frame::response::result::ColumnType; +use scylla::statement::batch::BatchType; +use std::cell::UnsafeCell; +use std::os::raw::c_char; +use std::sync::Arc; + +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub(crate) struct UdtDataType { + // Vec to preserve the order of types + pub(crate) field_types: Vec<(String, Arc)>, + + pub(crate) keyspace: String, + pub(crate) name: String, + pub(crate) frozen: bool, +} + +impl UdtDataType { + pub(crate) fn new() -> UdtDataType { + UdtDataType { + field_types: Vec::new(), + keyspace: "".to_string(), + name: "".to_string(), + frozen: false, + } + } + + pub(crate) fn with_capacity(capacity: usize) -> UdtDataType { + UdtDataType { + field_types: Vec::with_capacity(capacity), + keyspace: "".to_string(), + name: "".to_string(), + frozen: false, + } + } + + pub(crate) fn get_field_by_name(&self, name: &str) -> Option<&Arc> { + self.field_types + .iter() + .find(|(field_name, _)| field_name == name) + .map(|(_, t)| t) + } + + fn typecheck_equals(&self, other: &UdtDataType) -> bool { + // See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386 + + if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) { + return false; + } + if !any_string_empty_or_both_equal(&self.name, &other.name) { + return false; + } + + // A comment from cpp-driver: + //// UDT's can be considered equal as long as the mutual first fields shared + //// between them are equal. UDT's are append only as far as fields go, so a + //// newer 'version' of the UDT data type after a schema change event should be + //// treated as equivalent in this scenario, by simply looking at the first N + //// mutual fields they should share. + // + // Iterator returned from zip() is perfect for checking the first mutual fields. + for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) { + // Compare field names. + if field.0 != other_field.0 { + return false; + } + // Compare field types. + if unsafe { + !field + .1 + .get_unchecked() + .typecheck_equals(other_field.1.get_unchecked()) + } { + return false; + } + } + + true + } +} + +fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool { + s1.is_empty() || s2.is_empty() || s1 == s2 +} + +impl Default for UdtDataType { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub(crate) enum MapDataType { + Untyped, + Key(Arc), + KeyAndValue(Arc, Arc), +} + +#[derive(Debug)] +pub(crate) struct CassColumnSpec { + pub(crate) name: String, + pub(crate) data_type: Arc, +} + +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub(crate) enum CassDataTypeInner { + Value(CassValueType), + Udt(UdtDataType), + List { + // None stands for untyped list. + typ: Option>, + frozen: bool, + }, + Set { + // None stands for untyped set. + typ: Option>, + frozen: bool, + }, + Map { + typ: MapDataType, + frozen: bool, + }, + // Empty vector stands for untyped tuple. + Tuple(Vec>), + Custom(String), +} + +impl FFI for CassDataType { + type Origin = FromArc; +} + +impl CassDataTypeInner { + /// Checks for equality during typechecks. + /// + /// This takes into account the fact that tuples/collections may be untyped. + pub(crate) fn typecheck_equals(&self, other: &CassDataTypeInner) -> bool { + match self { + CassDataTypeInner::Value(t) => *t == other.get_value_type(), + CassDataTypeInner::Udt(udt) => match other { + CassDataTypeInner::Udt(other_udt) => udt.typecheck_equals(other_udt), + _ => false, + }, + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match other + { + CassDataTypeInner::List { typ: other_typ, .. } + | CassDataTypeInner::Set { typ: other_typ, .. } => { + // If one of them is list, and the other is set, fail the typecheck. + if self.get_value_type() != other.get_value_type() { + return false; + } + match (typ, other_typ) { + // One of them is untyped, skip the typecheck for subtype. + (None, _) | (_, None) => true, + (Some(typ), Some(other_typ)) => unsafe { + typ.get_unchecked() + .typecheck_equals(other_typ.get_unchecked()) + }, + } + } + _ => false, + }, + CassDataTypeInner::Map { typ: t, .. } => match other { + CassDataTypeInner::Map { typ: t_other, .. } => match (t, t_other) { + // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 + // In cpp-driver the types are held in a vector. + // The logic is following: + + // If either of vectors is empty, skip the typecheck. + (MapDataType::Untyped, _) => true, + (_, MapDataType::Untyped) => true, + + // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. + (MapDataType::Key(k), MapDataType::Key(k_other)) => unsafe { + k.get_unchecked().typecheck_equals(k_other.get_unchecked()) + }, + ( + MapDataType::KeyAndValue(k, v), + MapDataType::KeyAndValue(k_other, v_other), + ) => unsafe { + k.get_unchecked().typecheck_equals(k_other.get_unchecked()) + && v.get_unchecked().typecheck_equals(v_other.get_unchecked()) + }, + _ => false, + }, + _ => false, + }, + CassDataTypeInner::Tuple(sub) => match other { + CassDataTypeInner::Tuple(other_sub) => { + // If either of tuples is untyped, skip the typecheck for subtypes. + if sub.is_empty() || other_sub.is_empty() { + return true; + } + + // If both are non-empty, check for subtypes equality. + if sub.len() != other_sub.len() { + return false; + } + sub.iter() + .zip(other_sub.iter()) + .all(|(typ, other_typ)| unsafe { + typ.get_unchecked() + .typecheck_equals(other_typ.get_unchecked()) + }) + } + _ => false, + }, + CassDataTypeInner::Custom(_) => { + unimplemented!("cpp-rs-driver does not support custom types!") + } + } + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct CassDataType(UnsafeCell); + +/// PartialEq and Eq for test purposes. +#[cfg(test)] +impl PartialEq for CassDataType { + fn eq(&self, other: &Self) -> bool { + unsafe { self.get_unchecked() == other.get_unchecked() } + } +} +#[cfg(test)] +impl Eq for CassDataType {} + +unsafe impl Sync for CassDataType {} + +impl CassDataType { + pub(crate) unsafe fn get_unchecked(&self) -> &CassDataTypeInner { + unsafe { &*self.0.get() } + } + + #[allow(clippy::mut_from_ref)] + pub(crate) unsafe fn get_mut_unchecked(&self) -> &mut CassDataTypeInner { + unsafe { &mut *self.0.get() } + } + + pub(crate) const fn new(inner: CassDataTypeInner) -> CassDataType { + CassDataType(UnsafeCell::new(inner)) + } + + pub(crate) fn new_arced(inner: CassDataTypeInner) -> Arc { + Arc::new(CassDataType(UnsafeCell::new(inner))) + } +} + +fn native_type_to_cass_value_type(native_type: &NativeType) -> CassValueType { + use NativeType::*; + match native_type { + Ascii => CassValueType::CASS_VALUE_TYPE_ASCII, + Boolean => CassValueType::CASS_VALUE_TYPE_BOOLEAN, + Blob => CassValueType::CASS_VALUE_TYPE_BLOB, + Counter => CassValueType::CASS_VALUE_TYPE_COUNTER, + Date => CassValueType::CASS_VALUE_TYPE_DATE, + Decimal => CassValueType::CASS_VALUE_TYPE_DECIMAL, + Double => CassValueType::CASS_VALUE_TYPE_DOUBLE, + Duration => CassValueType::CASS_VALUE_TYPE_DURATION, + Float => CassValueType::CASS_VALUE_TYPE_FLOAT, + Int => CassValueType::CASS_VALUE_TYPE_INT, + BigInt => CassValueType::CASS_VALUE_TYPE_BIGINT, + // Rust Driver unifies both VARCHAR and TEXT into NativeType::Text. + // CPP Driver, in accordance to the CQL protocol, has separate types for VARCHAR and TEXT. + // Even worse, Rust Driver even does not handle CQL TEXT correctly! It errors out on TEXT + // type... + // As the DBs (Cassandra and ScyllaDB) seem to send the VARCHAR type in the protocol, + // we will assume that the NativeType::Text is actually a VARCHAR type. + Text => CassValueType::CASS_VALUE_TYPE_VARCHAR, + Timestamp => CassValueType::CASS_VALUE_TYPE_TIMESTAMP, + Inet => CassValueType::CASS_VALUE_TYPE_INET, + SmallInt => CassValueType::CASS_VALUE_TYPE_SMALL_INT, + TinyInt => CassValueType::CASS_VALUE_TYPE_TINY_INT, + Time => CassValueType::CASS_VALUE_TYPE_TIME, + Timeuuid => CassValueType::CASS_VALUE_TYPE_TIMEUUID, + Uuid => CassValueType::CASS_VALUE_TYPE_UUID, + Varint => CassValueType::CASS_VALUE_TYPE_VARINT, + + // NativeType is non_exhaustive + _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, + } +} + +impl CassDataTypeInner { + fn get_sub_data_type(&self, index: usize) -> Option<&Arc> { + match self { + CassDataTypeInner::Udt(udt_data_type) => { + udt_data_type.field_types.get(index).map(|(_, b)| b) + } + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { + if index > 0 { None } else { typ.as_ref() } + } + CassDataTypeInner::Map { + typ: MapDataType::Untyped, + .. + } => None, + CassDataTypeInner::Map { + typ: MapDataType::Key(k), + .. + } => (index == 0).then_some(k), + CassDataTypeInner::Map { + typ: MapDataType::KeyAndValue(k, v), + .. + } => match index { + 0 => Some(k), + 1 => Some(v), + _ => None, + }, + CassDataTypeInner::Tuple(v) => v.get(index), + _ => None, + } + } + + fn add_sub_data_type(&mut self, sub_type: Arc) -> Result<(), CassError> { + match self { + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => match typ { + Some(_) => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), + None => { + *typ = Some(sub_type); + Ok(()) + } + }, + CassDataTypeInner::Map { + typ: MapDataType::KeyAndValue(_, _), + .. + } => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), + CassDataTypeInner::Map { + typ: MapDataType::Key(k), + frozen, + } => { + *self = CassDataTypeInner::Map { + typ: MapDataType::KeyAndValue(k.clone(), sub_type), + frozen: *frozen, + }; + Ok(()) + } + CassDataTypeInner::Map { + typ: MapDataType::Untyped, + frozen, + } => { + *self = CassDataTypeInner::Map { + typ: MapDataType::Key(sub_type), + frozen: *frozen, + }; + Ok(()) + } + CassDataTypeInner::Tuple(types) => { + types.push(sub_type); + Ok(()) + } + _ => Err(CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE), + } + } + + pub(crate) fn get_udt_type(&self) -> &UdtDataType { + match self { + CassDataTypeInner::Udt(udt) => udt, + _ => panic!("Can get UDT out of non-UDT data type"), + } + } + + pub(crate) fn get_value_type(&self) -> CassValueType { + match &self { + CassDataTypeInner::Value(value_data_type) => *value_data_type, + CassDataTypeInner::Udt { .. } => CassValueType::CASS_VALUE_TYPE_UDT, + CassDataTypeInner::List { .. } => CassValueType::CASS_VALUE_TYPE_LIST, + CassDataTypeInner::Set { .. } => CassValueType::CASS_VALUE_TYPE_SET, + CassDataTypeInner::Map { .. } => CassValueType::CASS_VALUE_TYPE_MAP, + CassDataTypeInner::Tuple(..) => CassValueType::CASS_VALUE_TYPE_TUPLE, + CassDataTypeInner::Custom(..) => CassValueType::CASS_VALUE_TYPE_CUSTOM, + } + } +} + +pub(crate) fn get_column_type(column_type: &ColumnType) -> CassDataType { + use CollectionType::*; + use ColumnType::*; + let inner = match column_type { + Native(native) => CassDataTypeInner::Value(native_type_to_cass_value_type(native)), + Collection { + typ: List(boxed_type), + frozen, + } => CassDataTypeInner::List { + typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), + frozen: *frozen, + }, + Collection { + typ: Map(key, value), + frozen, + } => CassDataTypeInner::Map { + typ: MapDataType::KeyAndValue( + Arc::new(get_column_type(key.as_ref())), + Arc::new(get_column_type(value.as_ref())), + ), + frozen: *frozen, + }, + Collection { + typ: Set(boxed_type), + frozen, + } => CassDataTypeInner::Set { + typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))), + frozen: *frozen, + }, + UserDefinedType { definition, frozen } => CassDataTypeInner::Udt(UdtDataType { + field_types: definition + .field_types + .iter() + .map(|(name, col_type)| { + ( + name.clone().into_owned(), + Arc::new(get_column_type(col_type)), + ) + }) + .collect(), + keyspace: definition.keyspace.clone().into_owned(), + name: definition.name.clone().into_owned(), + frozen: *frozen, + }), + Tuple(v) => CassDataTypeInner::Tuple( + v.iter() + .map(|col_type| Arc::new(get_column_type(col_type))) + .collect(), + ), + + // ColumnType is non_exhaustive. + _ => CassDataTypeInner::Value(CassValueType::CASS_VALUE_TYPE_UNKNOWN), + }; + + CassDataType::new(inner) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_new( + value_type: CassValueType, +) -> CassOwnedSharedPtr { + let inner = match value_type { + CassValueType::CASS_VALUE_TYPE_LIST => CassDataTypeInner::List { + typ: None, + frozen: false, + }, + CassValueType::CASS_VALUE_TYPE_SET => CassDataTypeInner::Set { + typ: None, + frozen: false, + }, + CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataTypeInner::Tuple(Vec::new()), + CassValueType::CASS_VALUE_TYPE_MAP => CassDataTypeInner::Map { + typ: MapDataType::Untyped, + frozen: false, + }, + CassValueType::CASS_VALUE_TYPE_UDT => CassDataTypeInner::Udt(UdtDataType::new()), + CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataTypeInner::Custom("".to_string()), + CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ArcFFI::null(), + t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataTypeInner::Value(t), + _ => return ArcFFI::null(), + }; + ArcFFI::into_ptr(CassDataType::new_arced(inner)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_new_from_existing( + data_type: CassBorrowedSharedPtr, +) -> CassOwnedSharedPtr { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_new_from_existing!"); + return ArcFFI::null(); + }; + + ArcFFI::into_ptr(CassDataType::new_arced( + unsafe { data_type.get_unchecked() }.clone(), + )) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_new_tuple( + item_count: size_t, +) -> CassOwnedSharedPtr { + ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Tuple( + Vec::with_capacity(item_count as usize), + ))) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_new_udt( + field_count: size_t, +) -> CassOwnedSharedPtr { + ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Udt( + UdtDataType::with_capacity(field_count as usize), + ))) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_free(data_type: CassOwnedSharedPtr) { + ArcFFI::free(data_type); +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_type( + data_type: CassBorrowedSharedPtr, +) -> CassValueType { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_type!"); + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + }; + + unsafe { data_type.get_unchecked() }.get_value_type() +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_is_frozen( + data_type: CassBorrowedSharedPtr, +) -> cass_bool_t { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_is_frozen!"); + return cass_false; + }; + + let is_frozen = match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Udt(udt) => udt.frozen, + CassDataTypeInner::List { frozen, .. } => *frozen, + CassDataTypeInner::Set { frozen, .. } => *frozen, + CassDataTypeInner::Map { frozen, .. } => *frozen, + _ => false, + }; + + is_frozen as cass_bool_t +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_type_name( + data_type: CassBorrowedSharedPtr, + type_name: *mut *const c_char, + type_name_length: *mut size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_type_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Udt(UdtDataType { name, .. }) => { + unsafe { write_str_to_c(name, type_name, type_name_length) }; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_type_name( + data_type: CassBorrowedSharedPtr, + type_name: *const c_char, +) -> CassError { + unsafe { cass_data_type_set_type_name_n(data_type, type_name, strlen(type_name)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_type_name_n( + data_type_raw: CassBorrowedSharedPtr, + type_name: *const c_char, + type_name_length: size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_type_name_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let type_name_string = unsafe { ptr_to_cstr_n(type_name, type_name_length) } + .unwrap() + .to_string(); + + match unsafe { data_type.get_mut_unchecked() } { + CassDataTypeInner::Udt(udt_data_type) => { + udt_data_type.name = type_name_string; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_keyspace( + data_type: CassBorrowedSharedPtr, + keyspace: *mut *const c_char, + keyspace_length: *mut size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_keyspace!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Udt(UdtDataType { name, .. }) => { + unsafe { write_str_to_c(name, keyspace, keyspace_length) }; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_keyspace( + data_type: CassBorrowedSharedPtr, + keyspace: *const c_char, +) -> CassError { + unsafe { cass_data_type_set_keyspace_n(data_type, keyspace, strlen(keyspace)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_keyspace_n( + data_type: CassBorrowedSharedPtr, + keyspace: *const c_char, + keyspace_length: size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_keyspace_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let keyspace_string = unsafe { ptr_to_cstr_n(keyspace, keyspace_length) } + .unwrap() + .to_string(); + + match unsafe { data_type.get_mut_unchecked() } { + CassDataTypeInner::Udt(udt_data_type) => { + udt_data_type.keyspace = keyspace_string; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_class_name( + data_type: CassBorrowedSharedPtr, + class_name: *mut *const ::std::os::raw::c_char, + class_name_length: *mut size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_class_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Custom(name) => { + unsafe { write_str_to_c(name, class_name, class_name_length) }; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_class_name( + data_type: CassBorrowedSharedPtr, + class_name: *const ::std::os::raw::c_char, +) -> CassError { + unsafe { cass_data_type_set_class_name_n(data_type, class_name, strlen(class_name)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_set_class_name_n( + data_type: CassBorrowedSharedPtr, + class_name: *const ::std::os::raw::c_char, + class_name_length: size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_class_name_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let class_string = unsafe { ptr_to_cstr_n(class_name, class_name_length) } + .unwrap() + .to_string(); + match unsafe { data_type.get_mut_unchecked() } { + CassDataTypeInner::Custom(name) => { + *name = class_string; + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_sub_type_count( + data_type: CassBorrowedSharedPtr, +) -> size_t { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_type_count!"); + return 0; + }; + + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Value(..) => 0, + CassDataTypeInner::Udt(udt_data_type) => udt_data_type.field_types.len() as size_t, + CassDataTypeInner::List { typ, .. } | CassDataTypeInner::Set { typ, .. } => { + typ.is_some() as size_t + } + CassDataTypeInner::Map { typ, .. } => match typ { + MapDataType::Untyped => 0, + MapDataType::Key(_) => 1, + MapDataType::KeyAndValue(_, _) => 2, + }, + CassDataTypeInner::Tuple(v) => v.len() as size_t, + CassDataTypeInner::Custom(..) => 0, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_sub_type_count( + data_type: CassBorrowedSharedPtr, +) -> size_t { + unsafe { cass_data_type_sub_type_count(data_type) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_sub_data_type( + data_type: CassBorrowedSharedPtr, + index: size_t, +) -> CassBorrowedSharedPtr { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_data_type!"); + return ArcFFI::null(); + }; + + let sub_type: Option<&Arc> = + unsafe { data_type.get_unchecked() }.get_sub_data_type(index as usize); + + match sub_type { + None => ArcFFI::null(), + // Semantic from cppdriver which also returns non-owning pointer + Some(arc) => ArcFFI::as_ptr(arc), + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name( + data_type: CassBorrowedSharedPtr, + name: *const ::std::os::raw::c_char, +) -> CassBorrowedSharedPtr { + unsafe { cass_data_type_sub_data_type_by_name_n(data_type, name, strlen(name)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( + data_type: CassBorrowedSharedPtr, + name: *const ::std::os::raw::c_char, + name_length: size_t, +) -> CassBorrowedSharedPtr { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!( + "Provided null data type pointer to cass_data_type_sub_data_type_by_name_n!" + ); + return ArcFFI::null(); + }; + + let name_str = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap(); + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Udt(udt) => match udt.get_field_by_name(name_str) { + None => ArcFFI::null(), + Some(t) => ArcFFI::as_ptr(t), + }, + _ => ArcFFI::null(), + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_sub_type_name( + data_type: CassBorrowedSharedPtr, + index: size_t, + name: *mut *const ::std::os::raw::c_char, + name_length: *mut size_t, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_type_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_unchecked() } { + CassDataTypeInner::Udt(udt) => match udt.field_types.get(index as usize) { + None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + Some((field_name, _)) => { + unsafe { write_str_to_c(field_name, name, name_length) }; + CassError::CASS_OK + } + }, + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_type( + data_type: CassBorrowedSharedPtr, + sub_data_type: CassBorrowedSharedPtr, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_add_sub_type!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type) else { + tracing::error!("Provided null sub data type pointer to cass_data_type_add_sub_type!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_mut_unchecked() }.add_sub_data_type(sub_data_type) { + Ok(()) => CassError::CASS_OK, + Err(e) => e, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name( + data_type: CassBorrowedSharedPtr, + name: *const c_char, + sub_data_type: CassBorrowedSharedPtr, +) -> CassError { + unsafe { cass_data_type_add_sub_type_by_name_n(data_type, name, strlen(name), sub_data_type) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( + data_type_raw: CassBorrowedSharedPtr, + name: *const c_char, + name_length: size_t, + sub_data_type_raw: CassBorrowedSharedPtr, +) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { + tracing::error!( + "Provided null data type pointer to cass_data_type_add_sub_type_by_name_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type_raw) else { + tracing::error!( + "Provided null sub data type pointer to cass_data_type_add_sub_type_by_name_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let name_string = unsafe { ptr_to_cstr_n(name, name_length) } + .unwrap() + .to_string(); + + match unsafe { data_type.get_mut_unchecked() } { + CassDataTypeInner::Udt(udt_data_type) => { + // The Cpp Driver does not check whether field_types size + // exceeded field_count. + udt_data_type.field_types.push((name_string, sub_data_type)); + CassError::CASS_OK + } + _ => CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_value_type( + data_type: CassBorrowedSharedPtr, + sub_value_type: CassValueType, +) -> CassError { + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); + unsafe { cass_data_type_add_sub_type(data_type, ArcFFI::as_ptr(&sub_data_type)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( + data_type: CassBorrowedSharedPtr, + name: *const c_char, + sub_value_type: CassValueType, +) -> CassError { + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); + unsafe { cass_data_type_add_sub_type_by_name(data_type, name, ArcFFI::as_ptr(&sub_data_type)) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name_n( + data_type: CassBorrowedSharedPtr, + name: *const c_char, + name_length: size_t, + sub_value_type: CassValueType, +) -> CassError { + let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); + unsafe { + cass_data_type_add_sub_type_by_name_n( + data_type, + name, + name_length, + ArcFFI::as_ptr(&sub_data_type), + ) + } +} + +pub(crate) fn make_batch_type(type_: CassBatchType) -> Result { + match type_ { + CassBatchType::CASS_BATCH_TYPE_LOGGED => Ok(BatchType::Logged), + CassBatchType::CASS_BATCH_TYPE_UNLOGGED => Ok(BatchType::Unlogged), + CassBatchType::CASS_BATCH_TYPE_COUNTER => Ok(BatchType::Counter), + _ => Err(()), + } +} diff --git a/scylla-rust-wrapper/src/cql_types/mod.rs b/scylla-rust-wrapper/src/cql_types/mod.rs index 72e5a469..0d3f8439 100644 --- a/scylla-rust-wrapper/src/cql_types/mod.rs +++ b/scylla-rust-wrapper/src/cql_types/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod collection; +pub(crate) mod data_type; pub(crate) mod date_time; pub(crate) mod inet; pub(crate) mod tuple; diff --git a/scylla-rust-wrapper/src/cql_types/tuple.rs b/scylla-rust-wrapper/src/cql_types/tuple.rs index f11a435b..00c00be4 100644 --- a/scylla-rust-wrapper/src/cql_types/tuple.rs +++ b/scylla-rust-wrapper/src/cql_types/tuple.rs @@ -1,7 +1,6 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::CassDataType; -use crate::cass_types::CassDataTypeInner; +use crate::cql_types::data_type::{CassDataType, CassDataTypeInner}; use crate::cql_types::value; use crate::cql_types::value::CassCqlValue; use crate::types::*; @@ -136,8 +135,11 @@ make_binders!(user_type, cass_tuple_set_user_type); #[cfg(test)] mod tests { - use crate::cass_types::{ - CassValueType, cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, + use crate::{ + cass_types::CassValueType, + cql_types::data_type::{ + cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, + }, }; use super::{cass_tuple_data_type, cass_tuple_new}; diff --git a/scylla-rust-wrapper/src/cql_types/user_type.rs b/scylla-rust-wrapper/src/cql_types/user_type.rs index f2fd60fe..8ab8e79a 100644 --- a/scylla-rust-wrapper/src/cql_types/user_type.rs +++ b/scylla-rust-wrapper/src/cql_types/user_type.rs @@ -1,6 +1,6 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::{CassDataType, CassDataTypeInner}; +use crate::cql_types::data_type::{CassDataType, CassDataTypeInner}; use crate::cql_types::value::{self, CassCqlValue}; use crate::types::*; use std::os::raw::c_char; diff --git a/scylla-rust-wrapper/src/cql_types/value.rs b/scylla-rust-wrapper/src/cql_types/value.rs index ef480132..3cd2ee62 100644 --- a/scylla-rust-wrapper/src/cql_types/value.rs +++ b/scylla-rust-wrapper/src/cql_types/value.rs @@ -11,7 +11,8 @@ use scylla::serialize::writers::{CellWriter, WrittenCellProof}; use scylla::value::{CqlDate, CqlDecimal, CqlDuration}; use uuid::Uuid; -use crate::cass_types::{CassDataType, CassValueType}; +use crate::cass_types::CassValueType; +use crate::cql_types::data_type::CassDataType; /// A narrower version of rust driver's CqlValue. /// @@ -416,8 +417,11 @@ mod tests { use scylla::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ - cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType, UdtDataType}, - cql_types::value::{CassCqlValue, is_type_compatible}, + cass_types::CassValueType, + cql_types::{ + data_type::{CassDataType, CassDataTypeInner, MapDataType, UdtDataType}, + value::{CassCqlValue, is_type_compatible}, + }, }; fn all_value_data_types() -> Vec { diff --git a/scylla-rust-wrapper/src/iterator.rs b/scylla-rust-wrapper/src/iterator.rs index e19bb807..d840b756 100644 --- a/scylla-rust-wrapper/src/iterator.rs +++ b/scylla-rust-wrapper/src/iterator.rs @@ -6,7 +6,8 @@ use crate::argconv::{ CassOwnedExclusivePtr, FFI, FromBox, RefFFI, write_str_to_c, }; use crate::cass_error::CassError; -use crate::cass_types::{CassDataType, CassDataTypeInner, CassValueType, MapDataType}; +use crate::cass_types::CassValueType; +use crate::cql_types::data_type::{CassDataType, CassDataTypeInner, MapDataType}; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, }; diff --git a/scylla-rust-wrapper/src/metadata.rs b/scylla-rust-wrapper/src/metadata.rs index ca68a71e..667434f4 100644 --- a/scylla-rust-wrapper/src/metadata.rs +++ b/scylla-rust-wrapper/src/metadata.rs @@ -1,7 +1,6 @@ use crate::argconv::*; use crate::cass_column_types::CassColumnType; -use crate::cass_types::CassDataType; -use crate::cass_types::get_column_type; +use crate::cql_types::data_type::{CassDataType, get_column_type}; use crate::types::*; use scylla::cluster::metadata::{ColumnKind, Table}; use std::collections::HashMap; diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 06ef9e0d..8d5a70a5 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -4,7 +4,7 @@ use std::{os::raw::c_char, sync::Arc}; use crate::{ argconv::*, cass_error::CassError, - cass_types::{CassDataType, get_column_type}, + cql_types::data_type::{CassDataType, get_column_type}, query_result::CassResultMetadata, statement::{BoundPreparedStatement, CassStatement}, types::size_t, diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 89965526..34321023 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -2,7 +2,7 @@ use crate::argconv::*; use crate::cass_error::CassErrorResult; use crate::cass_error::{CassError, ToCassError}; pub use crate::cass_types::CassValueType; -use crate::cass_types::{ +use crate::cql_types::data_type::{ CassColumnSpec, CassDataType, CassDataTypeInner, MapDataType, cass_data_type_type, get_column_type, }; @@ -1178,7 +1178,7 @@ mod tests { use scylla::response::query_result::ColumnSpecs; use crate::argconv::{CConst, CassBorrowedSharedPtr, ptr_to_cstr_n}; - use crate::cass_types::{CassDataType, CassDataTypeInner}; + use crate::cql_types::data_type::{CassDataType, CassDataTypeInner}; use crate::{ argconv::{ArcFFI, RefFFI}, cass_error::CassError, diff --git a/scylla-rust-wrapper/src/ser_de_tests.rs b/scylla-rust-wrapper/src/ser_de_tests.rs index e2f0a3e5..7a67398a 100644 --- a/scylla-rust-wrapper/src/ser_de_tests.rs +++ b/scylla-rust-wrapper/src/ser_de_tests.rs @@ -26,7 +26,7 @@ use crate::argconv::{ CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr, CassOwnedExclusivePtr, RefFFI, }; use crate::cass_error::CassError; -use crate::cass_types::get_column_type; +use crate::cql_types::data_type::get_column_type; use crate::cql_types::inet::CassInet; use crate::cql_types::uuid::CassUuid; use crate::iterator::{ diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 50112fa9..6728d7b5 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -2,8 +2,8 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; use crate::cass_metrics_types::CassMetrics; -use crate::cass_types::get_column_type; use crate::cluster::CassCluster; +use crate::cql_types::data_type::get_column_type; use crate::cql_types::uuid::CassUuid; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; use crate::future::{CassFuture, CassFutureResult, CassResultValue}; From 340ae56c2235a82f449dc3560d30216720d1074e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 10:51:40 +0200 Subject: [PATCH 04/12] tree: decompose `cass_types` module The 3 remaining items there are just public re-exports. They are now moved to other (reasonable) modules, and `cass_types` module is deleted. --- scylla-rust-wrapper/src/api.rs | 5 ++++- scylla-rust-wrapper/src/batch.rs | 4 ++-- scylla-rust-wrapper/src/cass_error.rs | 2 +- scylla-rust-wrapper/src/cass_types.rs | 3 --- scylla-rust-wrapper/src/cluster.rs | 2 +- scylla-rust-wrapper/src/config_value.rs | 2 +- scylla-rust-wrapper/src/cql_types/collection.rs | 16 +++++++++------- scylla-rust-wrapper/src/cql_types/data_type.rs | 2 +- scylla-rust-wrapper/src/cql_types/mod.rs | 3 +++ scylla-rust-wrapper/src/cql_types/tuple.rs | 2 +- scylla-rust-wrapper/src/cql_types/value.rs | 4 ++-- scylla-rust-wrapper/src/exec_profile.rs | 6 +++--- scylla-rust-wrapper/src/iterator.rs | 2 +- scylla-rust-wrapper/src/lib.rs | 1 - scylla-rust-wrapper/src/misc.rs | 3 ++- scylla-rust-wrapper/src/query_result.rs | 4 ++-- scylla-rust-wrapper/src/session.rs | 4 ++-- scylla-rust-wrapper/src/statement.rs | 2 +- 18 files changed, 36 insertions(+), 31 deletions(-) delete mode 100644 scylla-rust-wrapper/src/cass_types.rs diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index eddbf51f..405ad219 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -895,8 +895,11 @@ pub mod custom_payload { pub mod consistency { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] + pub use crate::cql_types::{ + CassConsistency + }; + #[rustfmt::skip] pub use crate::misc::{ - CassConsistency, cass_consistency_string }; } diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index 33244d2e..21c8f357 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -2,10 +2,10 @@ use crate::argconv::{ ArcFFI, BoxFFI, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr, CassOwnedExclusivePtr, FFI, FromBox, }; +pub use crate::cass_batch_types::CassBatchType; use crate::cass_error::CassError; -pub use crate::cass_types::CassBatchType; -use crate::cass_types::CassConsistency; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::CassConsistency; use crate::cql_types::data_type::make_batch_type; use crate::cql_types::value::CassCqlValue; use crate::exec_profile::PerStatementExecProfile; diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index fb05a183..d697b712 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -1,5 +1,5 @@ use crate::argconv::*; -use crate::cass_types::CassConsistency; +use crate::cql_types::CassConsistency; use crate::misc::CassWriteType; use crate::types::*; use libc::c_char; diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs deleted file mode 100644 index 4ba981b2..00000000 --- a/scylla-rust-wrapper/src/cass_types.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub use crate::cass_batch_types::CassBatchType; -pub use crate::cass_consistency_types::CassConsistency; -pub use crate::cass_data_types::CassValueType; diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index e1d19f9c..41220dde 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::CassConsistency; use crate::config_value::MaybeUnsetConfig; +use crate::cql_types::CassConsistency; use crate::cql_types::uuid::CassUuid; use crate::exec_profile::{CassExecProfile, ExecProfileName, exec_profile_builder_modify}; use crate::load_balancing::{ diff --git a/scylla-rust-wrapper/src/config_value.rs b/scylla-rust-wrapper/src/config_value.rs index 73c30b61..3549d7cf 100644 --- a/scylla-rust-wrapper/src/config_value.rs +++ b/scylla-rust-wrapper/src/config_value.rs @@ -5,7 +5,7 @@ use scylla::{ statement::{Consistency, SerialConsistency}, }; -use crate::{cass_types::CassConsistency, retry_policy::CassRetryPolicy, types::cass_uint64_t}; +use crate::{cql_types::CassConsistency, retry_policy::CassRetryPolicy, types::cass_uint64_t}; /// Represents a configuration value that may or may not be set. /// If a configuration value is unset, it means that the default value diff --git a/scylla-rust-wrapper/src/cql_types/collection.rs b/scylla-rust-wrapper/src/cql_types/collection.rs index 562e9df8..3d462c56 100644 --- a/scylla-rust-wrapper/src/cql_types/collection.rs +++ b/scylla-rust-wrapper/src/cql_types/collection.rs @@ -281,13 +281,15 @@ mod tests { use crate::{ argconv::ArcFFI, cass_error::CassError, - cass_types::CassValueType, - cql_types::collection::{ - cass_collection_append_double, cass_collection_append_float, cass_collection_free, - }, - cql_types::data_type::{ - CassDataType, CassDataTypeInner, MapDataType, cass_data_type_add_sub_type, - cass_data_type_free, cass_data_type_new, + cql_types::{ + CassValueType, + collection::{ + cass_collection_append_double, cass_collection_append_float, cass_collection_free, + }, + data_type::{ + CassDataType, CassDataTypeInner, MapDataType, cass_data_type_add_sub_type, + cass_data_type_free, cass_data_type_new, + }, }, testing::assert_cass_error_eq, }; diff --git a/scylla-rust-wrapper/src/cql_types/data_type.rs b/scylla-rust-wrapper/src/cql_types/data_type.rs index bff491d5..5cb9885b 100644 --- a/scylla-rust-wrapper/src/cql_types/data_type.rs +++ b/scylla-rust-wrapper/src/cql_types/data_type.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::batch::CassBatchType; use crate::cass_error::CassError; -use crate::cass_types::CassValueType; +use crate::cql_types::CassValueType; use crate::types::*; use scylla::cluster::metadata::{CollectionType, NativeType}; use scylla::frame::response::result::ColumnType; diff --git a/scylla-rust-wrapper/src/cql_types/mod.rs b/scylla-rust-wrapper/src/cql_types/mod.rs index 0d3f8439..47ac0cd9 100644 --- a/scylla-rust-wrapper/src/cql_types/mod.rs +++ b/scylla-rust-wrapper/src/cql_types/mod.rs @@ -6,3 +6,6 @@ pub(crate) mod tuple; pub(crate) mod user_type; pub(crate) mod uuid; pub(crate) mod value; + +pub use crate::cass_consistency_types::CassConsistency; +pub use crate::cass_data_types::CassValueType; diff --git a/scylla-rust-wrapper/src/cql_types/tuple.rs b/scylla-rust-wrapper/src/cql_types/tuple.rs index 00c00be4..610d7d06 100644 --- a/scylla-rust-wrapper/src/cql_types/tuple.rs +++ b/scylla-rust-wrapper/src/cql_types/tuple.rs @@ -136,7 +136,7 @@ make_binders!(user_type, cass_tuple_set_user_type); #[cfg(test)] mod tests { use crate::{ - cass_types::CassValueType, + cql_types::CassValueType, cql_types::data_type::{ cass_data_type_add_sub_type, cass_data_type_free, cass_data_type_new, }, diff --git a/scylla-rust-wrapper/src/cql_types/value.rs b/scylla-rust-wrapper/src/cql_types/value.rs index 3cd2ee62..3a8bfe46 100644 --- a/scylla-rust-wrapper/src/cql_types/value.rs +++ b/scylla-rust-wrapper/src/cql_types/value.rs @@ -11,7 +11,7 @@ use scylla::serialize::writers::{CellWriter, WrittenCellProof}; use scylla::value::{CqlDate, CqlDecimal, CqlDuration}; use uuid::Uuid; -use crate::cass_types::CassValueType; +use crate::cql_types::CassValueType; use crate::cql_types::data_type::CassDataType; /// A narrower version of rust driver's CqlValue. @@ -417,7 +417,7 @@ mod tests { use scylla::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ - cass_types::CassValueType, + cql_types::CassValueType, cql_types::{ data_type::{CassDataType, CassDataTypeInner, MapDataType, UdtDataType}, value::{CassCqlValue, is_type_compatible}, diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index 337704f6..36ca67d6 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -20,11 +20,11 @@ use crate::argconv::{ }; use crate::batch::CassBatch; use crate::cass_error::CassError; -use crate::cass_types::CassConsistency; use crate::cluster::{ set_load_balance_dc_aware_n, set_load_balance_rack_aware_n, update_comma_delimited_list, }; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::CassConsistency; use crate::load_balancing::{LoadBalancingConfig, LoadBalancingKind}; use crate::retry_policy::CassRetryPolicy; use crate::session::CassConnectedSession; @@ -869,14 +869,14 @@ mod tests { use super::*; use crate::argconv::CassPtr; + use crate::cql_types::CassConsistency; use crate::retry_policy::{ cass_retry_policy_downgrading_consistency_new, cass_retry_policy_free, }; use crate::testing::{assert_cass_error_eq, setup_tracing}; use crate::{ argconv::{make_c_str, str_to_c_str_n}, - batch::{cass_batch_add_statement, cass_batch_free, cass_batch_new}, - cass_types::CassBatchType, + batch::{CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new}, statement::{cass_statement_free, cass_statement_new}, }; diff --git a/scylla-rust-wrapper/src/iterator.rs b/scylla-rust-wrapper/src/iterator.rs index d840b756..636c8302 100644 --- a/scylla-rust-wrapper/src/iterator.rs +++ b/scylla-rust-wrapper/src/iterator.rs @@ -6,7 +6,7 @@ use crate::argconv::{ CassOwnedExclusivePtr, FFI, FromBox, RefFFI, write_str_to_c, }; use crate::cass_error::CassError; -use crate::cass_types::CassValueType; +use crate::cql_types::CassValueType; use crate::cql_types::data_type::{CassDataType, CassDataTypeInner, MapDataType}; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index 593c6d71..0d8424a8 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -12,7 +12,6 @@ pub mod api; pub mod argconv; pub(crate) mod batch; pub(crate) mod cass_error; -pub(crate) mod cass_types; pub(crate) mod cluster; pub(crate) mod config_value; pub(crate) mod cql_types; diff --git a/scylla-rust-wrapper/src/misc.rs b/scylla-rust-wrapper/src/misc.rs index ef61f0ef..d7496ca4 100644 --- a/scylla-rust-wrapper/src/misc.rs +++ b/scylla-rust-wrapper/src/misc.rs @@ -1,6 +1,7 @@ use std::ffi::{CStr, c_char}; -pub use crate::{cass_error_types::CassWriteType, cass_types::CassConsistency}; +pub use crate::cass_error_types::CassWriteType; +use crate::cql_types::CassConsistency; impl CassConsistency { pub(crate) fn as_cstr(&self) -> &'static CStr { diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 34321023..7bce5280 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::cass_error::CassErrorResult; use crate::cass_error::{CassError, ToCassError}; -pub use crate::cass_types::CassValueType; +pub use crate::cql_types::CassValueType; use crate::cql_types::data_type::{ CassColumnSpec, CassDataType, CassDataTypeInner, MapDataType, cass_data_type_type, get_column_type, @@ -1182,7 +1182,7 @@ mod tests { use crate::{ argconv::{ArcFFI, RefFFI}, cass_error::CassError, - cass_types::CassValueType, + cql_types::CassValueType, query_result::{ cass_result_column_data_type, cass_result_column_name, cass_result_first_row, size_t, }, diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 6728d7b5..54aa50ce 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -886,9 +886,9 @@ mod tests { use crate::{ argconv::make_c_str, batch::{ - cass_batch_add_statement, cass_batch_free, cass_batch_new, cass_batch_set_retry_policy, + CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, + cass_batch_set_retry_policy, }, - cass_types::CassBatchType, cluster::{ cass_cluster_free, cass_cluster_new, cass_cluster_set_client_id, cass_cluster_set_contact_points_n, cass_cluster_set_execution_profile, diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index 66687170..020c258e 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::CassConsistency; use crate::config_value::{MaybeUnsetConfig, RequestTimeout}; +use crate::cql_types::CassConsistency; use crate::cql_types::inet::CassInet; use crate::cql_types::value::{self, CassCqlValue}; use crate::exec_profile::PerStatementExecProfile; From 926f2814018b82ff5257dbdd93e1cf74f197caa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 11:00:43 +0200 Subject: [PATCH 05/12] tree: move `misc` mod contents to `cql_types` mod These entities are all related to CQL protocol types, so it makes sense to put them there. This annihilates the `misc` module, which is good because the name was not very descriptive. --- scylla-rust-wrapper/src/api.rs | 9 ++--- scylla-rust-wrapper/src/cass_error.rs | 3 +- scylla-rust-wrapper/src/cql_types/mod.rs | 49 +++++++++++++++++++++++ scylla-rust-wrapper/src/lib.rs | 1 - scylla-rust-wrapper/src/misc.rs | 50 ------------------------ 5 files changed, 53 insertions(+), 59 deletions(-) delete mode 100644 scylla-rust-wrapper/src/misc.rs diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index 405ad219..c2d494a2 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -896,18 +896,15 @@ pub mod consistency { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] pub use crate::cql_types::{ - CassConsistency - }; - #[rustfmt::skip] - pub use crate::misc::{ - cass_consistency_string + CassConsistency, + cass_consistency_string, }; } pub mod write_type { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::misc::{ + pub use crate::cql_types::{ CassWriteType, cass_write_type_string }; diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index d697b712..904341d3 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -1,6 +1,5 @@ use crate::argconv::*; -use crate::cql_types::CassConsistency; -use crate::misc::CassWriteType; +use crate::cql_types::{CassConsistency, CassWriteType}; use crate::types::*; use libc::c_char; use scylla::deserialize::DeserializationError; diff --git a/scylla-rust-wrapper/src/cql_types/mod.rs b/scylla-rust-wrapper/src/cql_types/mod.rs index 47ac0cd9..36732950 100644 --- a/scylla-rust-wrapper/src/cql_types/mod.rs +++ b/scylla-rust-wrapper/src/cql_types/mod.rs @@ -9,3 +9,52 @@ pub(crate) mod value; pub use crate::cass_consistency_types::CassConsistency; pub use crate::cass_data_types::CassValueType; +pub use crate::cass_error_types::CassWriteType; + +use std::ffi::{CStr, c_char}; + +impl CassConsistency { + pub(crate) fn as_cstr(&self) -> &'static CStr { + match *self { + Self::CASS_CONSISTENCY_UNKNOWN => c"UNKNOWN", + Self::CASS_CONSISTENCY_ANY => c"ANY", + Self::CASS_CONSISTENCY_ONE => c"ONE", + Self::CASS_CONSISTENCY_TWO => c"TWO", + Self::CASS_CONSISTENCY_THREE => c"THREE", + Self::CASS_CONSISTENCY_QUORUM => c"QUORUM", + Self::CASS_CONSISTENCY_ALL => c"ALL", + Self::CASS_CONSISTENCY_LOCAL_QUORUM => c"LOCAL_QUORUM", + Self::CASS_CONSISTENCY_EACH_QUORUM => c"EACH_QUORUM", + Self::CASS_CONSISTENCY_SERIAL => c"SERIAL", + Self::CASS_CONSISTENCY_LOCAL_SERIAL => c"LOCAL_SERIAL", + Self::CASS_CONSISTENCY_LOCAL_ONE => c"LOCAL_ONE", + _ => c"", + } + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_consistency_string(consistency: CassConsistency) -> *const c_char { + consistency.as_cstr().as_ptr() as *const c_char +} + +impl CassWriteType { + pub(crate) fn as_cstr(&self) -> &'static CStr { + match *self { + Self::CASS_WRITE_TYPE_SIMPLE => c"SIMPLE", + Self::CASS_WRITE_TYPE_BATCH => c"BATCH", + Self::CASS_WRITE_TYPE_UNLOGGED_BATCH => c"UNLOGGED_BATCH", + Self::CASS_WRITE_TYPE_COUNTER => c"COUNTER", + Self::CASS_WRITE_TYPE_BATCH_LOG => c"BATCH_LOG", + Self::CASS_WRITE_TYPE_CAS => c"CAS", + Self::CASS_WRITE_TYPE_VIEW => c"VIEW", + Self::CASS_WRITE_TYPE_CDC => c"CDC", + _ => c"", + } + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_write_type_string(write_type: CassWriteType) -> *const c_char { + write_type.as_cstr().as_ptr() as *const c_char +} diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index 0d8424a8..1cb9e77e 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -23,7 +23,6 @@ pub(crate) mod iterator; mod load_balancing; mod logging; pub(crate) mod metadata; -pub(crate) mod misc; pub(crate) mod prepared; pub(crate) mod query_result; pub(crate) mod retry_policy; diff --git a/scylla-rust-wrapper/src/misc.rs b/scylla-rust-wrapper/src/misc.rs deleted file mode 100644 index d7496ca4..00000000 --- a/scylla-rust-wrapper/src/misc.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::ffi::{CStr, c_char}; - -pub use crate::cass_error_types::CassWriteType; -use crate::cql_types::CassConsistency; - -impl CassConsistency { - pub(crate) fn as_cstr(&self) -> &'static CStr { - match *self { - Self::CASS_CONSISTENCY_UNKNOWN => c"UNKNOWN", - Self::CASS_CONSISTENCY_ANY => c"ANY", - Self::CASS_CONSISTENCY_ONE => c"ONE", - Self::CASS_CONSISTENCY_TWO => c"TWO", - Self::CASS_CONSISTENCY_THREE => c"THREE", - Self::CASS_CONSISTENCY_QUORUM => c"QUORUM", - Self::CASS_CONSISTENCY_ALL => c"ALL", - Self::CASS_CONSISTENCY_LOCAL_QUORUM => c"LOCAL_QUORUM", - Self::CASS_CONSISTENCY_EACH_QUORUM => c"EACH_QUORUM", - Self::CASS_CONSISTENCY_SERIAL => c"SERIAL", - Self::CASS_CONSISTENCY_LOCAL_SERIAL => c"LOCAL_SERIAL", - Self::CASS_CONSISTENCY_LOCAL_ONE => c"LOCAL_ONE", - _ => c"", - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_consistency_string(consistency: CassConsistency) -> *const c_char { - consistency.as_cstr().as_ptr() as *const c_char -} - -impl CassWriteType { - pub(crate) fn as_cstr(&self) -> &'static CStr { - match *self { - Self::CASS_WRITE_TYPE_SIMPLE => c"SIMPLE", - Self::CASS_WRITE_TYPE_BATCH => c"BATCH", - Self::CASS_WRITE_TYPE_UNLOGGED_BATCH => c"UNLOGGED_BATCH", - Self::CASS_WRITE_TYPE_COUNTER => c"COUNTER", - Self::CASS_WRITE_TYPE_BATCH_LOG => c"BATCH_LOG", - Self::CASS_WRITE_TYPE_CAS => c"CAS", - Self::CASS_WRITE_TYPE_VIEW => c"VIEW", - Self::CASS_WRITE_TYPE_CDC => c"CDC", - _ => c"", - } - } -} - -#[unsafe(no_mangle)] -pub unsafe extern "C" fn cass_write_type_string(write_type: CassWriteType) -> *const c_char { - write_type.as_cstr().as_ptr() as *const c_char -} From 2213a54a60bd78f4308f5f979dbde3f2b275618b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 10:55:52 +0200 Subject: [PATCH 06/12] tree: group statements in one module It makes sense to put all statement-related types (simple statements, prepared statements, and batches) in one module - `statements`. --- scylla-rust-wrapper/src/api.rs | 6 +++--- scylla-rust-wrapper/src/cass_error.rs | 2 +- scylla-rust-wrapper/src/cql_types/data_type.rs | 2 +- scylla-rust-wrapper/src/exec_profile.rs | 10 ++++++---- scylla-rust-wrapper/src/future.rs | 2 +- scylla-rust-wrapper/src/integration_testing.rs | 4 ++-- scylla-rust-wrapper/src/lib.rs | 4 +--- scylla-rust-wrapper/src/session.rs | 18 ++++++++++-------- .../src/{ => statements}/batch.rs | 2 +- scylla-rust-wrapper/src/statements/mod.rs | 3 +++ .../src/{ => statements}/prepared.rs | 2 +- .../src/{ => statements}/statement.rs | 4 ++-- 12 files changed, 32 insertions(+), 27 deletions(-) rename scylla-rust-wrapper/src/{ => statements}/batch.rs (99%) create mode 100644 scylla-rust-wrapper/src/statements/mod.rs rename scylla-rust-wrapper/src/{ => statements}/prepared.rs (98%) rename scylla-rust-wrapper/src/{ => statements}/statement.rs (99%) diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index c2d494a2..77d5a7ae 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -387,7 +387,7 @@ pub mod future { pub mod statement { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::statement::{ + pub use crate::statements::statement::{ CassStatement, cass_statement_bind_bool, cass_statement_bind_bool_by_name, @@ -483,7 +483,7 @@ pub mod statement { pub mod prepared { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::prepared::{ + pub use crate::statements::prepared::{ CassPrepared, cass_prepared_bind, cass_prepared_free, @@ -497,7 +497,7 @@ pub mod prepared { pub mod batch { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::batch::{ + pub use crate::statements::batch::{ CassBatch, CassBatchType, cass_batch_add_statement, diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index 904341d3..74eeecac 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -14,7 +14,7 @@ use thiserror::Error; // Re-export error types. pub use crate::cass_error_types::{CassError, CassErrorSource}; -use crate::statement::UnknownNamedParameterError; +use crate::statements::statement::UnknownNamedParameterError; pub(crate) trait ToCassError { fn to_cass_error(&self) -> CassError; diff --git a/scylla-rust-wrapper/src/cql_types/data_type.rs b/scylla-rust-wrapper/src/cql_types/data_type.rs index 5cb9885b..a9ffa06c 100644 --- a/scylla-rust-wrapper/src/cql_types/data_type.rs +++ b/scylla-rust-wrapper/src/cql_types/data_type.rs @@ -1,7 +1,7 @@ use crate::argconv::*; -use crate::batch::CassBatchType; use crate::cass_error::CassError; use crate::cql_types::CassValueType; +use crate::statements::batch::CassBatchType; use crate::types::*; use scylla::cluster::metadata::{CollectionType, NativeType}; use scylla::frame::response::result::ColumnType; diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index 36ca67d6..0fbe4e52 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -18,7 +18,6 @@ use crate::argconv::{ ArcFFI, BoxFFI, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr, CassOwnedExclusivePtr, FFI, FromBox, ptr_to_cstr_n, strlen, }; -use crate::batch::CassBatch; use crate::cass_error::CassError; use crate::cluster::{ set_load_balance_dc_aware_n, set_load_balance_rack_aware_n, update_comma_delimited_list, @@ -28,7 +27,8 @@ use crate::cql_types::CassConsistency; use crate::load_balancing::{LoadBalancingConfig, LoadBalancingKind}; use crate::retry_policy::CassRetryPolicy; use crate::session::CassConnectedSession; -use crate::statement::CassStatement; +use crate::statements::batch::CassBatch; +use crate::statements::statement::CassStatement; use crate::types::{ cass_bool_t, cass_double_t, cass_int32_t, cass_int64_t, cass_uint32_t, cass_uint64_t, size_t, }; @@ -876,8 +876,10 @@ mod tests { use crate::testing::{assert_cass_error_eq, setup_tracing}; use crate::{ argconv::{make_c_str, str_to_c_str_n}, - batch::{CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new}, - statement::{cass_statement_free, cass_statement_new}, + statements::batch::{ + CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, + }, + statements::statement::{cass_statement_free, cass_statement_new}, }; use assert_matches::assert_matches; diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index db97a247..a524ce54 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -1,9 +1,9 @@ use crate::argconv::*; use crate::cass_error::{CassError, CassErrorMessage, CassErrorResult, ToCassError as _}; use crate::cql_types::uuid::CassUuid; -use crate::prepared::CassPrepared; use crate::query_result::{CassNode, CassResult}; use crate::runtime::Runtime; +use crate::statements::prepared::CassPrepared; use crate::types::*; use futures::future; use std::future::Future; diff --git a/scylla-rust-wrapper/src/integration_testing.rs b/scylla-rust-wrapper/src/integration_testing.rs index c5cd3056..d350b7fb 100644 --- a/scylla-rust-wrapper/src/integration_testing.rs +++ b/scylla-rust-wrapper/src/integration_testing.rs @@ -12,13 +12,13 @@ use crate::argconv::{ ArcFFI, BoxFFI, CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr, CassOwnedSharedPtr, }; -use crate::batch::CassBatch; use crate::cluster::CassCluster; use crate::future::{CassFuture, CassResultValue}; use crate::retry_policy::CassRetryPolicy; #[cfg(test)] use crate::runtime::Runtime; -use crate::statement::{BoundStatement, CassStatement}; +use crate::statements::batch::CassBatch; +use crate::statements::statement::{BoundStatement, CassStatement}; use crate::types::{cass_bool_t, cass_int32_t, cass_uint16_t, cass_uint64_t, size_t}; #[unsafe(no_mangle)] diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index 1cb9e77e..04f7a11a 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -10,7 +10,6 @@ mod binding; pub mod api; // pub, because doctests defined in `argconv` module need to access it. pub mod argconv; -pub(crate) mod batch; pub(crate) mod cass_error; pub(crate) mod cluster; pub(crate) mod config_value; @@ -23,7 +22,6 @@ pub(crate) mod iterator; mod load_balancing; mod logging; pub(crate) mod metadata; -pub(crate) mod prepared; pub(crate) mod query_result; pub(crate) mod retry_policy; pub(crate) mod runtime; @@ -31,7 +29,7 @@ pub(crate) mod runtime; mod ser_de_tests; pub(crate) mod session; pub(crate) mod ssl; -pub(crate) mod statement; +pub(crate) mod statements; #[cfg(test)] pub(crate) mod testing; pub(crate) mod timestamp_generator; diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 54aa50ce..b625f808 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,5 +1,4 @@ use crate::argconv::*; -use crate::batch::CassBatch; use crate::cass_error::*; use crate::cass_metrics_types::CassMetrics; use crate::cluster::CassCluster; @@ -9,10 +8,11 @@ use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProf use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; use crate::metadata::{CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta}; -use crate::prepared::CassPrepared; use crate::query_result::{CassResult, CassResultKind, CassResultMetadata}; use crate::runtime::Runtime; -use crate::statement::{BoundStatement, CassStatement, SimpleQueryRowSerializer}; +use crate::statements::batch::CassBatch; +use crate::statements::prepared::CassPrepared; +use crate::statements::statement::{BoundStatement, CassStatement, SimpleQueryRowSerializer}; use crate::types::size_t; use scylla::client::execution_profile::ExecutionProfileHandle; use scylla::client::session::Session; @@ -885,10 +885,6 @@ mod tests { use super::*; use crate::{ argconv::make_c_str, - batch::{ - CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, - cass_batch_set_retry_policy, - }, cluster::{ cass_cluster_free, cass_cluster_new, cass_cluster_set_client_id, cass_cluster_set_contact_points_n, cass_cluster_set_execution_profile, @@ -907,7 +903,13 @@ mod tests { retry_policy::{ CassRetryPolicy, cass_retry_policy_default_new, cass_retry_policy_fallthrough_new, }, - statement::{cass_statement_free, cass_statement_new, cass_statement_set_retry_policy}, + statements::batch::{ + CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, + cass_batch_set_retry_policy, + }, + statements::statement::{ + cass_statement_free, cass_statement_new, cass_statement_set_retry_policy, + }, testing::{ assert_cass_error_eq, cass_future_wait_check_and_free, generic_drop_queries_rules, handshake_rules, mock_init_rules, setup_tracing, test_with_one_proxy, diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/statements/batch.rs similarity index 99% rename from scylla-rust-wrapper/src/batch.rs rename to scylla-rust-wrapper/src/statements/batch.rs index 21c8f357..d490cdf0 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/statements/batch.rs @@ -10,7 +10,7 @@ use crate::cql_types::data_type::make_batch_type; use crate::cql_types::value::CassCqlValue; use crate::exec_profile::PerStatementExecProfile; use crate::retry_policy::CassRetryPolicy; -use crate::statement::{BoundStatement, CassStatement}; +use crate::statements::statement::{BoundStatement, CassStatement}; use crate::types::*; use scylla::statement::batch::Batch; use scylla::statement::{Consistency, SerialConsistency}; diff --git a/scylla-rust-wrapper/src/statements/mod.rs b/scylla-rust-wrapper/src/statements/mod.rs new file mode 100644 index 00000000..be8a6349 --- /dev/null +++ b/scylla-rust-wrapper/src/statements/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod batch; +pub(crate) mod prepared; +pub(crate) mod statement; diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/statements/prepared.rs similarity index 98% rename from scylla-rust-wrapper/src/prepared.rs rename to scylla-rust-wrapper/src/statements/prepared.rs index 8d5a70a5..c40497ef 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/statements/prepared.rs @@ -6,7 +6,7 @@ use crate::{ cass_error::CassError, cql_types::data_type::{CassDataType, get_column_type}, query_result::CassResultMetadata, - statement::{BoundPreparedStatement, CassStatement}, + statements::statement::{BoundPreparedStatement, CassStatement}, types::size_t, }; use scylla::statement::prepared::PreparedStatement; diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statements/statement.rs similarity index 99% rename from scylla-rust-wrapper/src/statement.rs rename to scylla-rust-wrapper/src/statements/statement.rs index 020c258e..25cdd0e5 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statements/statement.rs @@ -5,9 +5,9 @@ use crate::cql_types::CassConsistency; use crate::cql_types::inet::CassInet; use crate::cql_types::value::{self, CassCqlValue}; use crate::exec_profile::PerStatementExecProfile; -use crate::prepared::CassPrepared; use crate::query_result::{CassNode, CassResult}; use crate::retry_policy::CassRetryPolicy; +use crate::statements::prepared::CassPrepared; use crate::types::*; use scylla::frame::types::Consistency; use scylla::policies::load_balancing::{NodeIdentifier, SingleTargetLoadBalancingPolicy}; @@ -876,7 +876,7 @@ mod tests { use crate::argconv::{BoxFFI, RefFFI}; use crate::cass_error::CassError; use crate::cql_types::inet::CassInet; - use crate::statement::{ + use crate::statements::statement::{ cass_statement_set_host, cass_statement_set_host_inet, cass_statement_set_node, }; use crate::testing::assert_cass_error_eq; From 4581d2138a08c497700a61a018a1afdcc66a7392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 11:13:08 +0200 Subject: [PATCH 07/12] tree: extract testing module 1. A new module `testing/mod.rs` is created to re-export submodules. 2. `testing.rs` is renamed to `testing/utils.rs`. 3. `integration_testing.rs` is moved to `testing/integration.rs`. 4. `ser_de_tests.rs` is moved to `testing/ser_de_tests.rs`. 5. All imports in the codebase are updated to reflect the new module structure. --- scylla-rust-wrapper/src/api.rs | 4 ++-- scylla-rust-wrapper/src/cluster.rs | 2 +- scylla-rust-wrapper/src/cql_types/collection.rs | 2 +- scylla-rust-wrapper/src/exec_profile.rs | 2 +- scylla-rust-wrapper/src/future.rs | 8 ++++---- scylla-rust-wrapper/src/lib.rs | 6 +----- scylla-rust-wrapper/src/retry_policy.rs | 2 +- scylla-rust-wrapper/src/runtime.rs | 2 +- scylla-rust-wrapper/src/session.rs | 4 ++-- scylla-rust-wrapper/src/statements/statement.rs | 2 +- .../{integration_testing.rs => testing/integration.rs} | 0 scylla-rust-wrapper/src/testing/mod.rs | 8 ++++++++ scylla-rust-wrapper/src/{ => testing}/ser_de_tests.rs | 2 +- scylla-rust-wrapper/src/{testing.rs => testing/utils.rs} | 0 14 files changed, 24 insertions(+), 20 deletions(-) rename scylla-rust-wrapper/src/{integration_testing.rs => testing/integration.rs} (100%) create mode 100644 scylla-rust-wrapper/src/testing/mod.rs rename scylla-rust-wrapper/src/{ => testing}/ser_de_tests.rs (99%) rename scylla-rust-wrapper/src/{testing.rs => testing/utils.rs} (100%) diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index 77d5a7ae..b0127f5a 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -955,7 +955,7 @@ pub mod alloc { pub mod integration_testing { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::integration_testing::{ + pub use crate::testing::integration::{ IgnoringRetryPolicy, testing_batch_set_sleeping_history_listener, testing_cluster_get_connect_timeout, @@ -974,7 +974,7 @@ pub mod integration_testing { /// and at the same time the functions are not yet implemented in the wrapper. // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] - pub use crate::integration_testing::stubs::{ + pub use crate::testing::integration::stubs::{ CassAggregateMeta, CassAuthenticator, CassCustomPayload, diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index 41220dde..5f3cd233 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -1617,7 +1617,7 @@ pub unsafe extern "C" fn cass_cluster_set_metadata_request_serverside_timeout( #[cfg(test)] mod tests { - use crate::testing::{assert_cass_error_eq, setup_tracing}; + use crate::testing::utils::{assert_cass_error_eq, setup_tracing}; use super::*; use crate::{ diff --git a/scylla-rust-wrapper/src/cql_types/collection.rs b/scylla-rust-wrapper/src/cql_types/collection.rs index 3d462c56..c61d5bed 100644 --- a/scylla-rust-wrapper/src/cql_types/collection.rs +++ b/scylla-rust-wrapper/src/cql_types/collection.rs @@ -291,7 +291,7 @@ mod tests { cass_data_type_free, cass_data_type_new, }, }, - testing::assert_cass_error_eq, + testing::utils::assert_cass_error_eq, }; use super::{ diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index 0fbe4e52..f560b437 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -873,7 +873,7 @@ mod tests { use crate::retry_policy::{ cass_retry_policy_downgrading_consistency_new, cass_retry_policy_free, }; - use crate::testing::{assert_cass_error_eq, setup_tracing}; + use crate::testing::utils::{assert_cass_error_eq, setup_tracing}; use crate::{ argconv::{make_c_str, str_to_c_str_n}, statements::batch::{ diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index a524ce54..a39f6fd7 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -84,7 +84,7 @@ struct ResolvableFuture { wait_for_value: Condvar, #[cfg(cpp_integration_testing)] - recording_listener: Option>, + recording_listener: Option>, } pub struct CassFuture { @@ -119,7 +119,7 @@ impl CassFuture { runtime: Arc, fut: impl Future + Send + 'static, #[cfg(cpp_integration_testing)] recording_listener: Option< - Arc, + Arc, >, ) -> CassOwnedSharedPtr { Self::new_from_future( @@ -135,7 +135,7 @@ impl CassFuture { runtime: Arc, fut: impl Future + Send + 'static, #[cfg(cpp_integration_testing)] recording_listener: Option< - Arc, + Arc, >, ) -> Arc { let cass_fut = Arc::new(CassFuture { @@ -675,7 +675,7 @@ pub unsafe extern "C" fn cass_future_coordinator( #[cfg(test)] mod tests { - use crate::testing::{assert_cass_error_eq, assert_cass_future_error_message_eq}; + use crate::testing::utils::{assert_cass_error_eq, assert_cass_future_error_message_eq}; use super::*; use std::{ diff --git a/scylla-rust-wrapper/src/lib.rs b/scylla-rust-wrapper/src/lib.rs index 04f7a11a..096f7c7d 100644 --- a/scylla-rust-wrapper/src/lib.rs +++ b/scylla-rust-wrapper/src/lib.rs @@ -16,8 +16,6 @@ pub(crate) mod config_value; pub(crate) mod cql_types; pub(crate) mod exec_profile; pub(crate) mod future; -#[cfg(cpp_integration_testing)] -pub(crate) mod integration_testing; pub(crate) mod iterator; mod load_balancing; mod logging; @@ -25,12 +23,10 @@ pub(crate) mod metadata; pub(crate) mod query_result; pub(crate) mod retry_policy; pub(crate) mod runtime; -#[cfg(test)] -mod ser_de_tests; pub(crate) mod session; pub(crate) mod ssl; pub(crate) mod statements; -#[cfg(test)] +#[cfg(any(test, cpp_integration_testing))] pub(crate) mod testing; pub(crate) mod timestamp_generator; diff --git a/scylla-rust-wrapper/src/retry_policy.rs b/scylla-rust-wrapper/src/retry_policy.rs index 2a85a128..364a602c 100644 --- a/scylla-rust-wrapper/src/retry_policy.rs +++ b/scylla-rust-wrapper/src/retry_policy.rs @@ -69,7 +69,7 @@ pub enum CassRetryPolicy { DowngradingConsistency(Arc), Logging(Arc), #[cfg(cpp_integration_testing)] - Ignoring(Arc), + Ignoring(Arc), } impl RetryPolicy for CassRetryPolicy { diff --git a/scylla-rust-wrapper/src/runtime.rs b/scylla-rust-wrapper/src/runtime.rs index 20f92b9d..343eea33 100644 --- a/scylla-rust-wrapper/src/runtime.rs +++ b/scylla-rust-wrapper/src/runtime.rs @@ -136,7 +136,7 @@ mod tests { cluster::{cass_cluster_free, cass_cluster_new, cass_cluster_set_contact_points_n}, future::cass_future_free, session::{cass_session_close, cass_session_connect, cass_session_free, cass_session_new}, - testing::{ + testing::utils::{ assert_cass_error_eq, cass_future_wait_check_and_free, mock_init_rules, rusty_fork_test_with_proxy, setup_tracing, test_with_one_proxy_at_ip, }, diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index b625f808..3e0ca68a 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -389,7 +389,7 @@ pub unsafe extern "C" fn cass_session_execute( #[cfg(cpp_integration_testing)] let recording_listener = statement_opt.record_hosts.then(|| { let recording_listener = - Arc::new(crate::integration_testing::RecordingHistoryListener::new()); + Arc::new(crate::testing::integration::RecordingHistoryListener::new()); match statement { BoundStatement::Simple(ref mut unprepared) => { unprepared @@ -910,7 +910,7 @@ mod tests { statements::statement::{ cass_statement_free, cass_statement_new, cass_statement_set_retry_policy, }, - testing::{ + testing::utils::{ assert_cass_error_eq, cass_future_wait_check_and_free, generic_drop_queries_rules, handshake_rules, mock_init_rules, setup_tracing, test_with_one_proxy, }, diff --git a/scylla-rust-wrapper/src/statements/statement.rs b/scylla-rust-wrapper/src/statements/statement.rs index 25cdd0e5..643c8c1b 100644 --- a/scylla-rust-wrapper/src/statements/statement.rs +++ b/scylla-rust-wrapper/src/statements/statement.rs @@ -879,7 +879,7 @@ mod tests { use crate::statements::statement::{ cass_statement_set_host, cass_statement_set_host_inet, cass_statement_set_node, }; - use crate::testing::assert_cass_error_eq; + use crate::testing::utils::assert_cass_error_eq; use super::{cass_statement_free, cass_statement_new}; diff --git a/scylla-rust-wrapper/src/integration_testing.rs b/scylla-rust-wrapper/src/testing/integration.rs similarity index 100% rename from scylla-rust-wrapper/src/integration_testing.rs rename to scylla-rust-wrapper/src/testing/integration.rs diff --git a/scylla-rust-wrapper/src/testing/mod.rs b/scylla-rust-wrapper/src/testing/mod.rs new file mode 100644 index 00000000..fee08a88 --- /dev/null +++ b/scylla-rust-wrapper/src/testing/mod.rs @@ -0,0 +1,8 @@ +#[cfg(cpp_integration_testing)] +pub(crate) mod integration; + +#[cfg(test)] +pub(crate) mod utils; + +#[cfg(test)] +mod ser_de_tests; diff --git a/scylla-rust-wrapper/src/ser_de_tests.rs b/scylla-rust-wrapper/src/testing/ser_de_tests.rs similarity index 99% rename from scylla-rust-wrapper/src/ser_de_tests.rs rename to scylla-rust-wrapper/src/testing/ser_de_tests.rs index 7a67398a..d74e6a58 100644 --- a/scylla-rust-wrapper/src/ser_de_tests.rs +++ b/scylla-rust-wrapper/src/testing/ser_de_tests.rs @@ -43,7 +43,7 @@ use crate::query_result::{ cass_value_get_int8, cass_value_get_int16, cass_value_get_int32, cass_value_get_int64, cass_value_get_string, cass_value_get_uuid, cass_value_is_null, cass_value_item_count, }; -use crate::testing::{assert_cass_error_eq, setup_tracing}; +use crate::testing::utils::{assert_cass_error_eq, setup_tracing}; use crate::types::size_t; fn do_serialize(t: T, typ: &ColumnType) -> Vec { diff --git a/scylla-rust-wrapper/src/testing.rs b/scylla-rust-wrapper/src/testing/utils.rs similarity index 100% rename from scylla-rust-wrapper/src/testing.rs rename to scylla-rust-wrapper/src/testing/utils.rs From 31685452d1a27bdcb196c471bff94c103f4452a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 13:53:46 +0200 Subject: [PATCH 08/12] api: expose CassUuid, which was mistakenly private --- scylla-rust-wrapper/src/api.rs | 1 + scylla-rust-wrapper/src/cql_types/uuid.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/api.rs b/scylla-rust-wrapper/src/api.rs index b0127f5a..5a21166b 100644 --- a/scylla-rust-wrapper/src/api.rs +++ b/scylla-rust-wrapper/src/api.rs @@ -847,6 +847,7 @@ pub mod uuid { // Disabling rustfmt to have one item per line for better readability. #[rustfmt::skip] pub use crate::cql_types::uuid::{ + CassUuid, cass_uuid_from_string, cass_uuid_from_string_n, cass_uuid_max_from_time, diff --git a/scylla-rust-wrapper/src/cql_types/uuid.rs b/scylla-rust-wrapper/src/cql_types/uuid.rs index 7439cb32..91d20234 100644 --- a/scylla-rust-wrapper/src/cql_types/uuid.rs +++ b/scylla-rust-wrapper/src/cql_types/uuid.rs @@ -10,7 +10,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; -pub(crate) use crate::cass_uuid_types::CassUuid; +pub use crate::cass_uuid_types::CassUuid; pub struct CassUuidGen { pub(crate) clock_seq_and_node: cass_uint64_t, From a28b7c093d212597a29dc2d043a7b421fcd9f3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 13:56:50 +0200 Subject: [PATCH 09/12] integration: drop redundant ShardAwareness usage ShardAwareness does nothing in dry mode proxy, so we can drop it. --- .../tests/integration/consistency.rs | 110 +++++++++--------- .../tests/integration/utils.rs | 17 +-- 2 files changed, 59 insertions(+), 68 deletions(-) diff --git a/scylla-rust-wrapper/tests/integration/consistency.rs b/scylla-rust-wrapper/tests/integration/consistency.rs index 4d2374f3..05b93960 100644 --- a/scylla-rust-wrapper/tests/integration/consistency.rs +++ b/scylla-rust-wrapper/tests/integration/consistency.rs @@ -34,7 +34,6 @@ use scylla_cpp_driver::api::statement::{ cass_statement_set_execution_profile, cass_statement_set_serial_consistency, }; use scylla_cpp_driver::argconv::{CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr}; -use scylla_proxy::ShardAwareness; use scylla_proxy::{ Condition, ProxyError, Reaction, RequestFrame, RequestOpcode, RequestReaction, RequestRule, TargetShard, WorkerError, @@ -610,69 +609,66 @@ fn check_for_all_consistencies_and_setting_options( #[ntest::timeout(60000)] async fn consistency_is_correctly_set_in_cql_requests() { setup_tracing(); - let res = test_with_3_node_dry_mode_cluster( - ShardAwareness::QueryNode, - |proxy_uris, mut running_proxy| async move { - let request_rules = |tx| { - handshake_rules() - .into_iter() - .chain(drop_metadata_queries_rules()) - .chain([ - RequestRule( - Condition::and( - Condition::not(Condition::ConnectionRegisteredAnyEvent), - Condition::RequestOpcode(RequestOpcode::Prepare), - ), - // Respond to a PREPARE request with a prepared statement ID. - // This assumes 0 bind variables and 0 returned columns. - RequestReaction::forge_response(Arc::new(forge_prepare_response)), + let res = test_with_3_node_dry_mode_cluster(|proxy_uris, mut running_proxy| async move { + let request_rules = |tx| { + handshake_rules() + .into_iter() + .chain(drop_metadata_queries_rules()) + .chain([ + RequestRule( + Condition::and( + Condition::not(Condition::ConnectionRegisteredAnyEvent), + Condition::RequestOpcode(RequestOpcode::Prepare), ), - RequestRule( - Condition::and( - Condition::not(Condition::ConnectionRegisteredAnyEvent), + // Respond to a PREPARE request with a prepared statement ID. + // This assumes 0 bind variables and 0 returned columns. + RequestReaction::forge_response(Arc::new(forge_prepare_response)), + ), + RequestRule( + Condition::and( + Condition::not(Condition::ConnectionRegisteredAnyEvent), + Condition::or( + Condition::RequestOpcode(RequestOpcode::Execute), Condition::or( - Condition::RequestOpcode(RequestOpcode::Execute), - Condition::or( - Condition::RequestOpcode(RequestOpcode::Batch), - Condition::and( - Condition::RequestOpcode(RequestOpcode::Query), - Condition::BodyContainsCaseSensitive(Box::new( - *b"INTO consistency_tests", - )), - ), + Condition::RequestOpcode(RequestOpcode::Batch), + Condition::and( + Condition::RequestOpcode(RequestOpcode::Query), + Condition::BodyContainsCaseSensitive(Box::new( + *b"INTO consistency_tests", + )), ), ), ), - RequestReaction::forge() - .server_error() - .with_feedback_when_performed(tx), ), - ]) - .collect::>() - }; - - // Set the rules for the requests. - // This has the following effect: - // 1. PREPARE requests will be answered with a forged response. - // 2. EXECUTE, BATCH and QUERY requests will be replied with a forged error response, - // but additionally will send a feedback to the channel `tx`, which will be used - // to verify the consistency and serial consistency set in the request. - let (request_tx, request_rx) = mpsc::unbounded_channel(); - for running_node in running_proxy.running_nodes.iter_mut() { - running_node.change_request_rules(Some(request_rules(request_tx.clone()))); - } + RequestReaction::forge() + .server_error() + .with_feedback_when_performed(tx), + ), + ]) + .collect::>() + }; + + // Set the rules for the requests. + // This has the following effect: + // 1. PREPARE requests will be answered with a forged response. + // 2. EXECUTE, BATCH and QUERY requests will be replied with a forged error response, + // but additionally will send a feedback to the channel `tx`, which will be used + // to verify the consistency and serial consistency set in the request. + let (request_tx, request_rx) = mpsc::unbounded_channel(); + for running_node in running_proxy.running_nodes.iter_mut() { + running_node.change_request_rules(Some(request_rules(request_tx.clone()))); + } + + // The test must be executed in a blocking context, because otherwise the tokio runtime + // will panic on blocking operations that C API performs. + tokio::task::spawn_blocking(move || { + check_for_all_consistencies_and_setting_options(request_rx, proxy_uris) + }) + .await + .unwrap(); - // The test must be executed in a blocking context, because otherwise the tokio runtime - // will panic on blocking operations that C API performs. - tokio::task::spawn_blocking(move || { - check_for_all_consistencies_and_setting_options(request_rx, proxy_uris) - }) - .await - .unwrap(); - - running_proxy - }, - ) + running_proxy + }) .await; match res { diff --git a/scylla-rust-wrapper/tests/integration/utils.rs b/scylla-rust-wrapper/tests/integration/utils.rs index 091bd8fa..f47b8c10 100644 --- a/scylla-rust-wrapper/tests/integration/utils.rs +++ b/scylla-rust-wrapper/tests/integration/utils.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use scylla_proxy::{ Condition, Node, Proxy, ProxyError, Reaction as _, RequestFrame, RequestOpcode, - RequestReaction, RequestRule, ResponseFrame, RunningProxy, ShardAwareness, + RequestReaction, RequestRule, ResponseFrame, RunningProxy, }; pub(crate) fn setup_tracing() { @@ -26,10 +26,7 @@ pub(crate) fn setup_tracing() { .try_init(); } -pub(crate) async fn test_with_3_node_dry_mode_cluster( - shard_awareness: ShardAwareness, - test: F, -) -> Result<(), ProxyError> +pub(crate) async fn test_with_3_node_dry_mode_cluster(test: F) -> Result<(), ProxyError> where F: FnOnce([String; 3], RunningProxy) -> Fut, Fut: Future, @@ -42,12 +39,10 @@ where let proxy2_addr = SocketAddr::from_str(proxy2_uri.as_str()).unwrap(); let proxy3_addr = SocketAddr::from_str(proxy3_uri.as_str()).unwrap(); - let proxy = Proxy::new([proxy1_addr, proxy2_addr, proxy3_addr].map(|proxy_addr| { - Node::builder() - .proxy_address(proxy_addr) - .shard_awareness(shard_awareness) - .build_dry_mode() - })); + let proxy = Proxy::new( + [proxy1_addr, proxy2_addr, proxy3_addr] + .map(|proxy_addr| Node::builder().proxy_address(proxy_addr).build_dry_mode()), + ); let running_proxy = proxy.run().await.unwrap(); From ad9c0a5078fc5bca0ab24d9f32d0826ef48d13f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 14:12:14 +0200 Subject: [PATCH 10/12] integration: proxy helper spawns blocking task As the code in the closure passed to `test_with_3_node_dry_mode_cluster` is allowed to perform blocking operations (and this often happens when calling C API functions), we need to spawn a blocking task for it. Now, the closure itself is not async anymore, as it is executed in a blocking task. This avoids the pitfall of forgetting to spawn a blocking task inside the closure, which would lead to panics in certain situations. --- scylla-rust-wrapper/tests/integration/consistency.rs | 10 ++-------- scylla-rust-wrapper/tests/integration/utils.rs | 11 ++++++----- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/scylla-rust-wrapper/tests/integration/consistency.rs b/scylla-rust-wrapper/tests/integration/consistency.rs index 05b93960..e4c0f3a1 100644 --- a/scylla-rust-wrapper/tests/integration/consistency.rs +++ b/scylla-rust-wrapper/tests/integration/consistency.rs @@ -609,7 +609,7 @@ fn check_for_all_consistencies_and_setting_options( #[ntest::timeout(60000)] async fn consistency_is_correctly_set_in_cql_requests() { setup_tracing(); - let res = test_with_3_node_dry_mode_cluster(|proxy_uris, mut running_proxy| async move { + let res = test_with_3_node_dry_mode_cluster(|proxy_uris, mut running_proxy| { let request_rules = |tx| { handshake_rules() .into_iter() @@ -659,13 +659,7 @@ async fn consistency_is_correctly_set_in_cql_requests() { running_node.change_request_rules(Some(request_rules(request_tx.clone()))); } - // The test must be executed in a blocking context, because otherwise the tokio runtime - // will panic on blocking operations that C API performs. - tokio::task::spawn_blocking(move || { - check_for_all_consistencies_and_setting_options(request_rx, proxy_uris) - }) - .await - .unwrap(); + check_for_all_consistencies_and_setting_options(request_rx, proxy_uris); running_proxy }) diff --git a/scylla-rust-wrapper/tests/integration/utils.rs b/scylla-rust-wrapper/tests/integration/utils.rs index f47b8c10..e731557e 100644 --- a/scylla-rust-wrapper/tests/integration/utils.rs +++ b/scylla-rust-wrapper/tests/integration/utils.rs @@ -1,5 +1,4 @@ use bytes::BytesMut; -use futures::Future; use libc::c_char; use scylla_cpp_driver::api::error::{CassError, cass_error_desc}; use scylla_cpp_driver::api::future::{ @@ -26,10 +25,9 @@ pub(crate) fn setup_tracing() { .try_init(); } -pub(crate) async fn test_with_3_node_dry_mode_cluster(test: F) -> Result<(), ProxyError> +pub(crate) async fn test_with_3_node_dry_mode_cluster(test: F) -> Result<(), ProxyError> where - F: FnOnce([String; 3], RunningProxy) -> Fut, - Fut: Future, + F: FnOnce([String; 3], RunningProxy) -> RunningProxy + Send + 'static, { let proxy1_uri = format!("{}:9042", scylla_proxy::get_exclusive_local_address()); let proxy2_uri = format!("{}:9042", scylla_proxy::get_exclusive_local_address()); @@ -46,7 +44,10 @@ where let running_proxy = proxy.run().await.unwrap(); - let running_proxy = test([proxy1_uri, proxy2_uri, proxy3_uri], running_proxy).await; + let running_proxy = + tokio::task::spawn_blocking(|| test([proxy1_uri, proxy2_uri, proxy3_uri], running_proxy)) + .await + .unwrap(); running_proxy.finish().await } From b3f4471a7dd02ee03344911873dd67c1076dace9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 14:14:55 +0200 Subject: [PATCH 11/12] integration: proxy helper accepts initial request rules Although the only existing integration test - consistency.rs - does not benefit from this change (as it must create request rules dynamically), this change makes it easier to write future tests that do not need to create request rules dynamically. It will be used in subsequent commits, when we move more tests from the lib crate to the integration tests. --- .../tests/integration/consistency.rs | 95 ++++++++++--------- .../tests/integration/utils.rs | 16 +++- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/scylla-rust-wrapper/tests/integration/consistency.rs b/scylla-rust-wrapper/tests/integration/consistency.rs index e4c0f3a1..53b8d79f 100644 --- a/scylla-rust-wrapper/tests/integration/consistency.rs +++ b/scylla-rust-wrapper/tests/integration/consistency.rs @@ -609,60 +609,63 @@ fn check_for_all_consistencies_and_setting_options( #[ntest::timeout(60000)] async fn consistency_is_correctly_set_in_cql_requests() { setup_tracing(); - let res = test_with_3_node_dry_mode_cluster(|proxy_uris, mut running_proxy| { - let request_rules = |tx| { - handshake_rules() - .into_iter() - .chain(drop_metadata_queries_rules()) - .chain([ - RequestRule( - Condition::and( - Condition::not(Condition::ConnectionRegisteredAnyEvent), - Condition::RequestOpcode(RequestOpcode::Prepare), + let res = test_with_3_node_dry_mode_cluster( + || None, + |proxy_uris, mut running_proxy| { + let request_rules = |tx| { + handshake_rules() + .into_iter() + .chain(drop_metadata_queries_rules()) + .chain([ + RequestRule( + Condition::and( + Condition::not(Condition::ConnectionRegisteredAnyEvent), + Condition::RequestOpcode(RequestOpcode::Prepare), + ), + // Respond to a PREPARE request with a prepared statement ID. + // This assumes 0 bind variables and 0 returned columns. + RequestReaction::forge_response(Arc::new(forge_prepare_response)), ), - // Respond to a PREPARE request with a prepared statement ID. - // This assumes 0 bind variables and 0 returned columns. - RequestReaction::forge_response(Arc::new(forge_prepare_response)), - ), - RequestRule( - Condition::and( - Condition::not(Condition::ConnectionRegisteredAnyEvent), - Condition::or( - Condition::RequestOpcode(RequestOpcode::Execute), + RequestRule( + Condition::and( + Condition::not(Condition::ConnectionRegisteredAnyEvent), Condition::or( - Condition::RequestOpcode(RequestOpcode::Batch), - Condition::and( - Condition::RequestOpcode(RequestOpcode::Query), - Condition::BodyContainsCaseSensitive(Box::new( - *b"INTO consistency_tests", - )), + Condition::RequestOpcode(RequestOpcode::Execute), + Condition::or( + Condition::RequestOpcode(RequestOpcode::Batch), + Condition::and( + Condition::RequestOpcode(RequestOpcode::Query), + Condition::BodyContainsCaseSensitive(Box::new( + *b"INTO consistency_tests", + )), + ), ), ), ), + RequestReaction::forge() + .server_error() + .with_feedback_when_performed(tx), ), - RequestReaction::forge() - .server_error() - .with_feedback_when_performed(tx), - ), - ]) - .collect::>() - }; - - // Set the rules for the requests. - // This has the following effect: - // 1. PREPARE requests will be answered with a forged response. - // 2. EXECUTE, BATCH and QUERY requests will be replied with a forged error response, - // but additionally will send a feedback to the channel `tx`, which will be used - // to verify the consistency and serial consistency set in the request. - let (request_tx, request_rx) = mpsc::unbounded_channel(); - for running_node in running_proxy.running_nodes.iter_mut() { - running_node.change_request_rules(Some(request_rules(request_tx.clone()))); - } + ]) + .collect::>() + }; + + // Set the rules for the requests. + // This has the following effect: + // 1. PREPARE requests will be answered with a forged response. + // 2. EXECUTE, BATCH and QUERY requests will be replied with a forged error response, + // but additionally will send a feedback to the channel `tx`, which will be used + // to verify the consistency and serial consistency set in the request. + let (request_tx, request_rx) = mpsc::unbounded_channel(); + for running_node in running_proxy.running_nodes.iter_mut() { + running_node.change_request_rules(Some(request_rules(request_tx.clone()))); + } - check_for_all_consistencies_and_setting_options(request_rx, proxy_uris); + check_for_all_consistencies_and_setting_options(request_rx, proxy_uris); - running_proxy - }) + running_proxy + }, + ) .await; match res { diff --git a/scylla-rust-wrapper/tests/integration/utils.rs b/scylla-rust-wrapper/tests/integration/utils.rs index e731557e..d2f9859f 100644 --- a/scylla-rust-wrapper/tests/integration/utils.rs +++ b/scylla-rust-wrapper/tests/integration/utils.rs @@ -25,8 +25,12 @@ pub(crate) fn setup_tracing() { .try_init(); } -pub(crate) async fn test_with_3_node_dry_mode_cluster(test: F) -> Result<(), ProxyError> +pub(crate) async fn test_with_3_node_dry_mode_cluster( + initial_request_rules: impl Fn() -> I, + test: F, +) -> Result<(), ProxyError> where + I: IntoIterator, F: FnOnce([String; 3], RunningProxy) -> RunningProxy + Send + 'static, { let proxy1_uri = format!("{}:9042", scylla_proxy::get_exclusive_local_address()); @@ -37,10 +41,12 @@ where let proxy2_addr = SocketAddr::from_str(proxy2_uri.as_str()).unwrap(); let proxy3_addr = SocketAddr::from_str(proxy3_uri.as_str()).unwrap(); - let proxy = Proxy::new( - [proxy1_addr, proxy2_addr, proxy3_addr] - .map(|proxy_addr| Node::builder().proxy_address(proxy_addr).build_dry_mode()), - ); + let proxy = Proxy::new([proxy1_addr, proxy2_addr, proxy3_addr].map(|proxy_addr| { + Node::builder() + .proxy_address(proxy_addr) + .request_rules(Vec::from_iter(initial_request_rules())) + .build_dry_mode() + })); let running_proxy = proxy.run().await.unwrap(); From 9d1e98d4c9c707e3643be265096e55fd29e1ae6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 24 Sep 2025 14:20:21 +0200 Subject: [PATCH 12/12] session: move some tests to integration It's a good practice to have tests which use only the public API in the integration tests, to ensure that the public API is usable and works as expected. Move some session-related tests to the integration tests. --- scylla-rust-wrapper/src/session.rs | 690 +--------------- scylla-rust-wrapper/tests/integration/main.rs | 1 + .../tests/integration/session.rs | 734 ++++++++++++++++++ .../tests/integration/utils.rs | 52 ++ 4 files changed, 794 insertions(+), 683 deletions(-) create mode 100644 scylla-rust-wrapper/tests/integration/session.rs diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 3e0ca68a..a208d99e 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -876,9 +876,6 @@ pub unsafe extern "C" fn cass_session_get_metrics( #[cfg(test)] mod tests { - use rusty_fork::rusty_fork_test; - use scylla::errors::DbError; - use scylla::frame::types::Consistency; use scylla_proxy::{Condition, RequestOpcode, RequestReaction, RequestRule, RunningProxy}; use tracing::instrument::WithSubscriber; @@ -886,44 +883,25 @@ mod tests { use crate::{ argconv::make_c_str, cluster::{ - cass_cluster_free, cass_cluster_new, cass_cluster_set_client_id, - cass_cluster_set_contact_points_n, cass_cluster_set_execution_profile, - cass_cluster_set_latency_aware_routing, cass_cluster_set_retry_policy, + cass_cluster_free, cass_cluster_new, cass_cluster_set_contact_points_n, + cass_cluster_set_execution_profile, }, exec_profile::{ ExecProfileName, cass_batch_set_execution_profile, cass_batch_set_execution_profile_n, cass_execution_profile_free, cass_execution_profile_new, - cass_execution_profile_set_latency_aware_routing, - cass_execution_profile_set_retry_policy, cass_statement_set_execution_profile, - cass_statement_set_execution_profile_n, - }, - future::{ - cass_future_error_code, cass_future_free, cass_future_set_callback, cass_future_wait, - }, - retry_policy::{ - CassRetryPolicy, cass_retry_policy_default_new, cass_retry_policy_fallthrough_new, + cass_statement_set_execution_profile, cass_statement_set_execution_profile_n, }, + future::cass_future_error_code, statements::batch::{ CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, - cass_batch_set_retry_policy, - }, - statements::statement::{ - cass_statement_free, cass_statement_new, cass_statement_set_retry_policy, }, + statements::statement::{cass_statement_free, cass_statement_new}, testing::utils::{ assert_cass_error_eq, cass_future_wait_check_and_free, generic_drop_queries_rules, - handshake_rules, mock_init_rules, setup_tracing, test_with_one_proxy, + handshake_rules, setup_tracing, test_with_one_proxy, }, - types::cass_bool_t, - }; - use std::{ - collections::HashSet, - convert::{TryFrom, TryInto}, - ffi::{CStr, c_void}, - iter, - net::SocketAddr, - sync::atomic::{AtomicUsize, Ordering}, }; + use std::{collections::HashSet, iter, net::SocketAddr}; #[tokio::test] #[ntest::timeout(5000)] @@ -1306,658 +1284,4 @@ mod tests { } proxy } - - #[tokio::test] - #[ntest::timeout(30000)] - async fn retry_policy_on_statement_and_batch_is_handled_properly() { - setup_tracing(); - test_with_one_proxy( - retry_policy_on_statement_and_batch_is_handled_properly_do, - retry_policy_on_statement_and_batch_is_handled_properly_rules(), - ) - .with_current_subscriber() - .await; - } - - fn retry_policy_on_statement_and_batch_is_handled_properly_rules() - -> impl IntoIterator { - handshake_rules() - .into_iter() - .chain(iter::once(RequestRule( - Condition::RequestOpcode(RequestOpcode::Query) - .or(Condition::RequestOpcode(RequestOpcode::Batch)) - .and(Condition::BodyContainsCaseInsensitive(Box::new( - *b"SELECT host_id FROM system.", - ))) - // this 1 differentiates Fallthrough and Default retry policies. - .and(Condition::TrueForLimitedTimes(1)), - // We simulate the read timeout error in order to trigger DefaultRetryPolicy's - // retry on the same node. - // We don't use the example ReadTimeout error that is included in proxy, - // because in order to trigger a retry we need data_present=false. - RequestReaction::forge_with_error(DbError::ReadTimeout { - consistency: Consistency::All, - received: 1, - required: 1, - data_present: false, - }), - ))) - .chain(iter::once(RequestRule( - Condition::RequestOpcode(RequestOpcode::Query) - .or(Condition::RequestOpcode(RequestOpcode::Batch)) - .and(Condition::BodyContainsCaseInsensitive(Box::new( - *b"SELECT host_id FROM system.", - ))), - // We make the second attempt return a hard, nonrecoverable error. - RequestReaction::forge().read_failure(), - ))) - .chain(generic_drop_queries_rules()) - } - - // This test aims to verify that the retry policy emulation works properly, - // in any sequence of actions mutating the retry policy for a query. - // - // Below, the consecutive states of the test case are illustrated: - // Retry policy set on: ('F' - Fallthrough, 'D' - Default, '-' - no policy set) - // session default exec profile: F F F F F F F F F F F F F F - // per stmt/batch exec profile: - D - - D D D D D - - - D D - // stmt/batch (emulated): - - - F F - F D F F - D D - - fn retry_policy_on_statement_and_batch_is_handled_properly_do( - node_addr: SocketAddr, - mut proxy: RunningProxy, - ) -> RunningProxy { - unsafe { - let mut cluster_raw = cass_cluster_new(); - let ip = node_addr.ip().to_string(); - let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); - - assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len,), - CassError::CASS_OK - ); - - let fallthrough_policy = cass_retry_policy_fallthrough_new(); - let default_policy = cass_retry_policy_default_new(); - cass_cluster_set_retry_policy(cluster_raw.borrow_mut(), fallthrough_policy.borrow()); - - let session_raw = cass_session_new(); - - let mut profile_raw = cass_execution_profile_new(); - // A name of a profile that will have been registered in the Cluster. - let profile_name_c_str = make_c_str!("profile"); - - assert_cass_error_eq!( - cass_execution_profile_set_retry_policy( - profile_raw.borrow_mut(), - default_policy.borrow() - ), - CassError::CASS_OK - ); - - let query = make_c_str!("SELECT host_id FROM system.local WHERE key='local'"); - let mut statement_raw = cass_statement_new(query, 0); - let mut batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); - assert_cass_error_eq!( - cass_batch_add_statement(batch_raw.borrow_mut(), statement_raw.borrow()), - CassError::CASS_OK - ); - - assert_cass_error_eq!( - cass_cluster_set_execution_profile( - cluster_raw.borrow_mut(), - profile_name_c_str, - profile_raw.borrow_mut(), - ), - CassError::CASS_OK - ); - - cass_future_wait_check_and_free(cass_session_connect( - session_raw.borrow(), - cluster_raw.borrow().into_c_const(), - )); - { - unsafe fn execute_query( - session_raw: CassBorrowedSharedPtr, - statement_raw: CassBorrowedSharedPtr, - ) -> CassError { - unsafe { - cass_future_error_code( - cass_session_execute(session_raw, statement_raw).borrow(), - ) - } - } - unsafe fn execute_batch( - session_raw: CassBorrowedSharedPtr, - batch_raw: CassBorrowedSharedPtr, - ) -> CassError { - unsafe { - cass_future_error_code( - cass_session_execute_batch(session_raw, batch_raw).borrow(), - ) - } - } - - fn reset_proxy_rules(proxy: &mut RunningProxy) { - proxy.running_nodes[0].change_request_rules(Some( - retry_policy_on_statement_and_batch_is_handled_properly_rules() - .into_iter() - .collect(), - )) - } - - unsafe fn assert_query_with_fallthrough_policy( - proxy: &mut RunningProxy, - session_raw: CassBorrowedSharedPtr, - statement_raw: CassBorrowedSharedPtr, - batch_raw: CassBorrowedSharedPtr, - ) { - reset_proxy_rules(&mut *proxy); - unsafe { - assert_cass_error_eq!( - execute_query(session_raw.borrow(), statement_raw), - CassError::CASS_ERROR_SERVER_READ_TIMEOUT, - ); - reset_proxy_rules(&mut *proxy); - assert_cass_error_eq!( - execute_batch(session_raw, batch_raw), - CassError::CASS_ERROR_SERVER_READ_TIMEOUT, - ); - } - } - - unsafe fn assert_query_with_default_policy( - proxy: &mut RunningProxy, - session_raw: CassBorrowedSharedPtr, - statement_raw: CassBorrowedSharedPtr, - batch_raw: CassBorrowedSharedPtr, - ) { - reset_proxy_rules(&mut *proxy); - unsafe { - assert_cass_error_eq!( - execute_query(session_raw.borrow(), statement_raw), - CassError::CASS_ERROR_SERVER_READ_FAILURE - ); - reset_proxy_rules(&mut *proxy); - assert_cass_error_eq!( - execute_batch(session_raw, batch_raw), - CassError::CASS_ERROR_SERVER_READ_FAILURE - ); - } - } - - unsafe fn set_provided_exec_profile( - name: *const i8, - statement_raw: CassBorrowedExclusivePtr, - batch_raw: CassBorrowedExclusivePtr, - ) { - // Set statement/batch exec profile. - unsafe { - assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, name,), - CassError::CASS_OK - ); - assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, name,), - CassError::CASS_OK - ); - } - } - unsafe fn set_exec_profile( - profile_name_c_str: *const c_char, - statement_raw: CassBorrowedExclusivePtr, - batch_raw: CassBorrowedExclusivePtr, - ) { - unsafe { - set_provided_exec_profile(profile_name_c_str, statement_raw, batch_raw) - }; - } - unsafe fn unset_exec_profile( - statement_raw: CassBorrowedExclusivePtr, - batch_raw: CassBorrowedExclusivePtr, - ) { - unsafe { - set_provided_exec_profile(std::ptr::null::(), statement_raw, batch_raw) - }; - } - unsafe fn set_retry_policy_on_stmt( - policy: CassBorrowedSharedPtr, - statement_raw: CassBorrowedExclusivePtr, - batch_raw: CassBorrowedExclusivePtr, - ) { - unsafe { - assert_cass_error_eq!( - cass_statement_set_retry_policy(statement_raw, policy.borrow()), - CassError::CASS_OK - ); - assert_cass_error_eq!( - cass_batch_set_retry_policy(batch_raw, policy,), - CassError::CASS_OK - ); - } - } - unsafe fn unset_retry_policy_on_stmt( - statement_raw: CassBorrowedExclusivePtr, - batch_raw: CassBorrowedExclusivePtr, - ) { - unsafe { set_retry_policy_on_stmt(ArcFFI::null(), statement_raw, batch_raw) }; - } - - // ### START TESTING - - // With no exec profile nor retry policy set on statement/batch, - // the default cluster-wide retry policy should be used: in this case, fallthrough. - - // F - - - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D - - set_exec_profile( - profile_name_c_str, - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F - - - unset_exec_profile(statement_raw.borrow_mut(), batch_raw.borrow_mut()); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F - F - set_retry_policy_on_stmt( - fallthrough_policy.borrow(), - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D F - set_exec_profile( - profile_name_c_str, - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D - - unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D F - set_retry_policy_on_stmt( - fallthrough_policy.borrow(), - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D D - set_retry_policy_on_stmt( - default_policy.borrow(), - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D F - set_retry_policy_on_stmt( - fallthrough_policy.borrow(), - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F - F - unset_exec_profile(statement_raw.borrow_mut(), batch_raw.borrow_mut()); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F - - - unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); - assert_query_with_fallthrough_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F - D - set_retry_policy_on_stmt( - default_policy.borrow(), - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D D - set_exec_profile( - profile_name_c_str, - statement_raw.borrow_mut(), - batch_raw.borrow_mut(), - ); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - - // F D - - unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); - assert_query_with_default_policy( - &mut proxy, - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - batch_raw.borrow().into_c_const(), - ); - } - - cass_future_wait_check_and_free(cass_session_close(session_raw.borrow())); - cass_execution_profile_free(profile_raw); - cass_statement_free(statement_raw); - cass_batch_free(batch_raw); - cass_session_free(session_raw); - cass_cluster_free(cluster_raw); - } - proxy - } - - #[test] - #[ntest::timeout(5000)] - fn session_with_latency_aware_load_balancing_does_not_panic() { - unsafe { - let mut cluster_raw = cass_cluster_new(); - - // An IP with very little chance of having a ScyllaDB node listening - let ip = "127.0.1.231"; - let (c_ip, c_ip_len) = str_to_c_str_n(ip); - - assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), - CassError::CASS_OK - ); - cass_cluster_set_latency_aware_routing(cluster_raw.borrow_mut(), true as cass_bool_t); - let session_raw = cass_session_new(); - let mut profile_raw = cass_execution_profile_new(); - assert_cass_error_eq!( - cass_execution_profile_set_latency_aware_routing( - profile_raw.borrow_mut(), - true as cass_bool_t - ), - CassError::CASS_OK - ); - let profile_name = make_c_str!("latency_aware"); - cass_cluster_set_execution_profile( - cluster_raw.borrow_mut(), - profile_name, - profile_raw.borrow_mut(), - ); - { - let cass_future = - cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); - cass_future_wait(cass_future.borrow()); - // The exact outcome is not important, we only test that we don't panic. - } - cass_execution_profile_free(profile_raw); - cass_session_free(session_raw); - cass_cluster_free(cluster_raw); - } - } - - rusty_fork_test! { - #![rusty_fork(timeout_ms = 1000)] - #[test] - fn cluster_is_not_referenced_by_session_connect_future() { - // An IP with very little chance of having a ScyllaDB node listening - let ip = "127.0.1.231"; - let (c_ip, c_ip_len) = str_to_c_str_n(ip); - let profile_name = make_c_str!("latency_aware"); - - unsafe { - let mut cluster_raw = cass_cluster_new(); - - assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), - CassError::CASS_OK - ); - cass_cluster_set_latency_aware_routing(cluster_raw.borrow_mut(), true as cass_bool_t); - let session_raw = cass_session_new(); - let mut profile_raw = cass_execution_profile_new(); - assert_cass_error_eq!( - cass_execution_profile_set_latency_aware_routing(profile_raw.borrow_mut(), true as cass_bool_t), - CassError::CASS_OK - ); - cass_cluster_set_execution_profile(cluster_raw.borrow_mut(), profile_name, profile_raw.borrow_mut()); - { - let cass_future = cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); - - // This checks that we don't use-after-free the cluster inside the future. - cass_cluster_free(cluster_raw); - - cass_future_wait(cass_future.borrow()); - // The exact outcome is not important, we only test that we don't segfault. - } - cass_execution_profile_free(profile_raw); - cass_session_free(session_raw); - } - } - } - - #[tokio::test] - #[ntest::timeout(5000)] - async fn test_cass_session_get_client_id_on_disconnected_session() { - setup_tracing(); - test_with_one_proxy( - |node_addr: SocketAddr, proxy: RunningProxy| unsafe { - let session_raw = cass_session_new(); - - // Check that we can get a client ID from a disconnected session. - let _random_client_id = cass_session_get_client_id(session_raw.borrow()); - - let mut cluster_raw = cass_cluster_new(); - let ip = node_addr.ip().to_string(); - let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); - assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), - CassError::CASS_OK - ); - - let cluster_client_id = CassUuid { - time_and_version: 2137, - clock_seq_and_node: 7312, - }; - cass_cluster_set_client_id(cluster_raw.borrow_mut(), cluster_client_id); - - let connect_fut = - cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); - assert_cass_error_eq!(cass_future_error_code(connect_fut), CassError::CASS_OK); - - // Verify that the session inherits the client ID from the cluster. - let session_client_id = cass_session_get_client_id(session_raw.borrow()); - assert_eq!(session_client_id, cluster_client_id); - - // Verify that we can still get a client ID after disconnecting. - let session_client_id = cass_session_get_client_id(session_raw.borrow()); - assert_eq!(session_client_id, cluster_client_id); - - cass_session_free(session_raw); - cass_cluster_free(cluster_raw); - - proxy - }, - mock_init_rules(), - ) - .with_current_subscriber() - .await; - } - - #[tokio::test] - #[ntest::timeout(5000)] - async fn session_free_waits_for_requests_to_complete() { - setup_tracing(); - test_with_one_proxy( - session_free_waits_for_requests_to_complete_do, - mock_init_rules(), - ) - .with_current_subscriber() - .await; - } - - fn session_free_waits_for_requests_to_complete_do( - node_addr: SocketAddr, - proxy: RunningProxy, - ) -> RunningProxy { - unsafe { - let mut cluster_raw = cass_cluster_new(); - let ip = node_addr.ip().to_string(); - let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); - - assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), - CassError::CASS_OK - ); - let session_raw = cass_session_new(); - cass_future_wait_check_and_free(cass_session_connect( - session_raw.borrow(), - cluster_raw.borrow().into_c_const(), - )); - - tracing::debug!("Session connected, starting to execute requests..."); - - let statement = c"SELECT host_id FROM system.local WHERE key='local'" as *const CStr - as *const c_char; - let statement_raw = cass_statement_new(statement, 0); - - let mut batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); - // This batch is obviously invalid, because it contains a SELECT statement. This is OK for us, - // because we anyway expect the batch to fail. The goal is to have the future set, no matter if it's - // set with a success or an error. - cass_batch_add_statement(batch_raw.borrow_mut(), statement_raw.borrow()); - - let finished_executions = AtomicUsize::new(0); - unsafe extern "C" fn finished_execution_callback( - _future_raw: CassBorrowedSharedPtr, - data: *mut c_void, - ) { - let finished_executions = unsafe { &*(data as *const AtomicUsize) }; - finished_executions.fetch_add(1, Ordering::SeqCst); - } - - const ITERATIONS: usize = 1; - const EXECUTIONS: usize = 3 * ITERATIONS; // One prepare, one statement and one batch per iteration. - - let futures = (0..ITERATIONS) - .flat_map(|_| { - // Prepare a statement - let prepare_fut = cass_session_prepare(session_raw.borrow(), statement); - - // Execute a statement - let statement_fut = cass_session_execute( - session_raw.borrow(), - statement_raw.borrow().into_c_const(), - ); - - // Execute a batch - let batch_fut = cass_session_execute_batch( - session_raw.borrow(), - batch_raw.borrow().into_c_const(), - ); - for fut in [ - prepare_fut.borrow(), - statement_fut.borrow(), - batch_fut.borrow(), - ] { - cass_future_set_callback( - fut, - Some(finished_execution_callback), - std::ptr::addr_of!(finished_executions) as _, - ); - } - - [prepare_fut, statement_fut, batch_fut] - }) - .collect::>(); - - tracing::debug!("Started all requests. Now, freeing statements and session..."); - - // Free the statement - cass_statement_free(statement_raw); - // Free the batch - cass_batch_free(batch_raw); - - // Session is freed, but the requests may still be in-flight. - cass_session_free(session_raw); - - tracing::debug!("Session freed."); - - // Assert that the session awaited completion of all requests. - let actually_finished_executions = finished_executions.load(Ordering::SeqCst); - assert_eq!( - actually_finished_executions, EXECUTIONS, - "Expected {} requests to complete before the session was freed, but only {} did.", - EXECUTIONS, actually_finished_executions - ); - - futures.into_iter().for_each(|fut| { - // As per cassandra.h, "a future can be freed anytime". - cass_future_free(fut); - }); - - cass_cluster_free(cluster_raw); - } - proxy - } } diff --git a/scylla-rust-wrapper/tests/integration/main.rs b/scylla-rust-wrapper/tests/integration/main.rs index 1a5c3dda..7186cbab 100644 --- a/scylla-rust-wrapper/tests/integration/main.rs +++ b/scylla-rust-wrapper/tests/integration/main.rs @@ -1,2 +1,3 @@ mod consistency; +mod session; mod utils; diff --git a/scylla-rust-wrapper/tests/integration/session.rs b/scylla-rust-wrapper/tests/integration/session.rs new file mode 100644 index 00000000..006e0ee1 --- /dev/null +++ b/scylla-rust-wrapper/tests/integration/session.rs @@ -0,0 +1,734 @@ +use std::{ + ffi::{CStr, c_void}, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use libc::c_char; +use rusty_fork::rusty_fork_test; +use scylla::errors::DbError; +use scylla_cpp_driver::{ + api::{ + batch::{ + CassBatch, CassBatchType, cass_batch_add_statement, cass_batch_free, cass_batch_new, + cass_batch_set_execution_profile, cass_batch_set_retry_policy, + }, + cluster::{ + cass_cluster_free, cass_cluster_new, cass_cluster_set_client_id, + cass_cluster_set_contact_points, cass_cluster_set_contact_points_n, + cass_cluster_set_execution_profile, cass_cluster_set_latency_aware_routing, + cass_cluster_set_retry_policy, + }, + error::CassError, + execution_profile::{ + cass_execution_profile_free, cass_execution_profile_new, + cass_execution_profile_set_latency_aware_routing, + cass_execution_profile_set_retry_policy, + }, + future::{ + CassFuture, cass_future_error_code, cass_future_free, cass_future_set_callback, + cass_future_wait, + }, + retry_policy::{ + CassRetryPolicy, cass_retry_policy_default_new, cass_retry_policy_fallthrough_new, + }, + session::{ + CassSession, cass_session_close, cass_session_connect, cass_session_execute, + cass_session_execute_batch, cass_session_free, cass_session_get_client_id, + cass_session_new, cass_session_prepare, + }, + statement::{ + CassStatement, cass_statement_free, cass_statement_new, + cass_statement_set_execution_profile, cass_statement_set_retry_policy, + }, + uuid::CassUuid, + }, + argconv::{ArcFFI, CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr}, + types::cass_bool_t, +}; +use scylla_cql::Consistency; +use scylla_proxy::{ + Condition, ProxyError, RequestOpcode, RequestReaction, RequestRule, RunningProxy, WorkerError, +}; +use tracing::instrument::WithSubscriber as _; + +use crate::utils::{ + assert_cass_error_eq, cass_future_wait_check_and_free, generic_drop_queries_rules, + handshake_rules, make_c_str, mock_init_rules, proxy_uris_to_contact_points, setup_tracing, + str_to_c_str_n, test_with_3_node_dry_mode_cluster, +}; + +#[tokio::test] +#[ntest::timeout(30000)] +async fn retry_policy_on_statement_and_batch_is_handled_properly() { + setup_tracing(); + let res = test_with_3_node_dry_mode_cluster( + retry_policy_on_statement_and_batch_is_handled_properly_rules, + retry_policy_on_statement_and_batch_is_handled_properly_do, + ) + .with_current_subscriber() + .await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +} + +fn retry_policy_on_statement_and_batch_is_handled_properly_rules() +-> impl IntoIterator { + handshake_rules() + .into_iter() + .chain(std::iter::once(RequestRule( + Condition::RequestOpcode(RequestOpcode::Query) + .or(Condition::RequestOpcode(RequestOpcode::Batch)) + .and(Condition::BodyContainsCaseInsensitive(Box::new( + *b"SELECT host_id FROM system.", + ))) + // this 1 differentiates Fallthrough and Default retry policies. + .and(Condition::TrueForLimitedTimes(1)), + // We simulate the read timeout error in order to trigger DefaultRetryPolicy's + // retry on the same node. + // We don't use the example ReadTimeout error that is included in proxy, + // because in order to trigger a retry we need data_present=false. + RequestReaction::forge_with_error(DbError::ReadTimeout { + consistency: Consistency::All, + received: 1, + required: 1, + data_present: false, + }), + ))) + .chain(std::iter::once(RequestRule( + Condition::RequestOpcode(RequestOpcode::Query) + .or(Condition::RequestOpcode(RequestOpcode::Batch)) + .and(Condition::BodyContainsCaseInsensitive(Box::new( + *b"SELECT host_id FROM system.", + ))), + // We make the second attempt return a hard, nonrecoverable error. + RequestReaction::forge().read_failure(), + ))) + .chain(generic_drop_queries_rules()) +} + +// This test aims to verify that the retry policy emulation works properly, +// in any sequence of actions mutating the retry policy for a query. +// +// Below, the consecutive states of the test case are illustrated: +// Retry policy set on: ('F' - Fallthrough, 'D' - Default, '-' - no policy set) +// session default exec profile: F F F F F F F F F F F F F F +// per stmt/batch exec profile: - D - - D D D D D - - - D D +// stmt/batch (emulated): - - - F F - F D F F - D D - +fn retry_policy_on_statement_and_batch_is_handled_properly_do( + proxy_uris: [String; 3], + mut proxy: RunningProxy, +) -> RunningProxy { + unsafe { + let mut cluster_raw = cass_cluster_new(); + let contact_points = proxy_uris_to_contact_points(proxy_uris); + + assert_cass_error_eq( + cass_cluster_set_contact_points(cluster_raw.borrow_mut(), contact_points.as_ptr()), + CassError::CASS_OK, + ); + + let fallthrough_policy = cass_retry_policy_fallthrough_new(); + let default_policy = cass_retry_policy_default_new(); + cass_cluster_set_retry_policy(cluster_raw.borrow_mut(), fallthrough_policy.borrow()); + + let session_raw = cass_session_new(); + + let mut profile_raw = cass_execution_profile_new(); + // A name of a profile that will have been registered in the Cluster. + let profile_name_c_str = make_c_str!("profile"); + + assert_cass_error_eq( + cass_execution_profile_set_retry_policy( + profile_raw.borrow_mut(), + default_policy.borrow(), + ), + CassError::CASS_OK, + ); + + let query = make_c_str!("SELECT host_id FROM system.local WHERE key='local'"); + let mut statement_raw = cass_statement_new(query, 0); + let mut batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); + assert_cass_error_eq( + cass_batch_add_statement(batch_raw.borrow_mut(), statement_raw.borrow()), + CassError::CASS_OK, + ); + + assert_cass_error_eq( + cass_cluster_set_execution_profile( + cluster_raw.borrow_mut(), + profile_name_c_str, + profile_raw.borrow_mut(), + ), + CassError::CASS_OK, + ); + + cass_future_wait_check_and_free(cass_session_connect( + session_raw.borrow(), + cluster_raw.borrow().into_c_const(), + )); + { + unsafe fn execute_query( + session_raw: CassBorrowedSharedPtr, + statement_raw: CassBorrowedSharedPtr, + ) -> CassError { + unsafe { + cass_future_error_code( + cass_session_execute(session_raw, statement_raw).borrow(), + ) + } + } + unsafe fn execute_batch( + session_raw: CassBorrowedSharedPtr, + batch_raw: CassBorrowedSharedPtr, + ) -> CassError { + unsafe { + cass_future_error_code( + cass_session_execute_batch(session_raw, batch_raw).borrow(), + ) + } + } + + fn reset_proxy_rules(proxy: &mut RunningProxy) { + proxy.running_nodes.iter_mut().for_each(|node| { + node.change_request_rules(Some( + retry_policy_on_statement_and_batch_is_handled_properly_rules() + .into_iter() + .collect(), + )) + }) + } + + unsafe fn assert_query_with_fallthrough_policy( + proxy: &mut RunningProxy, + session_raw: CassBorrowedSharedPtr, + statement_raw: CassBorrowedSharedPtr, + batch_raw: CassBorrowedSharedPtr, + ) { + reset_proxy_rules(&mut *proxy); + unsafe { + assert_cass_error_eq( + execute_query(session_raw.borrow(), statement_raw), + CassError::CASS_ERROR_SERVER_READ_TIMEOUT, + ); + reset_proxy_rules(&mut *proxy); + assert_cass_error_eq( + execute_batch(session_raw, batch_raw), + CassError::CASS_ERROR_SERVER_READ_TIMEOUT, + ); + } + } + + unsafe fn assert_query_with_default_policy( + proxy: &mut RunningProxy, + session_raw: CassBorrowedSharedPtr, + statement_raw: CassBorrowedSharedPtr, + batch_raw: CassBorrowedSharedPtr, + ) { + reset_proxy_rules(&mut *proxy); + unsafe { + assert_cass_error_eq( + execute_query(session_raw.borrow(), statement_raw), + CassError::CASS_ERROR_SERVER_READ_FAILURE, + ); + reset_proxy_rules(&mut *proxy); + assert_cass_error_eq( + execute_batch(session_raw, batch_raw), + CassError::CASS_ERROR_SERVER_READ_FAILURE, + ); + } + } + + unsafe fn set_provided_exec_profile( + name: *const i8, + statement_raw: CassBorrowedExclusivePtr, + batch_raw: CassBorrowedExclusivePtr, + ) { + // Set statement/batch exec profile. + unsafe { + assert_cass_error_eq( + cass_statement_set_execution_profile(statement_raw, name), + CassError::CASS_OK, + ); + assert_cass_error_eq( + cass_batch_set_execution_profile(batch_raw, name), + CassError::CASS_OK, + ); + } + } + unsafe fn set_exec_profile( + profile_name_c_str: *const c_char, + statement_raw: CassBorrowedExclusivePtr, + batch_raw: CassBorrowedExclusivePtr, + ) { + unsafe { set_provided_exec_profile(profile_name_c_str, statement_raw, batch_raw) }; + } + unsafe fn unset_exec_profile( + statement_raw: CassBorrowedExclusivePtr, + batch_raw: CassBorrowedExclusivePtr, + ) { + unsafe { + set_provided_exec_profile(std::ptr::null::(), statement_raw, batch_raw) + }; + } + unsafe fn set_retry_policy_on_stmt( + policy: CassBorrowedSharedPtr, + statement_raw: CassBorrowedExclusivePtr, + batch_raw: CassBorrowedExclusivePtr, + ) { + unsafe { + assert_cass_error_eq( + cass_statement_set_retry_policy(statement_raw, policy.borrow()), + CassError::CASS_OK, + ); + assert_cass_error_eq( + cass_batch_set_retry_policy(batch_raw, policy), + CassError::CASS_OK, + ); + } + } + unsafe fn unset_retry_policy_on_stmt( + statement_raw: CassBorrowedExclusivePtr, + batch_raw: CassBorrowedExclusivePtr, + ) { + unsafe { set_retry_policy_on_stmt(ArcFFI::null(), statement_raw, batch_raw) }; + } + + // ### START TESTING + + // With no exec profile nor retry policy set on statement/batch, + // the default cluster-wide retry policy should be used: in this case, fallthrough. + + // F - - + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D - + set_exec_profile( + profile_name_c_str, + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F - - + unset_exec_profile(statement_raw.borrow_mut(), batch_raw.borrow_mut()); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F - F + set_retry_policy_on_stmt( + fallthrough_policy.borrow(), + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D F + set_exec_profile( + profile_name_c_str, + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D - + unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D F + set_retry_policy_on_stmt( + fallthrough_policy.borrow(), + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D D + set_retry_policy_on_stmt( + default_policy.borrow(), + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D F + set_retry_policy_on_stmt( + fallthrough_policy.borrow(), + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F - F + unset_exec_profile(statement_raw.borrow_mut(), batch_raw.borrow_mut()); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F - - + unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); + assert_query_with_fallthrough_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F - D + set_retry_policy_on_stmt( + default_policy.borrow(), + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D D + set_exec_profile( + profile_name_c_str, + statement_raw.borrow_mut(), + batch_raw.borrow_mut(), + ); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + + // F D - + unset_retry_policy_on_stmt(statement_raw.borrow_mut(), batch_raw.borrow_mut()); + assert_query_with_default_policy( + &mut proxy, + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + batch_raw.borrow().into_c_const(), + ); + } + + cass_future_wait_check_and_free(cass_session_close(session_raw.borrow())); + cass_execution_profile_free(profile_raw); + cass_statement_free(statement_raw); + cass_batch_free(batch_raw); + cass_session_free(session_raw); + cass_cluster_free(cluster_raw); + } + + proxy +} + +#[test] +#[ntest::timeout(5000)] +fn session_with_latency_aware_load_balancing_does_not_panic() { + unsafe { + let mut cluster_raw = cass_cluster_new(); + + // An IP with very little chance of having a ScyllaDB node listening + let ip = "127.0.1.231"; + let (c_ip, c_ip_len) = str_to_c_str_n(ip); + + assert_cass_error_eq( + cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), + CassError::CASS_OK, + ); + cass_cluster_set_latency_aware_routing(cluster_raw.borrow_mut(), true as cass_bool_t); + let session_raw = cass_session_new(); + let mut profile_raw = cass_execution_profile_new(); + assert_cass_error_eq( + cass_execution_profile_set_latency_aware_routing( + profile_raw.borrow_mut(), + true as cass_bool_t, + ), + CassError::CASS_OK, + ); + let profile_name = make_c_str!("latency_aware"); + cass_cluster_set_execution_profile( + cluster_raw.borrow_mut(), + profile_name, + profile_raw.borrow_mut(), + ); + { + let cass_future = + cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); + cass_future_wait(cass_future.borrow()); + // The exact outcome is not important, we only test that we don't panic. + } + cass_execution_profile_free(profile_raw); + cass_session_free(session_raw); + cass_cluster_free(cluster_raw); + } +} + +rusty_fork_test! { + #![rusty_fork(timeout_ms = 1000)] + #[test] + fn cluster_is_not_referenced_by_session_connect_future() { + // An IP with very little chance of having a ScyllaDB node listening + let ip = "127.0.1.231"; + let (c_ip, c_ip_len) = str_to_c_str_n(ip); + let profile_name = make_c_str!("latency_aware"); + + unsafe { + let mut cluster_raw = cass_cluster_new(); + + assert_cass_error_eq( + cass_cluster_set_contact_points_n(cluster_raw.borrow_mut(), c_ip, c_ip_len), + CassError::CASS_OK + ); + cass_cluster_set_latency_aware_routing(cluster_raw.borrow_mut(), true as cass_bool_t); + let session_raw = cass_session_new(); + let mut profile_raw = cass_execution_profile_new(); + assert_cass_error_eq( + cass_execution_profile_set_latency_aware_routing(profile_raw.borrow_mut(), true as cass_bool_t), + CassError::CASS_OK + ); + cass_cluster_set_execution_profile(cluster_raw.borrow_mut(), profile_name, profile_raw.borrow_mut()); + { + let cass_future = cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); + + // This checks that we don't use-after-free the cluster inside the future. + cass_cluster_free(cluster_raw); + + cass_future_wait(cass_future.borrow()); + // The exact outcome is not important, we only test that we don't segfault. + } + cass_execution_profile_free(profile_raw); + cass_session_free(session_raw); + } + } +} + +#[tokio::test] +#[ntest::timeout(5000)] +async fn test_cass_session_get_client_id_on_disconnected_session() { + setup_tracing(); + let res = test_with_3_node_dry_mode_cluster( + mock_init_rules, + |proxy_uris: [String; 3], proxy: RunningProxy| { + unsafe { + let session_raw = cass_session_new(); + + // Check that we can get a client ID from a disconnected session. + let _random_client_id = cass_session_get_client_id(session_raw.borrow()); + + let mut cluster_raw = cass_cluster_new(); + let contact_points = proxy_uris_to_contact_points(proxy_uris); + assert_cass_error_eq( + cass_cluster_set_contact_points( + cluster_raw.borrow_mut(), + contact_points.as_ptr(), + ), + CassError::CASS_OK, + ); + + let cluster_client_id = CassUuid { + time_and_version: 2137, + clock_seq_and_node: 7312, + }; + cass_cluster_set_client_id(cluster_raw.borrow_mut(), cluster_client_id); + + let connect_fut = + cass_session_connect(session_raw.borrow(), cluster_raw.borrow().into_c_const()); + assert_cass_error_eq(cass_future_error_code(connect_fut), CassError::CASS_OK); + + // Verify that the session inherits the client ID from the cluster. + let session_client_id = cass_session_get_client_id(session_raw.borrow()); + assert_eq!(session_client_id, cluster_client_id); + + // Verify that we can still get a client ID after disconnecting. + let session_client_id = cass_session_get_client_id(session_raw.borrow()); + assert_eq!(session_client_id, cluster_client_id); + + cass_session_free(session_raw); + cass_cluster_free(cluster_raw); + } + + proxy + }, + ) + .with_current_subscriber() + .await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +} + +#[tokio::test] +#[ntest::timeout(50000)] +async fn session_free_waits_for_requests_to_complete() { + setup_tracing(); + let res = test_with_3_node_dry_mode_cluster( + mock_init_rules, + session_free_waits_for_requests_to_complete_do, + ) + .with_current_subscriber() + .await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +} + +fn session_free_waits_for_requests_to_complete_do( + proxy_uris: [String; 3], + proxy: RunningProxy, +) -> RunningProxy { + unsafe { + let mut cluster_raw = cass_cluster_new(); + let contact_points = proxy_uris_to_contact_points(proxy_uris); + + assert_cass_error_eq( + cass_cluster_set_contact_points(cluster_raw.borrow_mut(), contact_points.as_ptr()), + CassError::CASS_OK, + ); + let session_raw = cass_session_new(); + cass_future_wait_check_and_free(cass_session_connect( + session_raw.borrow(), + cluster_raw.borrow().into_c_const(), + )); + + tracing::debug!("Session connected, starting to execute requests..."); + + let statement = + c"SELECT host_id FROM system.local WHERE key='local'" as *const CStr as *const c_char; + let statement_raw = cass_statement_new(statement, 0); + + let mut batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); + // This batch is obviously invalid, because it contains a SELECT statement. This is OK for us, + // because we anyway expect the batch to fail. The goal is to have the future set, no matter if it's + // set with a success or an error. + cass_batch_add_statement(batch_raw.borrow_mut(), statement_raw.borrow()); + + let finished_executions = AtomicUsize::new(0); + unsafe extern "C" fn finished_execution_callback( + _future_raw: CassBorrowedSharedPtr, + data: *mut c_void, + ) { + let finished_executions = unsafe { &*(data as *const AtomicUsize) }; + finished_executions.fetch_add(1, Ordering::SeqCst); + } + + const ITERATIONS: usize = 1; + const EXECUTIONS: usize = 3 * ITERATIONS; // One prepare, one statement and one batch per iteration. + + let futures = (0..ITERATIONS) + .flat_map(|_| { + // Prepare a statement + let prepare_fut = cass_session_prepare(session_raw.borrow(), statement); + + // Execute a statement + let statement_fut = cass_session_execute( + session_raw.borrow(), + statement_raw.borrow().into_c_const(), + ); + + // Execute a batch + let batch_fut = cass_session_execute_batch( + session_raw.borrow(), + batch_raw.borrow().into_c_const(), + ); + for fut in [ + prepare_fut.borrow(), + statement_fut.borrow(), + batch_fut.borrow(), + ] { + cass_future_set_callback( + fut, + Some(finished_execution_callback), + std::ptr::addr_of!(finished_executions) as _, + ); + } + + [prepare_fut, statement_fut, batch_fut] + }) + .collect::>(); + + tracing::debug!("Started all requests. Now, freeing statements and session..."); + + // Free the statement + cass_statement_free(statement_raw); + // Free the batch + cass_batch_free(batch_raw); + + // Session is freed, but the requests may still be in-flight. + cass_session_free(session_raw); + + tracing::debug!("Session freed."); + + // Assert that the session awaited completion of all requests. + let actually_finished_executions = finished_executions.load(Ordering::SeqCst); + assert_eq!( + actually_finished_executions, EXECUTIONS, + "Expected {} requests to complete before the session was freed, but only {} did.", + EXECUTIONS, actually_finished_executions + ); + + futures.into_iter().for_each(|fut| { + // As per cassandra.h, "a future can be freed anytime". + cass_future_free(fut); + }); + + cass_cluster_free(cluster_raw); + } + + proxy +} diff --git a/scylla-rust-wrapper/tests/integration/utils.rs b/scylla-rust-wrapper/tests/integration/utils.rs index d2f9859f..87e2d7f7 100644 --- a/scylla-rust-wrapper/tests/integration/utils.rs +++ b/scylla-rust-wrapper/tests/integration/utils.rs @@ -25,6 +25,30 @@ pub(crate) fn setup_tracing() { .try_init(); } +unsafe fn write_str_to_c(s: &str, c_str: *mut *const c_char, c_strlen: *mut size_t) { + unsafe { + *c_str = s.as_ptr() as *const c_char; + *c_strlen = s.len() as u64; + } +} + +pub(crate) fn str_to_c_str_n(s: &str) -> (*const c_char, size_t) { + let mut c_str = std::ptr::null(); + let mut c_strlen = size_t::default(); + + // SAFETY: The pointers that are passed to `write_str_to_c` are compile-checked references. + unsafe { write_str_to_c(s, &mut c_str, &mut c_strlen) }; + + (c_str, c_strlen) +} + +macro_rules! make_c_str { + ($str:literal) => { + concat!($str, "\0").as_ptr() as *const c_char + }; +} +pub(crate) use make_c_str; + pub(crate) async fn test_with_3_node_dry_mode_cluster( initial_request_rules: impl Fn() -> I, test: F, @@ -58,6 +82,7 @@ where running_proxy.finish().await } +#[track_caller] pub(crate) fn assert_cass_error_eq(errcode1: CassError, errcode2: CassError) { unsafe { assert_eq!( @@ -130,6 +155,33 @@ pub(crate) fn handshake_rules() -> impl IntoIterator { ] } +// As these are very generic, they should be put last in the rules Vec. +pub(crate) fn generic_drop_queries_rules() -> impl IntoIterator { + [RequestRule( + Condition::RequestOpcode(RequestOpcode::Query), + // We won't respond to any queries (including metadata fetch), + // but the driver will manage to continue with dummy metadata. + RequestReaction::forge().server_error(), + )] +} + +/// A set of rules that are needed to finish session initialization. +// They are used in tests that require a session to be connected. +// All connections are successfully negotiated. +// All requests are replied with a server error. +pub(crate) fn mock_init_rules() -> impl IntoIterator { + handshake_rules() + .into_iter() + .chain(std::iter::once(RequestRule( + Condition::RequestOpcode(RequestOpcode::Query) + .or(Condition::RequestOpcode(RequestOpcode::Prepare)) + .or(Condition::RequestOpcode(RequestOpcode::Batch)), + // We won't respond to any queries (including metadata fetch), + // but the driver will manage to continue with dummy metadata. + RequestReaction::forge().server_error(), + ))) +} + pub(crate) fn drop_metadata_queries_rules() -> impl IntoIterator { [RequestRule( Condition::ConnectionRegisteredAnyEvent.and(Condition::or(