From 61ef25bc2664fb6fc65dbac86cf9909b4867f4ae Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 13:49:23 +0100 Subject: [PATCH 001/134] feat(quaint): allow wasm32-unknown-unknown compilation; currently fails on native --- Cargo.lock | 1 + Cargo.toml | 1 + quaint/Cargo.toml | 41 +- quaint/src/connector.rs | 50 ++- quaint/src/connector/mssql.rs | 363 +-------------- quaint/src/connector/mssql_wasm.rs | 383 ++++++++++++++++ quaint/src/connector/mysql.rs | 297 +------------ quaint/src/connector/mysql_wasm.rs | 318 +++++++++++++ quaint/src/connector/postgres.rs | 423 +----------------- quaint/src/connector/postgres_wasm.rs | 612 ++++++++++++++++++++++++++ quaint/src/connector/sqlite.rs | 104 +---- quaint/src/connector/sqlite_wasm.rs | 103 +++++ quaint/src/error.rs | 6 +- quaint/src/pooled/manager.rs | 30 +- quaint/src/single.rs | 10 +- 15 files changed, 1533 insertions(+), 1209 deletions(-) create mode 100644 quaint/src/connector/mssql_wasm.rs create mode 100644 quaint/src/connector/mysql_wasm.rs create mode 100644 quaint/src/connector/postgres_wasm.rs create mode 100644 quaint/src/connector/sqlite_wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 35eff530999a..ff8323e356e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3570,6 +3570,7 @@ dependencies = [ "connection-string", "either", "futures", + "getrandom 0.2.10", "hex", "indoc 0.3.6", "lru-cache", diff --git a/Cargo.toml b/Cargo.toml index 4a3cd1450caf..66f4399ff6db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", + "connectors", ] [profile.dev.package.backtrace] diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b699518d0910..2da9ec0929c0 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -23,20 +23,28 @@ resolver = "2" features = ["docs", "all"] [features] -default = [] +default = ["mysql", "postgresql", "mssql", "sqlite"] docs = [] # Expose the underlying database drivers when a connector is enabled. This is a # way to access database-specific methods when you need extra control. expose-drivers = [] -all = ["mssql", "mysql", "pooled", "postgresql", "sqlite"] +connectors = [ + "postgresql-connector", + "mysql-connector", + "mssql-connector", + "sqlite-connector", +] + +all = ["connectors", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql = [ +postgresql-connector = [ + "postgresql", "native-tls", "tokio-postgres", "postgres-types", @@ -47,11 +55,24 @@ postgresql = [ "lru-cache", "byteorder", ] +postgresql = [] + +mssql-connector = [ + "mssql", + "tiberius", + "tokio-util", + "tokio/time", + "tokio/net", +] +mssql = [] + +mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql = ["chrono/std"] -mssql = ["tiberius", "tokio-util", "tokio/time", "tokio/net", "either"] -mysql = ["mysql_async", "tokio/time", "lru-cache"] pooled = ["mobc"] -sqlite = ["rusqlite", "tokio/sync"] +sqlite-connector = ["sqlite", "rusqlite", "tokio/sync"] +sqlite = [] + fmt-sql = ["sqlformat"] [dependencies] @@ -67,7 +88,7 @@ futures = "0.3" url = "2.1" hex = "0.4" -either = { version = "1.6", optional = true } +either = { version = "1.6" } base64 = { version = "0.12.3" } chrono = { version = "0.4", default-features = false, features = ["serde"] } lru-cache = { version = "0.1", optional = true } @@ -88,7 +109,11 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.0", features = ["macros", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] +version = "0.2" +features = ["js"] [dependencies.byteorder] default-features = false diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index de8bc64d22bb..898aac8fcb46 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -10,36 +10,62 @@ //! querying interface. mod connection_info; + pub mod metrics; mod queryable; mod result_set; -#[cfg(any(feature = "mssql", feature = "postgresql", feature = "mysql"))] +#[cfg(any( + feature = "mssql-connector", + feature = "postgresql-connector", + feature = "mysql-connector" +))] mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-connector")] pub(crate) mod mssql; -#[cfg(feature = "mysql")] +#[cfg(feature = "mssql")] +pub(crate) mod mssql_wasm; +#[cfg(feature = "mysql-connector")] pub(crate) mod mysql; -#[cfg(feature = "postgresql")] +#[cfg(feature = "mysql")] +pub(crate) mod mysql_wasm; +#[cfg(feature = "postgresql-connector")] pub(crate) mod postgres; -#[cfg(feature = "sqlite")] +#[cfg(feature = "postgresql")] +pub(crate) mod postgres_wasm; +#[cfg(feature = "sqlite-connector")] pub(crate) mod sqlite; +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite_wasm; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-connector")] pub use self::mysql::*; -#[cfg(feature = "postgresql")] +#[cfg(feature = "mysql")] +pub use self::mysql_wasm::*; +#[cfg(feature = "postgresql-connector")] pub use self::postgres::*; +#[cfg(feature = "postgresql")] +pub use self::postgres_wasm::*; +#[cfg(feature = "mssql-connector")] +pub use mssql::*; +#[cfg(feature = "mssql")] +pub use mssql_wasm::*; +#[cfg(feature = "sqlite-connector")] +pub use sqlite::*; +#[cfg(feature = "sqlite")] +pub use sqlite_wasm::*; + pub use self::result_set::*; pub use connection_info::*; -#[cfg(feature = "mssql")] -pub use mssql::*; pub use queryable::*; -#[cfg(feature = "sqlite")] -pub use sqlite::*; pub use transaction::*; -#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] +#[cfg(any( + feature = "mssql-connector", + feature = "postgresql-connector", + feature = "mysql-connector" +))] #[allow(unused_imports)] pub(crate) use type_identifier::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..16c31551768c 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,21 +1,19 @@ mod conversion; mod error; +pub(crate) use super::mssql_wasm::MssqlUrl; use super::{IsolationLevel, Transaction, TransactionOptions}; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - error::{Error, ErrorKind}, visitor::{self, Visitor}, }; use async_trait::async_trait; -use connection_string::JdbcString; use futures::lock::Mutex; use std::{ convert::TryFrom, fmt, future::Future, - str::FromStr, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; @@ -27,69 +25,6 @@ use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; #[cfg(feature = "expose-drivers")] pub use tiberius; -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -/// TLS mode when connecting to SQL Server. -#[derive(Debug, Clone, Copy)] -pub enum EncryptMode { - /// All traffic is encrypted. - On, - /// Only the login credentials are encrypted. - Off, - /// Nothing is encrypted. - DangerPlainText, -} - -impl fmt::Display for EncryptMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::On => write!(f, "true"), - Self::Off => write!(f, "false"), - Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), - } - } -} - -impl FromStr for EncryptMode { - type Err = Error; - - fn from_str(s: &str) -> crate::Result { - let mode = match s.parse::() { - Ok(true) => Self::On, - _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, - _ => Self::Off, - }; - - Ok(mode) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: EncryptMode, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - schema: String, - trust_server_certificate: bool, - trust_server_certificate_ca: Option, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - transaction_isolation_level: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, -} - static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; #[async_trait] @@ -114,158 +49,6 @@ impl TransactionCapable for Mssql { } } -impl MssqlUrl { - /// Maximum number of connections the pool can have (if used together with - /// pooled Quaint). - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - /// A duration how long one query can take. - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - /// A duration how long we can try to connect to the database. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - /// A pool check_out timeout. - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout() - } - - /// The isolation level of a transaction. - fn transaction_isolation_level(&self) -> Option { - self.query_params.transaction_isolation_level - } - - /// Name of the database. - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - /// The prefix which to use when querying database. - pub fn schema(&self) -> &str { - self.query_params.schema() - } - - /// Database hostname. - pub fn host(&self) -> &str { - self.query_params.host() - } - - /// The username to use when connecting to the database. - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - /// The password to use when connecting to the database. - pub fn password(&self) -> Option<&str> { - self.query_params.password() - } - - /// The TLS mode to use when connecting to the database. - pub fn encrypt(&self) -> EncryptMode { - self.query_params.encrypt() - } - - /// If true, we allow invalid certificates (self-signed, or otherwise - /// dangerous) when connecting. Should be true only for development and - /// testing. - pub fn trust_server_certificate(&self) -> bool { - self.query_params.trust_server_certificate() - } - - /// Path to a custom server certificate file. - pub fn trust_server_certificate_ca(&self) -> Option<&str> { - self.query_params.trust_server_certificate_ca() - } - - /// Database port. - pub fn port(&self) -> u16 { - self.query_params.port() - } - - /// The JDBC connection string - pub fn connection_string(&self) -> &str { - &self.connection_string - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime() - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime() - } -} - -impl MssqlQueryParams { - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_deref().unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_deref() - } - - fn password(&self) -> Option<&str> { - self.password.as_deref() - } - - fn encrypt(&self) -> EncryptMode { - self.encrypt - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn trust_server_certificate_ca(&self) -> Option<&str> { - self.trust_server_certificate_ca.as_deref() - } - - fn database(&self) -> &str { - &self.database - } - - fn schema(&self) -> &str { - &self.schema - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } - - fn pool_timeout(&self) -> Option { - self.pool_timeout - } - - fn max_connection_lifetime(&self) -> Option { - self.max_connection_lifetime - } - - fn max_idle_connection_lifetime(&self) -> Option { - self.max_idle_connection_lifetime - } -} - /// A connector interface for the SQL Server database. #[derive(Debug)] pub struct Mssql { @@ -452,150 +235,6 @@ impl Queryable for Mssql { } } -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); - - Ok(Self { - connection_string, - query_params, - }) - } - - fn with_jdbc_prefix(input: &str) -> String { - if input.starts_with("jdbc:sqlserver") { - input.into() - } else { - format!("jdbc:{input}") - } - } - - fn parse_query_params(input: &str) -> crate::Result { - let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; - - let host = conn.server_name().map(|server_name| match conn.instance_name() { - Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), - None => server_name.to_string(), - }); - - let port = conn.port(); - let props = conn.properties_mut(); - let user = props.remove("user"); - let password = props.remove("password"); - let database = props.remove("database").unwrap_or_else(|| String::from("master")); - let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); - - let connection_limit = props - .remove("connectionlimit") - .or_else(|| props.remove("connection_limit")) - .map(|param| param.parse()) - .transpose()?; - - let transaction_isolation_level = props - .remove("isolationlevel") - .or_else(|| props.remove("isolation_level")) - .map(|level| { - IsolationLevel::from_str(&level).map_err(|_| { - let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); - Error::builder(kind).build() - }) - }) - .transpose()?; - - let mut connect_timeout = props - .remove("logintimeout") - .or_else(|| props.remove("login_timeout")) - .or_else(|| props.remove("connecttimeout")) - .or_else(|| props.remove("connect_timeout")) - .or_else(|| props.remove("connectiontimeout")) - .or_else(|| props.remove("connection_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match connect_timeout { - None => connect_timeout = Some(Duration::from_secs(5)), - Some(dur) if dur.as_secs() == 0 => connect_timeout = None, - _ => (), - } - - let mut pool_timeout = props - .remove("pooltimeout") - .or_else(|| props.remove("pool_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match pool_timeout { - None => pool_timeout = Some(Duration::from_secs(10)), - Some(dur) if dur.as_secs() == 0 => pool_timeout = None, - _ => (), - } - - let socket_timeout = props - .remove("sockettimeout") - .or_else(|| props.remove("socket_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - let encrypt = props - .remove("encrypt") - .map(|param| EncryptMode::from_str(¶m)) - .transpose()? - .unwrap_or(EncryptMode::On); - - let trust_server_certificate = props - .remove("trustservercertificate") - .or_else(|| props.remove("trust_server_certificate")) - .map(|param| param.parse()) - .transpose()? - .unwrap_or(false); - - let trust_server_certificate_ca: Option = props - .remove("trustservercertificateca") - .or_else(|| props.remove("trust_server_certificate_ca")); - - let mut max_connection_lifetime = props - .remove("max_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_connection_lifetime { - Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, - _ => (), - } - - let mut max_idle_connection_lifetime = props - .remove("max_idle_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_idle_connection_lifetime { - None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), - Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, - _ => (), - } - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - schema, - trust_server_certificate, - trust_server_certificate_ca, - connection_limit, - socket_timeout, - connect_timeout, - pool_timeout, - transaction_isolation_level, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } -} - #[cfg(test)] mod tests { use crate::tests::test_api::mssql::CONN_STR; diff --git a/quaint/src/connector/mssql_wasm.rs b/quaint/src/connector/mssql_wasm.rs new file mode 100644 index 000000000000..d9f7dc27865b --- /dev/null +++ b/quaint/src/connector/mssql_wasm.rs @@ -0,0 +1,383 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use super::IsolationLevel; + +use crate::error::{Error, ErrorKind}; +use connection_string::JdbcString; +use std::{fmt, str::FromStr, time::Duration}; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct MssqlUrl { + pub(super) connection_string: String, + pub(super) query_params: MssqlQueryParams, +} + +/// TLS mode when connecting to SQL Server. +#[derive(Debug, Clone, Copy)] +pub enum EncryptMode { + /// All traffic is encrypted. + On, + /// Only the login credentials are encrypted. + Off, + /// Nothing is encrypted. + DangerPlainText, +} + +impl fmt::Display for EncryptMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::On => write!(f, "true"), + Self::Off => write!(f, "false"), + Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), + } + } +} + +impl FromStr for EncryptMode { + type Err = Error; + + fn from_str(s: &str) -> crate::Result { + let mode = match s.parse::() { + Ok(true) => Self::On, + _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, + _ => Self::Off, + }; + + Ok(mode) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + pub(super) encrypt: EncryptMode, + pub(super) port: Option, + pub(super) host: Option, + pub(super) user: Option, + pub(super) password: Option, + pub(super) database: String, + pub(super) schema: String, + pub(super) trust_server_certificate: bool, + pub(super) trust_server_certificate_ca: Option, + pub(super) connection_limit: Option, + pub(super) socket_timeout: Option, + pub(super) connect_timeout: Option, + pub(super) pool_timeout: Option, + pub(super) transaction_isolation_level: Option, + pub(super) max_connection_lifetime: Option, + pub(super) max_idle_connection_lifetime: Option, +} + +impl MssqlUrl { + /// Maximum number of connections the pool can have (if used together with + /// pooled Quaint). + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + /// A duration how long one query can take. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + /// A duration how long we can try to connect to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + /// A pool check_out timeout. + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout() + } + + /// The isolation level of a transaction. + pub(crate) fn transaction_isolation_level(&self) -> Option { + self.query_params.transaction_isolation_level + } + + /// Name of the database. + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + /// The prefix which to use when querying database. + pub fn schema(&self) -> &str { + self.query_params.schema() + } + + /// Database hostname. + pub fn host(&self) -> &str { + self.query_params.host() + } + + /// The username to use when connecting to the database. + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + /// The password to use when connecting to the database. + pub fn password(&self) -> Option<&str> { + self.query_params.password() + } + + /// The TLS mode to use when connecting to the database. + pub fn encrypt(&self) -> EncryptMode { + self.query_params.encrypt() + } + + /// If true, we allow invalid certificates (self-signed, or otherwise + /// dangerous) when connecting. Should be true only for development and + /// testing. + pub fn trust_server_certificate(&self) -> bool { + self.query_params.trust_server_certificate() + } + + /// Path to a custom server certificate file. + pub fn trust_server_certificate_ca(&self) -> Option<&str> { + self.query_params.trust_server_certificate_ca() + } + + /// Database port. + pub fn port(&self) -> u16 { + self.query_params.port() + } + + /// The JDBC connection string + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime() + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime() + } +} + +impl MssqlQueryParams { + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_deref().unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_deref() + } + + fn password(&self) -> Option<&str> { + self.password.as_deref() + } + + fn encrypt(&self) -> EncryptMode { + self.encrypt + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn trust_server_certificate_ca(&self) -> Option<&str> { + self.trust_server_certificate_ca.as_deref() + } + + fn database(&self) -> &str { + &self.database + } + + fn schema(&self) -> &str { + &self.schema + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.connect_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } + + fn pool_timeout(&self) -> Option { + self.pool_timeout + } + + fn max_connection_lifetime(&self) -> Option { + self.max_connection_lifetime + } + + fn max_idle_connection_lifetime(&self) -> Option { + self.max_idle_connection_lifetime + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); + + Ok(Self { + connection_string, + query_params, + }) + } + + fn with_jdbc_prefix(input: &str) -> String { + if input.starts_with("jdbc:sqlserver") { + input.into() + } else { + format!("jdbc:{input}") + } + } + + fn parse_query_params(input: &str) -> crate::Result { + let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; + + let host = conn.server_name().map(|server_name| match conn.instance_name() { + Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), + None => server_name.to_string(), + }); + + let port = conn.port(); + let props = conn.properties_mut(); + let user = props.remove("user"); + let password = props.remove("password"); + let database = props.remove("database").unwrap_or_else(|| String::from("master")); + let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); + + let connection_limit = props + .remove("connectionlimit") + .or_else(|| props.remove("connection_limit")) + .map(|param| param.parse()) + .transpose()?; + + let transaction_isolation_level = props + .remove("isolationlevel") + .or_else(|| props.remove("isolation_level")) + .map(|level| { + IsolationLevel::from_str(&level).map_err(|_| { + let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); + Error::builder(kind).build() + }) + }) + .transpose()?; + + let mut connect_timeout = props + .remove("logintimeout") + .or_else(|| props.remove("login_timeout")) + .or_else(|| props.remove("connecttimeout")) + .or_else(|| props.remove("connect_timeout")) + .or_else(|| props.remove("connectiontimeout")) + .or_else(|| props.remove("connection_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match connect_timeout { + None => connect_timeout = Some(Duration::from_secs(5)), + Some(dur) if dur.as_secs() == 0 => connect_timeout = None, + _ => (), + } + + let mut pool_timeout = props + .remove("pooltimeout") + .or_else(|| props.remove("pool_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match pool_timeout { + None => pool_timeout = Some(Duration::from_secs(10)), + Some(dur) if dur.as_secs() == 0 => pool_timeout = None, + _ => (), + } + + let socket_timeout = props + .remove("sockettimeout") + .or_else(|| props.remove("socket_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + let encrypt = props + .remove("encrypt") + .map(|param| EncryptMode::from_str(¶m)) + .transpose()? + .unwrap_or(EncryptMode::On); + + let trust_server_certificate = props + .remove("trustservercertificate") + .or_else(|| props.remove("trust_server_certificate")) + .map(|param| param.parse()) + .transpose()? + .unwrap_or(false); + + let trust_server_certificate_ca: Option = props + .remove("trustservercertificateca") + .or_else(|| props.remove("trust_server_certificate_ca")); + + let mut max_connection_lifetime = props + .remove("max_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_connection_lifetime { + Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, + _ => (), + } + + let mut max_idle_connection_lifetime = props + .remove("max_idle_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_idle_connection_lifetime { + None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), + Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, + _ => (), + } + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + schema, + trust_server_certificate, + trust_server_certificate_ca, + connection_limit, + socket_timeout, + connect_timeout, + pool_timeout, + transaction_isolation_level, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..a9a829404e76 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,7 @@ mod conversion; mod error; +pub(crate) use super::mysql_wasm::MysqlUrl; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -13,16 +14,12 @@ use mysql_async::{ self as my, prelude::{Query as _, Queryable as _}, }; -use percent_encoding::percent_decode; use std::{ - borrow::Cow, future::Future, - path::{Path, PathBuf}, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; use tokio::sync::Mutex; -use url::{Host, Url}; pub use error::MysqlError; @@ -33,293 +30,11 @@ pub use mysql_async; use super::IsolationLevel; -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - pub(crate) fn cache(&self) -> LruCache { LruCache::new(self.query_params.statement_cache_size) } - fn parse_query_params(url: &Url) -> Result { - let mut ssl_opts = my::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { let mut config = my::OptsBuilder::default() .stmt_cache_size(Some(0)) @@ -365,6 +80,16 @@ pub(crate) struct MysqlUrlQueryParams { statement_cache_size: usize, } +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + impl Mysql { /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. pub async fn new(url: MysqlUrl) -> crate::Result { diff --git a/quaint/src/connector/mysql_wasm.rs b/quaint/src/connector/mysql_wasm.rs new file mode 100644 index 000000000000..24cd525fea33 --- /dev/null +++ b/quaint/src/connector/mysql_wasm.rs @@ -0,0 +1,318 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(super) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(super) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-connector")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-connector")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-connector")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-connector")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..7a83e61218f6 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,6 +1,9 @@ mod conversion; mod error; +use super::postgres_wasm::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use super::postgres_wasm::{PostgresFlavour, PostgresUrl}; + use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -11,26 +14,19 @@ use async_trait::async_trait; use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; use postgres_native_tls::MakeTlsConnector; use std::{ - borrow::{Borrow, Cow}, + borrow::Borrow, fmt::{Debug, Display}, fs, future::Future, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; -use tokio_postgres::{ - config::{ChannelBinding, SslMode}, - Client, Config, Statement, -}; -use url::{Host, Url}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; pub use error::PostgresError; -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] @@ -38,15 +34,6 @@ pub use tokio_postgres; use super::{IsolationLevel, Transaction}; -#[derive(Clone)] -struct Hidden(T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - struct PostgresClient(Client); impl Debug for PostgresClient { @@ -65,20 +52,6 @@ pub struct PostgreSql { is_healthy: AtomicBool, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - #[derive(Debug)] struct SslAuth { certificate: Hidden>, @@ -146,166 +119,7 @@ impl SslParams { } } -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, - flavour: PostgresFlavour, -} - impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - pub(crate) fn cache(&self) -> LruCache { if self.query_params.pg_bouncer { LruCache::new(0) @@ -314,208 +128,8 @@ impl PostgresUrl { } } - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut application_name = None; - let mut channel_binding = ChannelBinding::Prefer; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - channel_binding, - options, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding } /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. @@ -569,29 +183,6 @@ impl PostgresUrl { config } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: Option, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - statement_cache_size: usize, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - application_name: Option, - channel_binding: ChannelBinding, - options: Option, } impl PostgreSql { diff --git a/quaint/src/connector/postgres_wasm.rs b/quaint/src/connector/postgres_wasm.rs new file mode 100644 index 000000000000..4c67b98cfa42 --- /dev/null +++ b/quaint/src/connector/postgres_wasm.rs @@ -0,0 +1,612 @@ +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-connector")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(super) certificate_file: Option, + pub(super) identity_file: Option, + pub(super) identity_password: Hidden>, + pub(super) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(super) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(super) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(super) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(super) url: Url, + pub(super) query_params: PostgresUrlQueryParams, + pub(super) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-connector")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-connector")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-connector")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-connector")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-connector")] + channel_binding, + #[cfg(feature = "postgresql-connector")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-connector")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-connector")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..fc993c1eaf0e 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,6 +1,7 @@ mod conversion; mod error; +pub(crate) use super::sqlite_wasm::{SqliteParams, DEFAULT_SQLITE_SCHEMA_NAME}; pub use error::SqliteError; pub use rusqlite::{params_from_iter, version as sqlite_version}; @@ -13,11 +14,9 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::{convert::TryFrom, path::Path, time::Duration}; +use std::convert::TryFrom; use tokio::sync::Mutex; -pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; - /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. #[cfg(feature = "expose-drivers")] pub use rusqlite; @@ -27,105 +26,6 @@ pub struct Sqlite { pub(crate) client: Mutex, } -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, - pub max_connection_lifetime: Option, - pub max_idle_connection_lifetime: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut socket_timeout = None; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), - socket_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } - } -} - impl TryFrom<&str> for Sqlite { type Error = Error; diff --git a/quaint/src/connector/sqlite_wasm.rs b/quaint/src/connector/sqlite_wasm.rs new file mode 100644 index 000000000000..10c174480785 --- /dev/null +++ b/quaint/src/connector/sqlite_wasm.rs @@ -0,0 +1,103 @@ +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug)] +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub max_connection_lifetime: Option, + pub max_idle_connection_lifetime: Option, +} + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let mut connection_limit = None; + let mut socket_timeout = None; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = None; + + if path_parts.len() > 1 { + let params = path_parts.last().unwrap().split('&').map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (splitted[0], splitted[1]) + }); + + for (k, v) in params { + match k { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = k); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), + socket_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } + } +} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..785fcc22ffe3 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -6,9 +6,9 @@ use thiserror::Error; #[cfg(feature = "pooled")] use std::time::Duration; -pub use crate::connector::mysql::MysqlError; -pub use crate::connector::postgres::PostgresError; -pub use crate::connector::sqlite::SqliteError; +// pub use crate::connector::mysql::MysqlError; +// pub use crate::connector::postgres::PostgresError; +// pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] pub enum DatabaseConstraint { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..c31fd44fbcae 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-connector")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-connector")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql")] +#[cfg(feature = "postgresql-connector")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..2f234e40fd74 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -130,27 +130,27 @@ impl Quaint { #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -166,7 +166,7 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { Ok(Quaint { From 055e696e40adb7294da2337ab86ded2d333ef5a8 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 15:02:09 +0100 Subject: [PATCH 002/134] feat(quaint): split postgres connector into native and wasm submodules --- quaint/src/connector.rs | 23 +- quaint/src/connector/postgres.rs | 1187 +---------------- .../postgres/{ => native}/conversion.rs | 0 .../{ => native}/conversion/decimal.rs | 0 quaint/src/connector/postgres/native/error.rs | 126 ++ quaint/src/connector/postgres/native/mod.rs | 1184 ++++++++++++++++ quaint/src/connector/postgres/wasm/common.rs | 612 +++++++++ .../connector/postgres/{ => wasm}/error.rs | 124 +- quaint/src/connector/postgres/wasm/mod.rs | 6 + quaint/src/error.rs | 2 +- 10 files changed, 1950 insertions(+), 1314 deletions(-) rename quaint/src/connector/postgres/{ => native}/conversion.rs (100%) rename quaint/src/connector/postgres/{ => native}/conversion/decimal.rs (100%) create mode 100644 quaint/src/connector/postgres/native/error.rs create mode 100644 quaint/src/connector/postgres/native/mod.rs create mode 100644 quaint/src/connector/postgres/wasm/common.rs rename quaint/src/connector/postgres/{ => wasm}/error.rs (66%) create mode 100644 quaint/src/connector/postgres/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 898aac8fcb46..71bba2d098ed 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -31,10 +31,10 @@ pub(crate) mod mssql_wasm; pub(crate) mod mysql; #[cfg(feature = "mysql")] pub(crate) mod mysql_wasm; -#[cfg(feature = "postgresql-connector")] -pub(crate) mod postgres; -#[cfg(feature = "postgresql")] -pub(crate) mod postgres_wasm; +// #[cfg(feature = "postgresql-connector")] +// pub(crate) mod postgres; +// #[cfg(feature = "postgresql")] +// pub(crate) mod postgres_wasm; #[cfg(feature = "sqlite-connector")] pub(crate) mod sqlite; #[cfg(feature = "sqlite")] @@ -44,10 +44,10 @@ pub(crate) mod sqlite_wasm; pub use self::mysql::*; #[cfg(feature = "mysql")] pub use self::mysql_wasm::*; -#[cfg(feature = "postgresql-connector")] -pub use self::postgres::*; -#[cfg(feature = "postgresql")] -pub use self::postgres_wasm::*; +// #[cfg(feature = "postgresql-connector")] +// pub use self::postgres::*; +// #[cfg(feature = "postgresql")] +// pub use self::postgres_wasm::*; #[cfg(feature = "mssql-connector")] pub use mssql::*; #[cfg(feature = "mssql")] @@ -70,3 +70,10 @@ pub use transaction::*; pub(crate) use type_identifier::*; pub use self::metrics::query; + +#[cfg(feature = "postgresql")] +pub(crate) mod postgres; +#[cfg(feature = "postgresql-connector")] +pub use postgres::native::*; +#[cfg(feature = "postgresql")] +pub use postgres::wasm::*; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 7a83e61218f6..9f4d4d496f2b 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,1184 +1,7 @@ -mod conversion; -mod error; +pub use wasm::error::PostgresError; -use super::postgres_wasm::{Hidden, SslAcceptMode, SslParams}; -pub(crate) use super::postgres_wasm::{PostgresFlavour, PostgresUrl}; +#[cfg(feature = "postgresql")] +pub(crate) mod wasm; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::Borrow, - fmt::{Debug, Display}, - fs, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; - -pub use error::PostgresError; - -/// The underlying postgres driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tokio_postgres; - -use super::{IsolationLevel, Transaction}; - -struct PostgresClient(Client); - -impl Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("PostgresClient") - } -} - -/// A connector interface for the PostgreSQL database. -#[derive(Debug)] -pub struct PostgreSql { - client: PostgresClient, - pg_bouncer: bool, - socket_timeout: Option, - statement_cache: Mutex>, - is_healthy: AtomicBool, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({err})"), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({err})"), - }) - .build() - })?; - let password = self.identity_password.0.as_deref().unwrap_or(""); - let identity = Identity::from_pkcs12(&db, password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -impl PostgresUrl { - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - - /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - fn set_search_path(&self, config: &mut Config) { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if self.query_params.pg_bouncer { - return; - } - - if let Some(schema) = &self.query_params.schema { - if self.flavour().is_cockroach() && is_safe_identifier(schema) { - config.search_path(CockroachSearchPath(schema).to_string()); - } - - if self.flavour().is_postgres() { - config.search_path(PostgresSearchPath(schema).to_string()); - } - } - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(options) = self.options() { - config.options(options); - } - - if let Some(application_name) = self.application_name() { - config.application_name(application_name); - } - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - } - - self.set_search_path(&mut config); - - config.ssl_mode(self.query_params.ssl_mode); - - config.channel_binding(self.query_params.channel_binding); - - config - } -} - -impl PostgreSql { - /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { - let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); - - // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. - if let Some(schema) = &url.query_params.schema { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if url.query_params.pg_bouncer - || url.flavour().is_unknown() - || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) - { - let session_variables = format!( - r##"{set_search_path}"##, - set_search_path = SetSearchPath(url.query_params.schema.as_deref()) - ); - - client.simple_query(session_variables.as_str()).await?; - } - } - - Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying tokio_postgres::Client. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &tokio_postgres::Client { - &self.client.0 - } - - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } - - fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { - if params.len() > i16::MAX as usize { - // tokio_postgres would return an error here. Let's avoid calling the driver - // and return an error early. - let kind = ErrorKind::QueryInvalidInput(format!( - "too many bind variables in prepared statement, expected maximum of {}, received {}", - i16::MAX, - params.len() - )); - Err(Error::builder(kind).build()) - } else { - Ok(()) - } - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -impl_default_TransactionCapable!(PostgreSql); - -#[async_trait] -impl Queryable for PostgreSql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.query_raw(sql.as_str(), ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.execute_raw(sql.as_str(), ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.perform_io(self.client.0.simple_query(cmd)).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { - if self.pg_bouncer { - tx.raw_cmd("DEALLOCATE ALL").await - } else { - Ok(()) - } - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::test_api::postgres::CONN_STR; - use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - - #[tokio::test] - async fn test_custom_search_path_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_pg_pgbouncer() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - url.query_pairs_mut().append_pair("pbbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - - #[test] - fn test_safe_ident() { - // Safe - assert!(is_safe_identifier("hello")); - assert!(is_safe_identifier("_hello")); - assert!(is_safe_identifier("àbracadabra")); - assert!(is_safe_identifier("h3ll0")); - assert!(is_safe_identifier("héllo")); - assert!(is_safe_identifier("héll0$")); - assert!(is_safe_identifier("héll_0$")); - assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); - - // Not safe - assert!(!is_safe_identifier("")); - assert!(!is_safe_identifier("Hello")); - assert!(!is_safe_identifier("hEllo")); - assert!(!is_safe_identifier("$hello")); - assert!(!is_safe_identifier("hello!")); - assert!(!is_safe_identifier("hello#")); - assert!(!is_safe_identifier("he llo")); - assert!(!is_safe_identifier(" hello")); - assert!(!is_safe_identifier("he-llo")); - assert!(!is_safe_identifier("hÉllo")); - assert!(!is_safe_identifier("1337")); - assert!(!is_safe_identifier("_HELLO")); - assert!(!is_safe_identifier("HELLO")); - assert!(!is_safe_identifier("HELLO$")); - assert!(!is_safe_identifier("ÀBRACADABRA")); - - for ident in RESERVED_KEYWORDS { - assert!(!is_safe_identifier(ident)); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert!(!is_safe_identifier(ident)); - } - } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } -} +#[cfg(feature = "postgresql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/postgres/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs similarity index 100% rename from quaint/src/connector/postgres/conversion.rs rename to quaint/src/connector/postgres/native/conversion.rs diff --git a/quaint/src/connector/postgres/conversion/decimal.rs b/quaint/src/connector/postgres/native/conversion/decimal.rs similarity index 100% rename from quaint/src/connector/postgres/conversion/decimal.rs rename to quaint/src/connector/postgres/native/conversion/decimal.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs new file mode 100644 index 000000000000..ec3b18483746 --- /dev/null +++ b/quaint/src/connector/postgres/native/error.rs @@ -0,0 +1,126 @@ +use tokio_postgres::error::DbError; + +use crate::{ + connector::error::PostgresError, + error::{Error, ErrorKind}, +}; + +impl From<&DbError> for PostgresError { + fn from(value: &DbError) -> Self { + PostgresError { + code: value.code().code().to_string(), + severity: value.severity().to_string(), + message: value.message().to_string(), + detail: value.detail().map(ToString::to_string), + column: value.column().map(ToString::to_string), + hint: value.hint().map(ToString::to_string), + } + } +} + +impl From for Error { + fn from(e: tokio_postgres::error::Error) -> Error { + if e.is_closed() { + return Error::builder(ErrorKind::ConnectionClosed).build(); + } + + if let Some(db_error) = e.as_db_error() { + return PostgresError::from(db_error).into(); + } + + if let Some(tls_error) = try_extracting_tls_error(&e) { + return tls_error; + } + + // Same for IO errors. + if let Some(io_error) = try_extracting_io_error(&e) { + return io_error; + } + + if let Some(uuid_error) = try_extracting_uuid_error(&e) { + return uuid_error; + } + + let reason = format!("{e}"); + let code = e.code().map(|c| c.code()); + + match reason.as_str() { + "error connecting to server: timed out" => { + let mut builder = Error::builder(ErrorKind::ConnectTimeout); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // sigh... + // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 + "error performing TLS handshake: server does not support TLS" => { + let mut builder = Error::builder(ErrorKind::TlsError { + message: reason.clone(), + }); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // double sigh + _ => { + let code = code.map(|c| c.to_string()); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } + } + } +} + +fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::UUIDError(format!("{err}"))) + .map(|kind| Error::builder(kind).build()) +} + +fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| err.into()) +} + +fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) + .map(|kind| Error::builder(kind).build()) +} + +impl From for Error { + fn from(e: native_tls::Error) -> Error { + Error::from(&e) + } +} + +impl From<&native_tls::Error> for Error { + fn from(e: &native_tls::Error) -> Error { + let kind = ErrorKind::TlsError { + message: format!("{e}"), + }; + + Error::builder(kind).build() + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs new file mode 100644 index 000000000000..8f1645ca4123 --- /dev/null +++ b/quaint/src/connector/postgres/native/mod.rs @@ -0,0 +1,1184 @@ +///! Definitions for the Postgres connector. +/// This module is not compatible with wasm32-* targets. +/// This module is only available with the `postgresql-connector` feature. +mod conversion; +mod error; + +use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use crate::connector::postgres::wasm::common::{PostgresFlavour, PostgresUrl}; +use crate::connector::{timeout, IsolationLevel, Transaction}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::{future::FutureExt, lock::Mutex}; +use lru_cache::LruCache; +use native_tls::{Certificate, Identity, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + fs, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; + +/// The underlying postgres driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tokio_postgres; + +struct PostgresClient(Client); + +impl Debug for PostgresClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PostgresClient") + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct PostgreSql { + client: PostgresClient, + pg_bouncer: bool, + socket_timeout: Option, + statement_cache: Mutex>, + is_healthy: AtomicBool, +} + +#[derive(Debug)] +struct SslAuth { + certificate: Hidden>, + identity: Hidden>, + ssl_accept_mode: SslAcceptMode, +} + +impl Default for SslAuth { + fn default() -> Self { + Self { + certificate: Hidden(None), + identity: Hidden(None), + ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, + } + } +} + +impl SslAuth { + fn certificate(&mut self, certificate: Certificate) -> &mut Self { + self.certificate = Hidden(Some(certificate)); + self + } + + fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Hidden(Some(identity)); + self + } + + fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { + self.ssl_accept_mode = mode; + self + } +} + +impl SslParams { + async fn into_auth(self) -> crate::Result { + let mut auth = SslAuth::default(); + auth.accept_mode(self.ssl_accept_mode); + + if let Some(ref cert_file) = self.certificate_file { + let cert = fs::read(cert_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("cert file not found ({err})"), + }) + .build() + })?; + + auth.certificate(Certificate::from_pem(&cert)?); + } + + if let Some(ref identity_file) = self.identity_file { + let db = fs::read(identity_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("identity file not found ({err})"), + }) + .build() + })?; + let password = self.identity_password.0.as_deref().unwrap_or(""); + let identity = Identity::from_pkcs12(&db, password)?; + + auth.identity(identity); + } + + Ok(auth) + } +} + +impl PostgresUrl { + pub(crate) fn cache(&self) -> LruCache { + if self.query_params.pg_bouncer { + LruCache::new(0) + } else { + LruCache::new(self.query_params.statement_cache_size) + } + } + + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding + } + + /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + fn set_search_path(&self, config: &mut Config) { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if self.query_params.pg_bouncer { + return; + } + + if let Some(schema) = &self.query_params.schema { + if self.flavour().is_cockroach() && is_safe_identifier(schema) { + config.search_path(CockroachSearchPath(schema).to_string()); + } + + if self.flavour().is_postgres() { + config.search_path(PostgresSearchPath(schema).to_string()); + } + } + } + + pub(crate) fn to_config(&self) -> Config { + let mut config = Config::new(); + + config.user(self.username().borrow()); + config.password(self.password().borrow() as &str); + config.host(self.host()); + config.port(self.port()); + config.dbname(self.dbname()); + config.pgbouncer_mode(self.query_params.pg_bouncer); + + if let Some(options) = self.options() { + config.options(options); + } + + if let Some(application_name) = self.application_name() { + config.application_name(application_name); + } + + if let Some(connect_timeout) = self.query_params.connect_timeout { + config.connect_timeout(connect_timeout); + } + + self.set_search_path(&mut config); + + config.ssl_mode(self.query_params.ssl_mode); + + config.channel_binding(self.query_params.channel_binding); + + config + } +} + +impl PostgreSql { + /// Create a new connection to the database. + pub async fn new(url: PostgresUrl) -> crate::Result { + let config = url.to_config(); + + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls = MakeTlsConnector::new(tls_builder.build()?); + let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + + tokio::spawn(conn.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + + // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. + if let Some(schema) = &url.query_params.schema { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if url.query_params.pg_bouncer + || url.flavour().is_unknown() + || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) + { + let session_variables = format!( + r##"{set_search_path}"##, + set_search_path = SetSearchPath(url.query_params.schema.as_deref()) + ); + + client.simple_query(session_variables.as_str()).await?; + } + } + + Ok(Self { + client: PostgresClient(client), + socket_timeout: url.query_params.socket_timeout, + pg_bouncer: url.query_params.pg_bouncer, + statement_cache: Mutex::new(url.cache()), + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying tokio_postgres::Client. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &tokio_postgres::Client { + &self.client.0 + } + + async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let param_types = conversion::params_to_types(params); + let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; + + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } + + fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { + if params.len() > i16::MAX as usize { + // tokio_postgres would return an error here. Let's avoid calling the driver + // and return an error early. + let kind = ErrorKind::QueryInvalidInput(format!( + "too many bind variables in prepared statement, expected maximum of {}, received {}", + i16::MAX, + params.len() + )); + Err(Error::builder(kind).build()) + } else { + Ok(()) + } + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +impl_default_TransactionCapable!(PostgreSql); + +#[async_trait] +impl Queryable for PostgreSql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.query_raw(sql.as_str(), ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.execute_raw(sql.as_str(), ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + self.perform_io(self.client.0.simple_query(cmd)).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT version()"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { + if self.pg_bouncer { + tx.raw_cmd("DEALLOCATE ALL").await + } else { + Ok(()) + } + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::test_api::postgres::CONN_STR; + use crate::tests::test_api::CRDB_CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn test_custom_search_path_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_pg_pgbouncer() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + url.query_pairs_mut().append_pair("pbbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn test_safe_ident() { + // Safe + assert!(is_safe_identifier("hello")); + assert!(is_safe_identifier("_hello")); + assert!(is_safe_identifier("àbracadabra")); + assert!(is_safe_identifier("h3ll0")); + assert!(is_safe_identifier("héllo")); + assert!(is_safe_identifier("héll0$")); + assert!(is_safe_identifier("héll_0$")); + assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); + + // Not safe + assert!(!is_safe_identifier("")); + assert!(!is_safe_identifier("Hello")); + assert!(!is_safe_identifier("hEllo")); + assert!(!is_safe_identifier("$hello")); + assert!(!is_safe_identifier("hello!")); + assert!(!is_safe_identifier("hello#")); + assert!(!is_safe_identifier("he llo")); + assert!(!is_safe_identifier(" hello")); + assert!(!is_safe_identifier("he-llo")); + assert!(!is_safe_identifier("hÉllo")); + assert!(!is_safe_identifier("1337")); + assert!(!is_safe_identifier("_HELLO")); + assert!(!is_safe_identifier("HELLO")); + assert!(!is_safe_identifier("HELLO$")); + assert!(!is_safe_identifier("ÀBRACADABRA")); + + for ident in RESERVED_KEYWORDS { + assert!(!is_safe_identifier(ident)); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert!(!is_safe_identifier(ident)); + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs new file mode 100644 index 000000000000..46d327c0183d --- /dev/null +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -0,0 +1,612 @@ +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-connector")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(crate) certificate_file: Option, + pub(crate) identity_file: Option, + pub(crate) identity_password: Hidden>, + pub(crate) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(crate) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(crate) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(crate) url: Url, + pub(crate) query_params: PostgresUrlQueryParams, + pub(crate) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-connector")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-connector")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-connector")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-connector")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-connector")] + channel_binding, + #[cfg(feature = "postgresql-connector")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-connector")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-connector")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/wasm/error.rs similarity index 66% rename from quaint/src/connector/postgres/error.rs rename to quaint/src/connector/postgres/wasm/error.rs index d4e5ec7837fe..ab6ec7b07847 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/wasm/error.rs @@ -1,7 +1,5 @@ use std::fmt::{Display, Formatter}; -use tokio_postgres::error::DbError; - use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; #[derive(Debug)] @@ -17,7 +15,7 @@ pub struct PostgresError { impl std::error::Error for PostgresError {} impl Display for PostgresError { - // copy of DbError::fmt + // copy of tokio_postgres::error::DbError::fmt fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { @@ -30,19 +28,6 @@ impl Display for PostgresError { } } -impl From<&DbError> for PostgresError { - fn from(value: &DbError) -> Self { - PostgresError { - code: value.code().code().to_string(), - severity: value.severity().to_string(), - message: value.message().to_string(), - detail: value.detail().map(ToString::to_string), - column: value.column().map(ToString::to_string), - hint: value.hint().map(ToString::to_string), - } - } -} - impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -245,110 +230,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - if e.is_closed() { - return Error::builder(ErrorKind::ConnectionClosed).build(); - } - - if let Some(db_error) = e.as_db_error() { - return PostgresError::from(db_error).into(); - } - - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } - - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - if let Some(uuid_error) = try_extracting_uuid_error(&e) { - return uuid_error; - } - - let reason = format!("{e}"); - let code = e.code().map(|c| c.code()); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } - } -} - -fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::UUIDError(format!("{err}"))) - .map(|kind| Error::builder(kind).build()) -} - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{e}"), - }; - - Error::builder(kind).build() - } -} diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs new file mode 100644 index 000000000000..5b330861a199 --- /dev/null +++ b/quaint/src/connector/postgres/wasm/mod.rs @@ -0,0 +1,6 @@ +///! Wasm-compatible definitions for the Postgres connector. +/// /// This module is only available with the `postgresql` feature. +pub(crate) mod common; +pub mod error; + +pub use common::PostgresUrl; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 785fcc22ffe3..73bf5c405c66 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -7,7 +7,7 @@ use thiserror::Error; use std::time::Duration; // pub use crate::connector::mysql::MysqlError; -// pub use crate::connector::postgres::PostgresError; +pub use crate::connector::postgres::PostgresError; // pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] From 12c6ebb15a31d42b5be0301892a122526b6521c7 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 15:49:39 +0100 Subject: [PATCH 003/134] feat(quaint): split mysql connector into native and wasm submodules --- quaint/src/connector.rs | 25 +- quaint/src/connector/mysql.rs | 398 +----------- .../mysql/{ => native}/conversion.rs | 0 quaint/src/connector/mysql/native/error.rs | 36 ++ quaint/src/connector/mysql/native/mod.rs | 392 +++++++++++ quaint/src/connector/mysql/wasm/common.rs | 316 +++++++++ .../src/connector/mysql/{ => wasm}/error.rs | 65 +- quaint/src/connector/mysql/wasm/mod.rs | 6 + quaint/src/connector/postgres.rs | 1 + quaint/src/connector/postgres/native/error.rs | 2 +- quaint/src/connector/postgres_wasm.rs | 612 ------------------ quaint/src/error.rs | 2 +- 12 files changed, 792 insertions(+), 1063 deletions(-) rename quaint/src/connector/mysql/{ => native}/conversion.rs (100%) create mode 100644 quaint/src/connector/mysql/native/error.rs create mode 100644 quaint/src/connector/mysql/native/mod.rs create mode 100644 quaint/src/connector/mysql/wasm/common.rs rename quaint/src/connector/mysql/{ => wasm}/error.rs (81%) create mode 100644 quaint/src/connector/mysql/wasm/mod.rs delete mode 100644 quaint/src/connector/postgres_wasm.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 71bba2d098ed..d0e4d7e57bdc 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -27,10 +27,10 @@ mod type_identifier; pub(crate) mod mssql; #[cfg(feature = "mssql")] pub(crate) mod mssql_wasm; -#[cfg(feature = "mysql-connector")] -pub(crate) mod mysql; -#[cfg(feature = "mysql")] -pub(crate) mod mysql_wasm; +// #[cfg(feature = "mysql-connector")] +// pub(crate) mod mysql; +// #[cfg(feature = "mysql")] +// pub(crate) mod mysql_wasm; // #[cfg(feature = "postgresql-connector")] // pub(crate) mod postgres; // #[cfg(feature = "postgresql")] @@ -40,10 +40,10 @@ pub(crate) mod sqlite; #[cfg(feature = "sqlite")] pub(crate) mod sqlite_wasm; -#[cfg(feature = "mysql-connector")] -pub use self::mysql::*; -#[cfg(feature = "mysql")] -pub use self::mysql_wasm::*; +// #[cfg(feature = "mysql-connector")] +// pub use self::mysql::*; +// #[cfg(feature = "mysql")] +// pub use self::mysql_wasm::*; // #[cfg(feature = "postgresql-connector")] // pub use self::postgres::*; // #[cfg(feature = "postgresql")] @@ -76,4 +76,11 @@ pub(crate) mod postgres; #[cfg(feature = "postgresql-connector")] pub use postgres::native::*; #[cfg(feature = "postgresql")] -pub use postgres::wasm::*; +pub use postgres::wasm::common::*; + +#[cfg(feature = "mysql")] +pub(crate) mod mysql; +#[cfg(feature = "mysql-connector")] +pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::wasm::common::*; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index a9a829404e76..1794cc738b1e 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,394 +1,8 @@ -mod conversion; -mod error; +pub use wasm::common::MysqlUrl; +pub use wasm::error::MysqlError; -pub(crate) use super::mysql_wasm::MysqlUrl; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use lru_cache::LruCache; -use mysql_async::{ - self as my, - prelude::{Query as _, Queryable as _}, -}; -use std::{ - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio::sync::Mutex; +#[cfg(feature = "mysql")] +pub(crate) mod wasm; -pub use error::MysqlError; - -/// The underlying MySQL driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use mysql_async; - -use super::IsolationLevel; - -impl MysqlUrl { - pub(crate) fn cache(&self) -> LruCache { - LruCache::new(self.query_params.statement_cache_size) - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::default() - .stmt_cache_size(Some(0)) - .user(Some(self.username())) - .pass(self.password()) - .db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config = config.socket(Some(socket)); - } - None => { - config = config.ip_or_hostname(self.host()).tcp_port(self.port()); - } - } - - config = config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - if self.query_params.prefer_socket.is_some() { - config = config.prefer_socket(self.query_params.prefer_socket); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -impl Mysql { - /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub async fn new(url: MysqlUrl) -> crate::Result { - let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; - - Ok(Self { - socket_timeout: url.query_params.socket_timeout, - conn: Mutex::new(conn), - statement_cache: Mutex::new(url.cache()), - url, - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying mysql_async::Conn. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn conn(&self) -> &Mutex { - &self.conn - } - - async fn perform_io(&self, op: U) -> crate::Result - where - F: Future>, - U: FnOnce() -> F, - { - match super::timeout::socket(self.socket_timeout, op()).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => Ok(res?), - } - } - - async fn prepared(&self, sql: &str, op: U) -> crate::Result - where - F: Future>, - U: Fn(my::Statement) -> F, - { - if self.url.statement_cache_size() == 0 { - self.perform_io(|| async move { - let stmt = { - let mut conn = self.conn.lock().await; - conn.prep(sql).await? - }; - - let res = op(stmt.clone()).await; - - { - let mut conn = self.conn.lock().await; - conn.close(stmt).await?; - } - - res - }) - .await - } else { - self.perform_io(|| async move { - let stmt = self.fetch_cached(sql).await?; - op(stmt).await - }) - .await - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let mut conn = self.conn.lock().await; - if cache.capacity() == cache.len() { - if let Some((_, stmt)) = cache.remove_lru() { - conn.close(stmt).await?; - } - } - - let stmt = conn.prep(sql).await?; - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } -} - -impl_default_TransactionCapable!(Mysql); - -#[async_trait] -impl Queryable for Mysql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); - - let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); - - for mut row in rows { - result_set.rows.push(row.take_result_row()?); - } - - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; - - Ok(result_set) - }) - .await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - conn.exec_drop(stmt, conversion::conv_params(params)?).await?; - - Ok(conn.affected_rows()) - }) - .await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - self.perform_io(|| async move { - let mut conn = self.conn.lock().await; - let mut result = cmd.run(&mut *conn).await?; - - loop { - result.map(drop).await?; - - if result.is_empty() { - result.map(drop).await?; - break; - } - } - - Ok(()) - }) - .await - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@GLOBAL.version version"#; - let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mysql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/mysql/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mysql/conversion.rs rename to quaint/src/connector/mysql/native/conversion.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs new file mode 100644 index 000000000000..e00ff1e0aa74 --- /dev/null +++ b/quaint/src/connector/mysql/native/error.rs @@ -0,0 +1,36 @@ +use crate::{ + connector::mysql::wasm::error::MysqlError, + error::{Error, ErrorKind}, +}; +use mysql_async as my; + +impl From<&my::ServerError> for MysqlError { + fn from(value: &my::ServerError) -> Self { + MysqlError { + code: value.code, + message: value.message.to_owned(), + state: value.state.to_owned(), + } + } +} + +impl From for Error { + fn from(e: my::Error) -> Error { + match e { + my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { + message: err.to_string(), + }) + .build(), + my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + Error::builder(ErrorKind::ConnectionClosed).build() + } + my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), + my::Error::Server(ref server_error) => { + let mysql_error: MysqlError = server_error.into(); + mysql_error.into() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs new file mode 100644 index 000000000000..1a9652b628f8 --- /dev/null +++ b/quaint/src/connector/mysql/native/mod.rs @@ -0,0 +1,392 @@ +mod conversion; +mod error; + +pub(crate) use crate::connector::mysql::wasm::common::MysqlUrl; +use crate::connector::{timeout, IsolationLevel}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use lru_cache::LruCache; +use mysql_async::{ + self as my, + prelude::{Query as _, Queryable as _}, +}; +use std::{ + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::sync::Mutex; + +/// The underlying MySQL driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use mysql_async; + +impl MysqlUrl { + pub(crate) fn cache(&self) -> LruCache { + LruCache::new(self.query_params.statement_cache_size) + } + + pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { + let mut config = my::OptsBuilder::default() + .stmt_cache_size(Some(0)) + .user(Some(self.username())) + .pass(self.password()) + .db_name(Some(self.dbname())); + + match self.socket() { + Some(ref socket) => { + config = config.socket(Some(socket)); + } + None => { + config = config.ip_or_hostname(self.host()).tcp_port(self.port()); + } + } + + config = config.conn_ttl(Some(Duration::from_secs(5))); + + if self.query_params.use_ssl { + config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); + } + + if self.query_params.prefer_socket.is_some() { + config = config.prefer_socket(self.query_params.prefer_socket); + } + + config + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + ssl_opts: my::SslOpts, + connection_limit: Option, + use_ssl: bool, + socket: Option, + socket_timeout: Option, + connect_timeout: Option, + pool_timeout: Option, + max_connection_lifetime: Option, + max_idle_connection_lifetime: Option, + prefer_socket: Option, + statement_cache_size: usize, +} + +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + +impl Mysql { + /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. + pub async fn new(url: MysqlUrl) -> crate::Result { + let conn = timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; + + Ok(Self { + socket_timeout: url.query_params.socket_timeout, + conn: Mutex::new(conn), + statement_cache: Mutex::new(url.cache()), + url, + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying mysql_async::Conn. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn conn(&self) -> &Mutex { + &self.conn + } + + async fn perform_io(&self, op: U) -> crate::Result + where + F: Future>, + U: FnOnce() -> F, + { + match timeout::socket(self.socket_timeout, op()).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => Ok(res?), + } + } + + async fn prepared(&self, sql: &str, op: U) -> crate::Result + where + F: Future>, + U: Fn(my::Statement) -> F, + { + if self.url.statement_cache_size() == 0 { + self.perform_io(|| async move { + let stmt = { + let mut conn = self.conn.lock().await; + conn.prep(sql).await? + }; + + let res = op(stmt.clone()).await; + + { + let mut conn = self.conn.lock().await; + conn.close(stmt).await?; + } + + res + }) + .await + } else { + self.perform_io(|| async move { + let stmt = self.fetch_cached(sql).await?; + op(stmt).await + }) + .await + } + } + + async fn fetch_cached(&self, sql: &str) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let mut conn = self.conn.lock().await; + if cache.capacity() == cache.len() { + if let Some((_, stmt)) = cache.remove_lru() { + conn.close(stmt).await?; + } + } + + let stmt = conn.prep(sql).await?; + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } +} + +impl_default_TransactionCapable!(Mysql); + +#[async_trait] +impl Queryable for Mysql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.query_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; + let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); + + let last_id = conn.last_insert_id(); + let mut result_set = ResultSet::new(columns, Vec::new()); + + for mut row in rows { + result_set.rows.push(row.take_result_row()?); + } + + if let Some(id) = last_id { + result_set.set_last_insert_id(id); + }; + + Ok(result_set) + }) + .await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.execute_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + conn.exec_drop(stmt, conversion::conv_params(params)?).await?; + + Ok(conn.affected_rows()) + }) + .await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + self.perform_io(|| async move { + let mut conn = self.conn.lock().await; + let mut result = cmd.run(&mut *conn).await?; + + loop { + result.map(drop).await?; + + if result.is_empty() { + result.map(drop).await?; + break; + } + } + + Ok(()) + }) + .await + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@GLOBAL.version version"#; + let rows = timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs new file mode 100644 index 000000000000..fe60fd24cfc1 --- /dev/null +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -0,0 +1,316 @@ +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(crate) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-connector")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-connector")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-connector")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-connector")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/wasm/error.rs similarity index 81% rename from quaint/src/connector/mysql/error.rs rename to quaint/src/connector/mysql/wasm/error.rs index dd7c3d3bfa66..c09ec84d7a7b 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/wasm/error.rs @@ -1,5 +1,4 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; pub struct MysqlError { pub code: u16, @@ -7,16 +6,6 @@ pub struct MysqlError { pub state: String, } -impl From<&my::ServerError> for MysqlError { - fn from(value: &my::ServerError) -> Self { - MysqlError { - code: value.code, - message: value.message.to_owned(), - state: value.state.to_owned(), - } - } -} - impl From for Error { fn from(error: MysqlError) -> Self { let code = error.code; @@ -230,43 +219,23 @@ impl From for Error { builder.set_original_message(error.message); builder.build() } - _ => { - let kind = ErrorKind::QueryError( - my::Error::Server(my::ServerError { - message: error.message.clone(), - code, - state: error.state.clone(), - }) - .into(), - ); - - let mut builder = Error::builder(kind); - builder.set_original_code(format!("{code}")); - builder.set_original_message(error.message); - - builder.build() - } - } - } -} - -impl From for Error { - fn from(e: my::Error) -> Error { - match e { - my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { - message: err.to_string(), - }) - .build(), - my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - Error::builder(ErrorKind::ConnectionClosed).build() - } - my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::Error::Server(ref server_error) => { - let mysql_error: MysqlError = server_error.into(); - mysql_error.into() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), + _ => unimplemented!(), + // _ => { + // let kind = ErrorKind::QueryError( + // my::Error::Server(my::ServerError { + // message: error.message.clone(), + // code, + // state: error.state.clone(), + // }) + // .into(), + // ); + + // let mut builder = Error::builder(kind); + // builder.set_original_code(format!("{code}")); + // builder.set_original_message(error.message); + + // builder.build() + // } } } } diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs new file mode 100644 index 000000000000..da9a57a53876 --- /dev/null +++ b/quaint/src/connector/mysql/wasm/mod.rs @@ -0,0 +1,6 @@ +///! Wasm-compatible definitions for the MySQL connector. +/// /// This module is only available with the `mysql` feature. +pub(crate) mod common; +pub mod error; + +pub use common::MysqlUrl; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 9f4d4d496f2b..0f4da84a7c67 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,3 +1,4 @@ +pub use wasm::common::PostgresUrl; pub use wasm::error::PostgresError; #[cfg(feature = "postgresql")] diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs index ec3b18483746..05b792e27900 100644 --- a/quaint/src/connector/postgres/native/error.rs +++ b/quaint/src/connector/postgres/native/error.rs @@ -1,7 +1,7 @@ use tokio_postgres::error::DbError; use crate::{ - connector::error::PostgresError, + connector::postgres::wasm::error::PostgresError, error::{Error, ErrorKind}, }; diff --git a/quaint/src/connector/postgres_wasm.rs b/quaint/src/connector/postgres_wasm.rs deleted file mode 100644 index 4c67b98cfa42..000000000000 --- a/quaint/src/connector/postgres_wasm.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::{ - borrow::Cow, - fmt::{Debug, Display}, - time::Duration, -}; - -use percent_encoding::percent_decode; -use url::{Host, Url}; - -use crate::error::{Error, ErrorKind}; - -#[cfg(feature = "postgresql-connector")] -use tokio_postgres::config::{ChannelBinding, SslMode}; - -#[derive(Clone)] -pub(crate) struct Hidden(pub(crate) T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - pub(super) certificate_file: Option, - pub(super) identity_file: Option, - pub(super) identity_password: Hidden>, - pub(super) ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - pub(super) fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - pub(super) fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - pub(super) fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - pub(super) url: Url, - pub(super) query_params: PostgresUrlQueryParams, - pub(super) flavour: PostgresFlavour, -} - -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "postgresql-connector")] - let mut ssl_mode = SslMode::Prefer; - #[cfg(feature = "postgresql-connector")] - let mut channel_binding = ChannelBinding::Prefer; - - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut host = None; - let mut application_name = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - #[cfg(feature = "postgresql-connector")] - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - #[cfg(feature = "postgresql-connector")] - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - options, - #[cfg(feature = "postgresql-connector")] - channel_binding, - #[cfg(feature = "postgresql-connector")] - ssl_mode, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - pub(crate) ssl_params: SslParams, - pub(crate) connection_limit: Option, - pub(crate) schema: Option, - pub(crate) pg_bouncer: bool, - pub(crate) host: Option, - pub(crate) socket_timeout: Option, - pub(crate) connect_timeout: Option, - pub(crate) pool_timeout: Option, - pub(crate) statement_cache_size: usize, - pub(crate) max_connection_lifetime: Option, - pub(crate) max_idle_connection_lifetime: Option, - pub(crate) application_name: Option, - pub(crate) options: Option, - - #[cfg(feature = "postgresql-connector")] - pub(crate) channel_binding: ChannelBinding, - - #[cfg(feature = "postgresql-connector")] - pub(crate) ssl_mode: SslMode, -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 73bf5c405c66..f8202b030466 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -6,7 +6,7 @@ use thiserror::Error; #[cfg(feature = "pooled")] use std::time::Duration; -// pub use crate::connector::mysql::MysqlError; +pub use crate::connector::mysql::MysqlError; pub use crate::connector::postgres::PostgresError; // pub use crate::connector::sqlite::SqliteError; From 060486d74e7525d7cd61c51accfb2604d629ad75 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 16:05:41 +0100 Subject: [PATCH 004/134] feat(quaint): recover wasm error for mysql --- quaint/src/connector/mysql/wasm/error.rs | 43 ++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/quaint/src/connector/mysql/wasm/error.rs b/quaint/src/connector/mysql/wasm/error.rs index c09ec84d7a7b..615f0c69dda4 100644 --- a/quaint/src/connector/mysql/wasm/error.rs +++ b/quaint/src/connector/mysql/wasm/error.rs @@ -1,5 +1,15 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use thiserror::Error; +#[derive(Debug, Error)] +enum MysqlAsyncError { + #[error("Server error: `{}'", _0)] + Server(#[source] MysqlError), +} + +/// This type represents MySql server error. +#[derive(Debug, Error, Clone, Eq, PartialEq)] +#[error("ERROR {} ({}): {}", state, code, message)] pub struct MysqlError { pub code: u16, pub message: String, @@ -219,23 +229,22 @@ impl From for Error { builder.set_original_message(error.message); builder.build() } - _ => unimplemented!(), - // _ => { - // let kind = ErrorKind::QueryError( - // my::Error::Server(my::ServerError { - // message: error.message.clone(), - // code, - // state: error.state.clone(), - // }) - // .into(), - // ); - - // let mut builder = Error::builder(kind); - // builder.set_original_code(format!("{code}")); - // builder.set_original_message(error.message); - - // builder.build() - // } + _ => { + let kind = ErrorKind::QueryError( + MysqlAsyncError::Server(MysqlError { + message: error.message.clone(), + code, + state: error.state.clone(), + }) + .into(), + ); + + let mut builder = Error::builder(kind); + builder.set_original_code(format!("{code}")); + builder.set_original_message(error.message); + + builder.build() + } } } } From 5de1dc0c34513191b1e635a5b4a13c9236fe1399 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 18:19:23 +0100 Subject: [PATCH 005/134] feat(quaint): split mssql connector into native and wasm submodules --- quaint/src/connector.rs | 23 +- quaint/src/connector/mssql.rs | 256 +------------- .../src/connector/mssql/native/conversion.rs | 87 +++++ quaint/src/connector/mssql/native/error.rs | 247 ++++++++++++++ quaint/src/connector/mssql/native/mod.rs | 253 ++++++++++++++ .../{mssql_wasm.rs => mssql/wasm/common.rs} | 45 ++- quaint/src/connector/mssql/wasm/mod.rs | 5 + quaint/src/connector/mysql_wasm.rs | 318 ------------------ 8 files changed, 634 insertions(+), 600 deletions(-) create mode 100644 quaint/src/connector/mssql/native/conversion.rs create mode 100644 quaint/src/connector/mssql/native/error.rs create mode 100644 quaint/src/connector/mssql/native/mod.rs rename quaint/src/connector/{mssql_wasm.rs => mssql/wasm/common.rs} (91%) create mode 100644 quaint/src/connector/mssql/wasm/mod.rs delete mode 100644 quaint/src/connector/mysql_wasm.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index d0e4d7e57bdc..32f9e6186890 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -23,10 +23,10 @@ mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql-connector")] -pub(crate) mod mssql; -#[cfg(feature = "mssql")] -pub(crate) mod mssql_wasm; +// #[cfg(feature = "mssql-connector")] +// pub(crate) mod mssql; +// #[cfg(feature = "mssql")] +// pub(crate) mod mssql_wasm; // #[cfg(feature = "mysql-connector")] // pub(crate) mod mysql; // #[cfg(feature = "mysql")] @@ -48,10 +48,10 @@ pub(crate) mod sqlite_wasm; // pub use self::postgres::*; // #[cfg(feature = "postgresql")] // pub use self::postgres_wasm::*; -#[cfg(feature = "mssql-connector")] -pub use mssql::*; -#[cfg(feature = "mssql")] -pub use mssql_wasm::*; +// #[cfg(feature = "mssql-connector")] +// pub use mssql::*; +// #[cfg(feature = "mssql")] +// pub use mssql_wasm::*; #[cfg(feature = "sqlite-connector")] pub use sqlite::*; #[cfg(feature = "sqlite")] @@ -84,3 +84,10 @@ pub(crate) mod mysql; pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; + +#[cfg(feature = "mssql")] +pub(crate) mod mssql; +#[cfg(feature = "mssql-connector")] +pub use mssql::native::*; +#[cfg(feature = "mssql")] +pub use mssql::wasm::common::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index 16c31551768c..ea681bd08d18 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,253 +1,7 @@ -mod conversion; -mod error; +pub use wasm::common::MssqlUrl; -pub(crate) use super::mssql_wasm::MssqlUrl; -use super::{IsolationLevel, Transaction, TransactionOptions}; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::lock::Mutex; -use std::{ - convert::TryFrom, - fmt, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tiberius::*; -use tokio::net::TcpStream; -use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +#[cfg(feature = "mssql")] +pub(crate) mod wasm; -/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tiberius; - -static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; - -#[async_trait] -impl TransactionCapable for Mssql { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> crate::Result> { - // Isolation levels in SQL Server are set on the connection and live until they're changed. - // Always explicitly setting the isolation level each time a tx is started (either to the given value - // or by using the default/connection string value) prevents transactions started on connections from - // the pool to have unexpected isolation levels set. - let isolation = isolation - .or(self.url.query_params.transaction_isolation_level) - .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) - } -} - -/// A connector interface for the SQL Server database. -#[derive(Debug)] -pub struct Mssql { - client: Mutex>>, - url: MssqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, -} - -impl Mssql { - /// Creates a new connection to SQL Server. - pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_jdbc_string(&url.connection_string)?; - let tcp = TcpStream::connect_named(&config).await?; - let socket_timeout = url.socket_timeout(); - - let connecting = async { - match Client::connect(config, tcp.compat_write()).await { - Ok(client) => Ok(client), - Err(tiberius::error::Error::Routing { host, port }) => { - let mut config = Config::from_jdbc_string(&url.connection_string)?; - config.host(host); - config.port(port); - - let tcp = TcpStream::connect_named(&config).await?; - Client::connect(config, tcp.compat_write()).await - } - Err(e) => Err(e), - } - }; - - let client = super::timeout::connect(url.connect_timeout(), connecting).await?; - - let this = Self { - client: Mutex::new(client), - url, - socket_timeout, - is_healthy: AtomicBool::new(true), - }; - - if let Some(isolation) = this.url.transaction_isolation_level() { - this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) - .await?; - }; - - Ok(this) - } - - /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. - /// This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &Mutex>> { - &self.client - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } -} - -#[async_trait] -impl Queryable for Mssql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { - let mut client = self.client.lock().await; - - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; - - match results.pop() { - Some(rows) => { - let mut columns_set = false; - let mut columns = Vec::new(); - let mut result_rows = Vec::with_capacity(rows.len()); - - for row in rows.into_iter() { - if !columns_set { - columns = row.columns().iter().map(|c| c.name().to_string()).collect(); - columns_set = true; - } - - let mut values: Vec> = Vec::with_capacity(row.len()); - - for val in row.into_iter() { - values.push(Value::try_from(val)?); - } - - result_rows.push(values); - } - - Ok(ResultSet::new(columns, result_rows)) - } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), - } - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut client = self.client.lock().await; - let changes = self.perform_io(query.execute(&mut client)).await?.total(); - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { - let mut client = self.client.lock().await; - self.perform_io(client.simple_query(cmd)).await?.into_results().await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mssql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs new file mode 100644 index 000000000000..870654ad5de3 --- /dev/null +++ b/quaint/src/connector/mssql/native/conversion.rs @@ -0,0 +1,87 @@ +use crate::ast::{Value, ValueType}; + +use bigdecimal::BigDecimal; +use std::{borrow::Cow, convert::TryFrom}; + +use tiberius::ToSql; +use tiberius::{ColumnData, FromSql, IntoSql}; + +impl<'a> IntoSql<'a> for &'a Value<'a> { + fn into_sql(self) -> ColumnData<'a> { + match &self.typed { + ValueType::Int32(val) => val.into_sql(), + ValueType::Int64(val) => val.into_sql(), + ValueType::Float(val) => val.into_sql(), + ValueType::Double(val) => val.into_sql(), + ValueType::Text(val) => val.as_deref().into_sql(), + ValueType::Bytes(val) => val.as_deref().into_sql(), + ValueType::Enum(val, _) => val.as_deref().into_sql(), + ValueType::Boolean(val) => val.into_sql(), + ValueType::Char(val) => val.as_ref().map(|val| format!("{val}")).into_sql(), + ValueType::Xml(val) => val.as_deref().into_sql(), + ValueType::Array(_) | ValueType::EnumArray(_, _) => panic!("Arrays are not supported on SQL Server."), + ValueType::Numeric(val) => (*val).to_sql(), + ValueType::Json(val) => val.as_ref().map(|val| serde_json::to_string(&val).unwrap()).into_sql(), + ValueType::Uuid(val) => val.into_sql(), + ValueType::DateTime(val) => val.into_sql(), + ValueType::Date(val) => val.into_sql(), + ValueType::Time(val) => val.into_sql(), + } + } +} + +impl TryFrom> for Value<'static> { + type Error = crate::error::Error; + + fn try_from(cd: ColumnData<'static>) -> crate::Result { + let res = match cd { + ColumnData::U8(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I16(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I32(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I64(num) => ValueType::Int64(num.map(i64::from)), + ColumnData::F32(num) => ValueType::Float(num), + ColumnData::F64(num) => ValueType::Double(num), + ColumnData::Bit(b) => ValueType::Boolean(b), + ColumnData::String(s) => ValueType::Text(s), + ColumnData::Guid(uuid) => ValueType::Uuid(uuid), + ColumnData::Binary(bytes) => ValueType::Bytes(bytes), + numeric @ ColumnData::Numeric(_) => ValueType::Numeric(BigDecimal::from_sql(&numeric)?), + dt @ ColumnData::DateTime(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + ValueType::DateTime(dt) + } + dt @ ColumnData::SmallDateTime(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + ValueType::DateTime(dt) + } + dt @ ColumnData::Time(_) => { + use tiberius::time::chrono::NaiveTime; + + ValueType::Time(NaiveTime::from_sql(&dt)?) + } + dt @ ColumnData::Date(_) => { + use tiberius::time::chrono::NaiveDate; + ValueType::Date(NaiveDate::from_sql(&dt)?) + } + dt @ ColumnData::DateTime2(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + + ValueType::DateTime(dt) + } + dt @ ColumnData::DateTimeOffset(_) => { + use tiberius::time::chrono::{DateTime, Utc}; + + ValueType::DateTime(DateTime::::from_sql(&dt)?) + } + ColumnData::Xml(cow) => ValueType::Xml(cow.map(|xml_data| Cow::Owned(xml_data.into_owned().into_string()))), + }; + + Ok(Value::from(res)) + } +} diff --git a/quaint/src/connector/mssql/native/error.rs b/quaint/src/connector/mssql/native/error.rs new file mode 100644 index 000000000000..f9b6f5e95ab6 --- /dev/null +++ b/quaint/src/connector/mssql/native/error.rs @@ -0,0 +1,247 @@ +use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use tiberius::error::IoErrorKind; + +impl From for Error { + fn from(e: tiberius::error::Error) -> Error { + match e { + tiberius::error::Error::Io { + kind: IoErrorKind::UnexpectedEof, + message, + } => { + let mut builder = Error::builder(ErrorKind::ConnectionClosed); + builder.set_original_message(message); + builder.build() + } + e @ tiberius::error::Error::Io { .. } => Error::builder(ErrorKind::ConnectionError(e.into())).build(), + tiberius::error::Error::Tls(message) => { + let message = format!( + "The TLS settings didn't allow the connection to be established. Please review your connection string. (error: {message})" + ); + + Error::builder(ErrorKind::TlsError { message }).build() + } + tiberius::error::Error::Server(e) if [3902u32, 3903u32, 3971u32].iter().any(|code| e.code() == *code) => { + let kind = ErrorKind::TransactionAlreadyClosed(e.message().to_string()); + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 8169 => { + let kind = ErrorKind::conversion(e.message().to_string()); + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 18456 => { + let user = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::AuthenticationFailed { user }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 4060 => { + let db_name = e.message().split('"').nth(1).into(); + let kind = ErrorKind::DatabaseDoesNotExist { db_name }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 515 => { + let constraint = e + .message() + .split_whitespace() + .nth(7) + .and_then(|s| s.split('\'').nth(1)) + .map(|s| DatabaseConstraint::fields(Some(s))) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::NullConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1801 => { + let db_name = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::DatabaseAlreadyExists { db_name }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2627 => { + let constraint = e + .message() + .split(". ") + .nth(1) + .and_then(|s| s.split(' ').last()) + .and_then(|s| s.split('\'').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 547 => { + let constraint = e + .message() + .split('.') + .next() + .and_then(|s| s.split_whitespace().last()) + .and_then(|s| s.split('\"').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::ForeignKeyConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1505 => { + let constraint = e + .message() + .split('\'') + .nth(3) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2601 => { + let constraint = e + .message() + .split_whitespace() + .nth(11) + .and_then(|s| s.split('\'').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1801 => { + let db_name = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::DatabaseAlreadyExists { db_name }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2628 => { + let column = e.message().split('\'').nth(3).into(); + let kind = ErrorKind::LengthMismatch { column }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 208 => { + let table = e + .message() + .split_whitespace() + .nth(3) + .and_then(|s| s.split('\'').nth(1)) + .into(); + + let kind = ErrorKind::TableDoesNotExist { table }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 207 => { + let column = e + .message() + .split_whitespace() + .nth(3) + .and_then(|s| s.split('\'').nth(1)) + .into(); + + let kind = ErrorKind::ColumnNotFound { column }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1205 => { + let mut builder = Error::builder(ErrorKind::TransactionWriteConflict); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 3903 => { + let mut builder = Error::builder(ErrorKind::RollbackWithoutBegin); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) => { + let kind = ErrorKind::QueryError(e.clone().into()); + + let mut builder = Error::builder(kind); + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs new file mode 100644 index 000000000000..a1ea3bd5394d --- /dev/null +++ b/quaint/src/connector/mssql/native/mod.rs @@ -0,0 +1,253 @@ +mod conversion; +mod error; + +pub(crate) use crate::connector::mssql::wasm::common::MssqlUrl; +use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{ + convert::TryFrom, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tiberius::*; +use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; + +/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tiberius; + +static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + // Isolation levels in SQL Server are set on the connection and live until they're changed. + // Always explicitly setting the isolation level each time a tx is started (either to the given value + // or by using the default/connection string value) prevents transactions started on connections from + // the pool to have unexpected isolation levels set. + let isolation = isolation + .or(self.url.query_params.transaction_isolation_level) + .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); + + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} + +/// A connector interface for the SQL Server database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, +} + +impl Mssql { + /// Creates a new connection to SQL Server. + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_jdbc_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let socket_timeout = url.socket_timeout(); + + let connecting = async { + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(tiberius::error::Error::Routing { host, port }) => { + let mut config = Config::from_jdbc_string(&url.connection_string)?; + config.host(host); + config.port(port); + + let tcp = TcpStream::connect_named(&config).await?; + Client::connect(config, tcp.compat_write()).await + } + Err(e) => Err(e), + } + }; + + let client = timeout::connect(url.connect_timeout(), connecting).await?; + + let this = Self { + client: Mutex::new(client), + url, + socket_timeout, + is_healthy: AtomicBool::new(true), + }; + + if let Some(isolation) = this.url.transaction_isolation_level() { + this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) + .await?; + }; + + Ok(this) + } + + /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. + /// This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &Mutex>> { + &self.client + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; + + match results.pop() { + Some(rows) => { + let mut columns_set = false; + let mut columns = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); + + for row in rows.into_iter() { + if !columns_set { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + columns_set = true; + } + + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result_rows.push(values); + } + + Ok(ResultSet::new(columns, result_rows)) + } + None => Ok(ResultSet::new(Vec::new(), Vec::new())), + } + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut client = self.client.lock().await; + let changes = self.perform_io(query.execute(&mut client)).await?.total(); + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.perform_io(client.simple_query(cmd)).await?.into_results().await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } + + fn requires_isolation_first(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mssql_wasm.rs b/quaint/src/connector/mssql/wasm/common.rs similarity index 91% rename from quaint/src/connector/mssql_wasm.rs rename to quaint/src/connector/mssql/wasm/common.rs index d9f7dc27865b..5b6ee881d3e9 100644 --- a/quaint/src/connector/mssql_wasm.rs +++ b/quaint/src/connector/mssql/wasm/common.rs @@ -1,8 +1,7 @@ -#![cfg_attr(target_arch = "wasm32", allow(dead_code))] - -use super::IsolationLevel; - -use crate::error::{Error, ErrorKind}; +use crate::{ + connector::IsolationLevel, + error::{Error, ErrorKind}, +}; use connection_string::JdbcString; use std::{fmt, str::FromStr, time::Duration}; @@ -10,8 +9,8 @@ use std::{fmt, str::FromStr, time::Duration}; /// including default values. #[derive(Debug, Clone)] pub struct MssqlUrl { - pub(super) connection_string: String, - pub(super) query_params: MssqlQueryParams, + pub(crate) connection_string: String, + pub(crate) query_params: MssqlQueryParams, } /// TLS mode when connecting to SQL Server. @@ -51,22 +50,22 @@ impl FromStr for EncryptMode { #[derive(Debug, Clone)] pub(crate) struct MssqlQueryParams { - pub(super) encrypt: EncryptMode, - pub(super) port: Option, - pub(super) host: Option, - pub(super) user: Option, - pub(super) password: Option, - pub(super) database: String, - pub(super) schema: String, - pub(super) trust_server_certificate: bool, - pub(super) trust_server_certificate_ca: Option, - pub(super) connection_limit: Option, - pub(super) socket_timeout: Option, - pub(super) connect_timeout: Option, - pub(super) pool_timeout: Option, - pub(super) transaction_isolation_level: Option, - pub(super) max_connection_lifetime: Option, - pub(super) max_idle_connection_lifetime: Option, + pub(crate) encrypt: EncryptMode, + pub(crate) port: Option, + pub(crate) host: Option, + pub(crate) user: Option, + pub(crate) password: Option, + pub(crate) database: String, + pub(crate) schema: String, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) connection_limit: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) transaction_isolation_level: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, } impl MssqlUrl { diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs new file mode 100644 index 000000000000..69f1f46f7d21 --- /dev/null +++ b/quaint/src/connector/mssql/wasm/mod.rs @@ -0,0 +1,5 @@ +///! Wasm-compatible definitions for the MSSQL connector. +/// This module is only available with the `mssql` feature. +pub(crate) mod common; + +pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql_wasm.rs b/quaint/src/connector/mysql_wasm.rs deleted file mode 100644 index 24cd525fea33..000000000000 --- a/quaint/src/connector/mysql_wasm.rs +++ /dev/null @@ -1,318 +0,0 @@ -#![cfg_attr(target_arch = "wasm32", allow(dead_code))] - -use crate::error::{Error, ErrorKind}; -use percent_encoding::percent_decode; -use std::{ - borrow::Cow, - path::{Path, PathBuf}, - time::Duration, -}; -use url::{Host, Url}; - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - pub(super) query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - pub(super) fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - - fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "mysql-connector")] - let mut ssl_opts = { - let mut ssl_opts = mysql_async::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - ssl_opts - }; - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - - #[cfg(feature = "mysql-connector")] - { - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - #[cfg(feature = "mysql-connector")] - { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - // Wrapping this in a block, as attributes on expressions are still experimental - // See: https://github.com/rust-lang/rust/issues/15701 - #[cfg(feature = "mysql-connector")] - { - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - } - - Ok(MysqlUrlQueryParams { - #[cfg(feature = "mysql-connector")] - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - pub(crate) connection_limit: Option, - pub(crate) use_ssl: bool, - pub(crate) socket: Option, - pub(crate) socket_timeout: Option, - pub(crate) connect_timeout: Option, - pub(crate) pool_timeout: Option, - pub(crate) max_connection_lifetime: Option, - pub(crate) max_idle_connection_lifetime: Option, - pub(crate) prefer_socket: Option, - pub(crate) statement_cache_size: usize, - - #[cfg(feature = "mysql-connector")] - pub(crate) ssl_opts: mysql_async::SslOpts, -} From 8ecbc5c37d71513eeea080fd3bb1ef08618b080d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 12:25:20 +0100 Subject: [PATCH 006/134] feat(quaint): split sqlite connector into native and wasm submodules --- quaint/Cargo.toml | 6 +- quaint/src/connector.rs | 23 +- quaint/src/connector/sqlite.rs | 256 +----------------- .../sqlite/{ => native}/conversion.rs | 0 quaint/src/connector/sqlite/native/error.rs | 49 ++++ quaint/src/connector/sqlite/native/mod.rs | 252 +++++++++++++++++ .../{sqlite_wasm.rs => sqlite/wasm/common.rs} | 0 .../src/connector/sqlite/{ => wasm}/error.rs | 62 +---- quaint/src/connector/sqlite/wasm/mod.rs | 4 + quaint/src/error.rs | 2 +- 10 files changed, 336 insertions(+), 318 deletions(-) rename quaint/src/connector/sqlite/{ => native}/conversion.rs (100%) create mode 100644 quaint/src/connector/sqlite/native/error.rs create mode 100644 quaint/src/connector/sqlite/native/mod.rs rename quaint/src/connector/{sqlite_wasm.rs => sqlite/wasm/common.rs} (100%) rename quaint/src/connector/sqlite/{ => wasm}/error.rs (69%) create mode 100644 quaint/src/connector/sqlite/wasm/mod.rs diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 2da9ec0929c0..abe9fece9746 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -70,8 +70,8 @@ mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] mysql = ["chrono/std"] pooled = ["mobc"] -sqlite-connector = ["sqlite", "rusqlite", "tokio/sync"] -sqlite = [] +sqlite-connector = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite = ["rusqlite"] fmt-sql = ["sqlformat"] @@ -127,7 +127,7 @@ branch = "vendored-openssl" [dependencies.rusqlite] version = "0.29" -features = ["chrono", "bundled", "column_decltype"] +features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 32f9e6186890..b182e60a4387 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -35,10 +35,10 @@ mod type_identifier; // pub(crate) mod postgres; // #[cfg(feature = "postgresql")] // pub(crate) mod postgres_wasm; -#[cfg(feature = "sqlite-connector")] -pub(crate) mod sqlite; -#[cfg(feature = "sqlite")] -pub(crate) mod sqlite_wasm; +// #[cfg(feature = "sqlite-connector")] +// pub(crate) mod sqlite; +// #[cfg(feature = "sqlite")] +// pub(crate) mod sqlite_wasm; // #[cfg(feature = "mysql-connector")] // pub use self::mysql::*; @@ -52,10 +52,10 @@ pub(crate) mod sqlite_wasm; // pub use mssql::*; // #[cfg(feature = "mssql")] // pub use mssql_wasm::*; -#[cfg(feature = "sqlite-connector")] -pub use sqlite::*; -#[cfg(feature = "sqlite")] -pub use sqlite_wasm::*; +// #[cfg(feature = "sqlite-connector")] +// pub use sqlite::*; +// #[cfg(feature = "sqlite")] +// pub use sqlite_wasm::*; pub use self::result_set::*; pub use connection_info::*; @@ -85,6 +85,13 @@ pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite; +#[cfg(feature = "sqlite-connector")] +pub use sqlite::native::*; +#[cfg(feature = "sqlite")] +pub use sqlite::wasm::common::*; + #[cfg(feature = "mssql")] pub(crate) mod mssql; #[cfg(feature = "mssql-connector")] diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index fc993c1eaf0e..0e699c211878 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,253 +1,7 @@ -mod conversion; -mod error; +pub use wasm::error::SqliteError; -pub(crate) use super::sqlite_wasm::{SqliteParams, DEFAULT_SQLITE_SCHEMA_NAME}; -pub use error::SqliteError; +#[cfg(feature = "sqlite")] +pub(crate) mod wasm; -pub use rusqlite::{params_from_iter, version as sqlite_version}; - -use super::IsolationLevel; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use std::convert::TryFrom; -use tokio::sync::Mutex; - -/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use rusqlite; - -/// A connector interface for the SQLite database -pub struct Sqlite { - pub(crate) client: Mutex, -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; - let file_path = params.file_path; - - let conn = rusqlite::Connection::open(file_path.as_str())?; - - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; - - let client = Mutex::new(conn); - - Ok(Sqlite { client }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) - } - - /// Open a new SQLite database in memory. - pub fn new_in_memory() -> crate::Result { - let client = rusqlite::Connection::open_in_memory()?; - - Ok(Sqlite { - client: Mutex::new(client), - }) - } - - /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo - /// feature. This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn connection(&self) -> &Mutex { - &self.client - } -} - -impl_default_TransactionCapable!(Sqlite); - -#[async_trait] -impl Queryable for Sqlite { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - - let mut stmt = client.prepare_cached(sql)?; - - let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); - - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); - } - - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; - - Ok(res) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) - } - - fn is_healthy(&self) -> bool { - true - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - // SQLite is always "serializable", other modes involve pragmas - // and shared cache mode, which is out of scope for now and should be implemented - // as part of a separate effort. - if !matches!(isolation_level, IsolationLevel::Serializable) { - let kind = ErrorKind::invalid_isolation_level(&isolation_level); - return Err(Error::builder(kind).build()); - } - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ast::*, - connector::Queryable, - error::{ErrorKind, Name}, - }; - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[tokio::test] - async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); - let select = Select::from_table("not_there"); - - let err = conn.select(select).await.unwrap_err(); - - match err.kind() { - ErrorKind::TableDoesNotExist { table } => { - assert_eq!(&Name::available("not_there"), table); - } - e => panic!("Expected error TableDoesNotExist, got {:?}", e), - } - } - - #[tokio::test] - async fn in_memory_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); - - // Check that we do get a separate, new database. - let other_conn = Sqlite::new_in_memory().unwrap(); - - let err = other_conn.select(select).await.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); - } - - #[tokio::test] - async fn quoting_in_returning_in_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - let insert: Insert = Insert::from(insert).returning(["txt space"]); - - let result = conn.insert(insert).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - } -} +#[cfg(feature = "sqlite-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs similarity index 100% rename from quaint/src/connector/sqlite/conversion.rs rename to quaint/src/connector/sqlite/native/conversion.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs new file mode 100644 index 000000000000..9e2b2e7c3ea1 --- /dev/null +++ b/quaint/src/connector/sqlite/native/error.rs @@ -0,0 +1,49 @@ +use crate::connector::sqlite::wasm::error::SqliteError; + +use crate::error::*; + +impl From for Error { + fn from(e: rusqlite::Error) -> Error { + match e { + rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { + Ok(error) => *error, + Err(error) => { + let mut builder = Error::builder(ErrorKind::QueryError(error)); + + builder.set_original_message("Could not interpret parameters in an SQLite query."); + + builder.build() + } + }, + rusqlite::Error::InvalidQuery => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + builder.set_original_message( + "Could not interpret the query or its parameters. Check the syntax and parameter types.", + ); + + builder.build() + } + rusqlite::Error::ExecuteReturnedResults => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + builder.set_original_message("Execute returned results, which is not allowed in SQLite."); + + builder.build() + } + + rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), + + rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code: _, extended_code }, message) => { + SqliteError::new(extended_code, message).into() + } + + rusqlite::Error::SqlInputError { + error: rusqlite::ffi::Error { extended_code, .. }, + msg, + .. + } => SqliteError::new(extended_code, Some(msg)).into(), + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs new file mode 100644 index 000000000000..e11f6cd021bc --- /dev/null +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -0,0 +1,252 @@ +mod conversion; +mod error; + +use crate::connector::sqlite::wasm::common::SqliteParams; +use crate::connector::IsolationLevel; + +pub use rusqlite::{params_from_iter, version as sqlite_version}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use std::convert::TryFrom; +use tokio::sync::Mutex; + +/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use rusqlite; + +/// A connector interface for the SQLite database +pub struct Sqlite { + pub(crate) client: Mutex, +} + +impl TryFrom<&str> for Sqlite { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let params = SqliteParams::try_from(path)?; + let file_path = params.file_path; + + let conn = rusqlite::Connection::open(file_path.as_str())?; + + if let Some(timeout) = params.socket_timeout { + conn.busy_timeout(timeout)?; + }; + + let client = Mutex::new(conn); + + Ok(Sqlite { client }) + } +} + +impl Sqlite { + pub fn new(file_path: &str) -> crate::Result { + Self::try_from(file_path) + } + + /// Open a new SQLite database in memory. + pub fn new_in_memory() -> crate::Result { + let client = rusqlite::Connection::open_in_memory()?; + + Ok(Sqlite { + client: Mutex::new(client), + }) + } + + /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo + /// feature. This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn connection(&self) -> &Mutex { + &self.client + } +} + +impl_default_TransactionCapable!(Sqlite); + +#[async_trait] +impl Queryable for Sqlite { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + + let mut stmt = client.prepare_cached(sql)?; + + let mut rows = stmt.query(params_from_iter(params.iter()))?; + let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + + while let Some(row) = rows.next()? { + result.rows.push(row.get_result_row()?); + } + + result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + let mut stmt = client.prepare_cached(sql)?; + let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; + + Ok(res) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + let client = self.client.lock().await; + client.execute_batch(cmd)?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + Ok(Some(rusqlite::version().into())) + } + + fn is_healthy(&self) -> bool { + true + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + // SQLite is always "serializable", other modes involve pragmas + // and shared cache mode, which is out of scope for now and should be implemented + // as part of a separate effort. + if !matches!(isolation_level, IsolationLevel::Serializable) { + let kind = ErrorKind::invalid_isolation_level(&isolation_level); + return Err(Error::builder(kind).build()); + } + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::*, + connector::Queryable, + error::{ErrorKind, Name}, + }; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[tokio::test] + async fn unknown_table_should_give_a_good_error() { + let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let select = Select::from_table("not_there"); + + let err = conn.select(select).await.unwrap_err(); + + match err.kind() { + ErrorKind::TableDoesNotExist { table } => { + assert_eq!(&Name::available("not_there"), table); + } + e => panic!("Expected error TableDoesNotExist, got {:?}", e), + } + } + + #[tokio::test] + async fn in_memory_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); + + // Check that we do get a separate, new database. + let other_conn = Sqlite::new_in_memory().unwrap(); + + let err = other_conn.select(select).await.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); + } + + #[tokio::test] + async fn quoting_in_returning_in_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + let insert: Insert = Insert::from(insert).returning(["txt space"]); + + let result = conn.insert(insert).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + } +} diff --git a/quaint/src/connector/sqlite_wasm.rs b/quaint/src/connector/sqlite/wasm/common.rs similarity index 100% rename from quaint/src/connector/sqlite_wasm.rs rename to quaint/src/connector/sqlite/wasm/common.rs diff --git a/quaint/src/connector/sqlite/error.rs b/quaint/src/connector/sqlite/wasm/error.rs similarity index 69% rename from quaint/src/connector/sqlite/error.rs rename to quaint/src/connector/sqlite/wasm/error.rs index c10b335cb3c0..9cd0ef64e8a4 100644 --- a/quaint/src/connector/sqlite/error.rs +++ b/quaint/src/connector/sqlite/wasm/error.rs @@ -1,8 +1,6 @@ use std::fmt; use crate::error::*; -use rusqlite::ffi; -use rusqlite::types::FromSqlError; #[derive(Debug)] pub struct SqliteError { @@ -16,7 +14,7 @@ impl fmt::Display for SqliteError { f, "Error code {}: {}", self.extended_code, - ffi::code_to_str(self.extended_code) + rusqlite::ffi::code_to_str(self.extended_code) ) } } @@ -37,7 +35,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE | rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -58,7 +56,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -79,7 +77,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_FOREIGNKEY | ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | rusqlite::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -92,7 +90,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == rusqlite::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -153,54 +151,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error)); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure(ffi::Error { code: _, extended_code }, message) => { - SqliteError::new(extended_code, message).into() - } - - rusqlite::Error::SqlInputError { - error: ffi::Error { extended_code, .. }, - msg, - .. - } => SqliteError::new(extended_code, Some(msg)).into(), - - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} - -impl From for Error { - fn from(e: FromSqlError) -> Error { +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() } } diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs new file mode 100644 index 000000000000..0dbbcd76daec --- /dev/null +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -0,0 +1,4 @@ +///! Wasm-compatible definitions for the SQLite connector. +/// /// This module is only available with the `sqlite` feature. +pub(crate) mod common; +pub mod error; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index f8202b030466..705bb6b37ee0 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -8,7 +8,7 @@ use std::time::Duration; pub use crate::connector::mysql::MysqlError; pub use crate::connector::postgres::PostgresError; -// pub use crate::connector::sqlite::SqliteError; +pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] pub enum DatabaseConstraint { From 45df24fdd0ba18c192ff3bbf4d363d69b8ae4f5e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 13:32:01 +0100 Subject: [PATCH 007/134] chore(quaint): fix clippy when compiling natively --- quaint/src/connector.rs | 34 --- quaint/src/connector/mssql/conversion.rs | 87 ------- quaint/src/connector/mssql/error.rs | 247 ------------------- quaint/src/connector/mssql/native/mod.rs | 3 + quaint/src/connector/mssql/wasm/mod.rs | 4 +- quaint/src/connector/mysql/native/mod.rs | 18 +- quaint/src/connector/mysql/wasm/mod.rs | 4 +- quaint/src/connector/postgres/native/mod.rs | 9 +- quaint/src/connector/postgres/wasm/common.rs | 142 ----------- quaint/src/connector/postgres/wasm/mod.rs | 4 +- quaint/src/connector/sqlite/native/mod.rs | 3 + quaint/src/connector/sqlite/wasm/mod.rs | 4 +- quaint/src/single.rs | 5 +- 13 files changed, 24 insertions(+), 540 deletions(-) delete mode 100644 quaint/src/connector/mssql/conversion.rs delete mode 100644 quaint/src/connector/mssql/error.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index b182e60a4387..0aaa19aa463b 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -23,40 +23,6 @@ mod timeout; mod transaction; mod type_identifier; -// #[cfg(feature = "mssql-connector")] -// pub(crate) mod mssql; -// #[cfg(feature = "mssql")] -// pub(crate) mod mssql_wasm; -// #[cfg(feature = "mysql-connector")] -// pub(crate) mod mysql; -// #[cfg(feature = "mysql")] -// pub(crate) mod mysql_wasm; -// #[cfg(feature = "postgresql-connector")] -// pub(crate) mod postgres; -// #[cfg(feature = "postgresql")] -// pub(crate) mod postgres_wasm; -// #[cfg(feature = "sqlite-connector")] -// pub(crate) mod sqlite; -// #[cfg(feature = "sqlite")] -// pub(crate) mod sqlite_wasm; - -// #[cfg(feature = "mysql-connector")] -// pub use self::mysql::*; -// #[cfg(feature = "mysql")] -// pub use self::mysql_wasm::*; -// #[cfg(feature = "postgresql-connector")] -// pub use self::postgres::*; -// #[cfg(feature = "postgresql")] -// pub use self::postgres_wasm::*; -// #[cfg(feature = "mssql-connector")] -// pub use mssql::*; -// #[cfg(feature = "mssql")] -// pub use mssql_wasm::*; -// #[cfg(feature = "sqlite-connector")] -// pub use sqlite::*; -// #[cfg(feature = "sqlite")] -// pub use sqlite_wasm::*; - pub use self::result_set::*; pub use connection_info::*; pub use queryable::*; diff --git a/quaint/src/connector/mssql/conversion.rs b/quaint/src/connector/mssql/conversion.rs deleted file mode 100644 index 870654ad5de3..000000000000 --- a/quaint/src/connector/mssql/conversion.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::ast::{Value, ValueType}; - -use bigdecimal::BigDecimal; -use std::{borrow::Cow, convert::TryFrom}; - -use tiberius::ToSql; -use tiberius::{ColumnData, FromSql, IntoSql}; - -impl<'a> IntoSql<'a> for &'a Value<'a> { - fn into_sql(self) -> ColumnData<'a> { - match &self.typed { - ValueType::Int32(val) => val.into_sql(), - ValueType::Int64(val) => val.into_sql(), - ValueType::Float(val) => val.into_sql(), - ValueType::Double(val) => val.into_sql(), - ValueType::Text(val) => val.as_deref().into_sql(), - ValueType::Bytes(val) => val.as_deref().into_sql(), - ValueType::Enum(val, _) => val.as_deref().into_sql(), - ValueType::Boolean(val) => val.into_sql(), - ValueType::Char(val) => val.as_ref().map(|val| format!("{val}")).into_sql(), - ValueType::Xml(val) => val.as_deref().into_sql(), - ValueType::Array(_) | ValueType::EnumArray(_, _) => panic!("Arrays are not supported on SQL Server."), - ValueType::Numeric(val) => (*val).to_sql(), - ValueType::Json(val) => val.as_ref().map(|val| serde_json::to_string(&val).unwrap()).into_sql(), - ValueType::Uuid(val) => val.into_sql(), - ValueType::DateTime(val) => val.into_sql(), - ValueType::Date(val) => val.into_sql(), - ValueType::Time(val) => val.into_sql(), - } - } -} - -impl TryFrom> for Value<'static> { - type Error = crate::error::Error; - - fn try_from(cd: ColumnData<'static>) -> crate::Result { - let res = match cd { - ColumnData::U8(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I16(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I32(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I64(num) => ValueType::Int64(num.map(i64::from)), - ColumnData::F32(num) => ValueType::Float(num), - ColumnData::F64(num) => ValueType::Double(num), - ColumnData::Bit(b) => ValueType::Boolean(b), - ColumnData::String(s) => ValueType::Text(s), - ColumnData::Guid(uuid) => ValueType::Uuid(uuid), - ColumnData::Binary(bytes) => ValueType::Bytes(bytes), - numeric @ ColumnData::Numeric(_) => ValueType::Numeric(BigDecimal::from_sql(&numeric)?), - dt @ ColumnData::DateTime(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - ValueType::DateTime(dt) - } - dt @ ColumnData::SmallDateTime(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - ValueType::DateTime(dt) - } - dt @ ColumnData::Time(_) => { - use tiberius::time::chrono::NaiveTime; - - ValueType::Time(NaiveTime::from_sql(&dt)?) - } - dt @ ColumnData::Date(_) => { - use tiberius::time::chrono::NaiveDate; - ValueType::Date(NaiveDate::from_sql(&dt)?) - } - dt @ ColumnData::DateTime2(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - - ValueType::DateTime(dt) - } - dt @ ColumnData::DateTimeOffset(_) => { - use tiberius::time::chrono::{DateTime, Utc}; - - ValueType::DateTime(DateTime::::from_sql(&dt)?) - } - ColumnData::Xml(cow) => ValueType::Xml(cow.map(|xml_data| Cow::Owned(xml_data.into_owned().into_string()))), - }; - - Ok(Value::from(res)) - } -} diff --git a/quaint/src/connector/mssql/error.rs b/quaint/src/connector/mssql/error.rs deleted file mode 100644 index f9b6f5e95ab6..000000000000 --- a/quaint/src/connector/mssql/error.rs +++ /dev/null @@ -1,247 +0,0 @@ -use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use tiberius::error::IoErrorKind; - -impl From for Error { - fn from(e: tiberius::error::Error) -> Error { - match e { - tiberius::error::Error::Io { - kind: IoErrorKind::UnexpectedEof, - message, - } => { - let mut builder = Error::builder(ErrorKind::ConnectionClosed); - builder.set_original_message(message); - builder.build() - } - e @ tiberius::error::Error::Io { .. } => Error::builder(ErrorKind::ConnectionError(e.into())).build(), - tiberius::error::Error::Tls(message) => { - let message = format!( - "The TLS settings didn't allow the connection to be established. Please review your connection string. (error: {message})" - ); - - Error::builder(ErrorKind::TlsError { message }).build() - } - tiberius::error::Error::Server(e) if [3902u32, 3903u32, 3971u32].iter().any(|code| e.code() == *code) => { - let kind = ErrorKind::TransactionAlreadyClosed(e.message().to_string()); - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 8169 => { - let kind = ErrorKind::conversion(e.message().to_string()); - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 18456 => { - let user = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::AuthenticationFailed { user }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 4060 => { - let db_name = e.message().split('"').nth(1).into(); - let kind = ErrorKind::DatabaseDoesNotExist { db_name }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 515 => { - let constraint = e - .message() - .split_whitespace() - .nth(7) - .and_then(|s| s.split('\'').nth(1)) - .map(|s| DatabaseConstraint::fields(Some(s))) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::NullConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1801 => { - let db_name = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::DatabaseAlreadyExists { db_name }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2627 => { - let constraint = e - .message() - .split(". ") - .nth(1) - .and_then(|s| s.split(' ').last()) - .and_then(|s| s.split('\'').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 547 => { - let constraint = e - .message() - .split('.') - .next() - .and_then(|s| s.split_whitespace().last()) - .and_then(|s| s.split('\"').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::ForeignKeyConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1505 => { - let constraint = e - .message() - .split('\'') - .nth(3) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2601 => { - let constraint = e - .message() - .split_whitespace() - .nth(11) - .and_then(|s| s.split('\'').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1801 => { - let db_name = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::DatabaseAlreadyExists { db_name }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2628 => { - let column = e.message().split('\'').nth(3).into(); - let kind = ErrorKind::LengthMismatch { column }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 208 => { - let table = e - .message() - .split_whitespace() - .nth(3) - .and_then(|s| s.split('\'').nth(1)) - .into(); - - let kind = ErrorKind::TableDoesNotExist { table }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 207 => { - let column = e - .message() - .split_whitespace() - .nth(3) - .and_then(|s| s.split('\'').nth(1)) - .into(); - - let kind = ErrorKind::ColumnNotFound { column }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1205 => { - let mut builder = Error::builder(ErrorKind::TransactionWriteConflict); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 3903 => { - let mut builder = Error::builder(ErrorKind::RollbackWithoutBegin); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) => { - let kind = ErrorKind::QueryError(e.clone().into()); - - let mut builder = Error::builder(kind); - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index a1ea3bd5394d..6a1019c4f594 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the MSSQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mssql-connector` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs index 69f1f46f7d21..5a25a32836c2 100644 --- a/quaint/src/connector/mssql/wasm/mod.rs +++ b/quaint/src/connector/mssql/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the MSSQL connector. -/// This module is only available with the `mssql` feature. +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. pub(crate) mod common; pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 1a9652b628f8..234f7fb3d74f 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the MySQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mysql-connector` feature. mod conversion; mod error; @@ -63,21 +66,6 @@ impl MysqlUrl { } } -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - /// A connector interface for the MySQL database. #[derive(Debug)] pub struct Mysql { diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs index da9a57a53876..4f73f82031d5 100644 --- a/quaint/src/connector/mysql/wasm/mod.rs +++ b/quaint/src/connector/mysql/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the MySQL connector. -/// /// This module is only available with the `mysql` feature. +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 8f1645ca4123..a6628086aaae 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,11 +1,11 @@ -///! Definitions for the Postgres connector. -/// This module is not compatible with wasm32-* targets. -/// This module is only available with the `postgresql-connector` feature. +//! Definitions for the Postgres connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `postgresql-connector` feature. mod conversion; mod error; +pub(crate) use crate::connector::postgres::wasm::common::PostgresUrl; use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; -pub(crate) use crate::connector::postgres::wasm::common::{PostgresFlavour, PostgresUrl}; use crate::connector::{timeout, IsolationLevel, Transaction}; use crate::{ @@ -670,6 +670,7 @@ fn is_safe_identifier(ident: &str) -> bool { #[cfg(test)] mod tests { use super::*; + pub(crate) use crate::connector::postgres::wasm::common::PostgresFlavour; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; use crate::{connector::Queryable, error::*, single::Quaint}; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index 46d327c0183d..88145beb40de 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -468,145 +468,3 @@ impl Display for SetSearchPath<'_> { Ok(()) } } - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs index 5b330861a199..859de8f6fd3c 100644 --- a/quaint/src/connector/postgres/wasm/mod.rs +++ b/quaint/src/connector/postgres/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the Postgres connector. -/// /// This module is only available with the `postgresql` feature. +//! Wasm-compatible definitions for the Postgres connector. +//! This module is only available with the `postgresql` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index e11f6cd021bc..66f0e6d840df 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the SQLite connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `sqlite-connector` feature. mod conversion; mod error; diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs index 0dbbcd76daec..45307cccd0a3 100644 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -1,4 +1,4 @@ -///! Wasm-compatible definitions for the SQLite connector. -/// /// This module is only available with the `sqlite` feature. +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 2f234e40fd74..12bcf65c460a 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -1,7 +1,5 @@ //! A single connection abstraction to a SQL database. -#[cfg(feature = "sqlite")] -use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, @@ -9,7 +7,6 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -169,6 +166,8 @@ impl Quaint { #[cfg(feature = "sqlite-connector")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { + use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; + Ok(Quaint { inner: Arc::new(connector::Sqlite::new_in_memory()?), connection_info: Arc::new(ConnectionInfo::InMemorySqlite { From 6a1f733241372c0459797a215c11443d0e130bcf Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 14:35:42 +0100 Subject: [PATCH 008/134] chore(quaint): fix clippy when compiling to wasm32-unknown-unknown --- quaint/src/connector/mssql/wasm/common.rs | 2 ++ quaint/src/connector/mysql/wasm/common.rs | 2 ++ quaint/src/connector/postgres/wasm/common.rs | 2 ++ quaint/src/connector/sqlite/wasm/common.rs | 2 ++ quaint/src/error.rs | 2 +- quaint/src/single.rs | 2 ++ 6 files changed, 11 insertions(+), 1 deletion(-) diff --git a/quaint/src/connector/mssql/wasm/common.rs b/quaint/src/connector/mssql/wasm/common.rs index 5b6ee881d3e9..42cc0868f9bf 100644 --- a/quaint/src/connector/mssql/wasm/common.rs +++ b/quaint/src/connector/mssql/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::{ connector::IsolationLevel, error::{Error, ErrorKind}, diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs index fe60fd24cfc1..58598d6509ac 100644 --- a/quaint/src/connector/mysql/wasm/common.rs +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::error::{Error, ErrorKind}; use percent_encoding::percent_decode; use std::{ diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index 88145beb40de..c90826c40548 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use std::{ borrow::Cow, fmt::{Debug, Display}, diff --git a/quaint/src/connector/sqlite/wasm/common.rs b/quaint/src/connector/sqlite/wasm/common.rs index 10c174480785..46fb5c08f669 100644 --- a/quaint/src/connector/sqlite/wasm/common.rs +++ b/quaint/src/connector/sqlite/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::error::{Error, ErrorKind}; use std::{convert::TryFrom, path::Path, time::Duration}; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..f6ae3b3ee34a 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 12bcf65c460a..e4e72ab614fa 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -7,6 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; +#[cfg(feature = "sqlite-connector")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -124,6 +125,7 @@ impl Quaint { /// - `isolationLevel` the transaction isolation level. Possible values: /// `READ UNCOMMITTED`, `READ COMMITTED`, `REPEATABLE READ`, `SNAPSHOT`, /// `SERIALIZABLE`. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { From e61bf75be0c36fd603e37441965ab8935c99c487 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 16:01:12 +0100 Subject: [PATCH 009/134] chore(quaint): update README --- quaint/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/quaint/README.md b/quaint/README.md index 92033db269b1..3a9b41c65751 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,9 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. + - On non-WebAssembly targets, choose `mysql-connector` instead. - `postgresql`: Support for PostgreSQL databases. + - On non-WebAssembly targets, choose `postgresql-connector` instead. - `sqlite`: Support for SQLite databases. + - On non-WebAssembly targets, choose `sqlite-connector` instead. - `mssql`: Support for Microsoft SQL Server databases. + - On non-WebAssembly targets, choose `mssql-connector` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. From 257c4c86e10bae7e61a0a32d5b5069c3f84f407f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 09:45:23 +0100 Subject: [PATCH 010/134] chore(quaint): rename "*-connector" feature flag to "*-native" --- Cargo.toml | 2 +- quaint/Cargo.toml | 20 ++++++------- quaint/README.md | 8 +++--- quaint/src/connector.rs | 20 ++++--------- quaint/src/connector/mssql.rs | 2 +- quaint/src/connector/mssql/native/mod.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/mysql/native/mod.rs | 2 +- quaint/src/connector/mysql/wasm/common.rs | 12 ++++---- quaint/src/connector/postgres.rs | 2 +- quaint/src/connector/postgres/native/mod.rs | 2 +- quaint/src/connector/postgres/wasm/common.rs | 18 ++++++------ quaint/src/connector/sqlite.rs | 2 +- quaint/src/connector/sqlite/native/mod.rs | 2 +- quaint/src/error.rs | 2 +- quaint/src/pooled/manager.rs | 30 ++++++++++---------- quaint/src/single.rs | 12 ++++---- 17 files changed, 66 insertions(+), 74 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 66f4399ff6db..b32a1a85cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,7 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", - "connectors", + "native", ] [profile.dev.package.backtrace] diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index abe9fece9746..7c804add2f5e 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -29,21 +29,21 @@ docs = [] # way to access database-specific methods when you need extra control. expose-drivers = [] -connectors = [ - "postgresql-connector", - "mysql-connector", - "mssql-connector", - "sqlite-connector", +native = [ + "postgresql-native", + "mysql-native", + "mssql-native", + "sqlite-native", ] -all = ["connectors", "pooled"] +all = ["native", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql-connector = [ +postgresql-native = [ "postgresql", "native-tls", "tokio-postgres", @@ -57,7 +57,7 @@ postgresql-connector = [ ] postgresql = [] -mssql-connector = [ +mssql-native = [ "mssql", "tiberius", "tokio-util", @@ -66,11 +66,11 @@ mssql-connector = [ ] mssql = [] -mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] mysql = ["chrono/std"] pooled = ["mobc"] -sqlite-connector = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] sqlite = ["rusqlite"] fmt-sql = ["sqlformat"] diff --git a/quaint/README.md b/quaint/README.md index 3a9b41c65751..03108d9090d3 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,13 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. - - On non-WebAssembly targets, choose `mysql-connector` instead. + - On non-WebAssembly targets, choose `mysql-native` instead. - `postgresql`: Support for PostgreSQL databases. - - On non-WebAssembly targets, choose `postgresql-connector` instead. + - On non-WebAssembly targets, choose `postgresql-native` instead. - `sqlite`: Support for SQLite databases. - - On non-WebAssembly targets, choose `sqlite-connector` instead. + - On non-WebAssembly targets, choose `sqlite-native` instead. - `mssql`: Support for Microsoft SQL Server databases. - - On non-WebAssembly targets, choose `mssql-connector` instead. + - On non-WebAssembly targets, choose `mssql-native` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 0aaa19aa463b..7903d23931c0 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -14,11 +14,7 @@ mod connection_info; pub mod metrics; mod queryable; mod result_set; -#[cfg(any( - feature = "mssql-connector", - feature = "postgresql-connector", - feature = "mysql-connector" -))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] mod timeout; mod transaction; mod type_identifier; @@ -27,11 +23,7 @@ pub use self::result_set::*; pub use connection_info::*; pub use queryable::*; pub use transaction::*; -#[cfg(any( - feature = "mssql-connector", - feature = "postgresql-connector", - feature = "mysql-connector" -))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] #[allow(unused_imports)] pub(crate) use type_identifier::*; @@ -39,28 +31,28 @@ pub use self::metrics::query; #[cfg(feature = "postgresql")] pub(crate) mod postgres; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] pub use postgres::native::*; #[cfg(feature = "postgresql")] pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] pub use sqlite::native::*; #[cfg(feature = "sqlite")] pub use sqlite::wasm::common::*; #[cfg(feature = "mssql")] pub(crate) mod mssql; -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] pub use mssql::native::*; #[cfg(feature = "mssql")] pub use mssql::wasm::common::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index ea681bd08d18..c83b5f1f7266 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -3,5 +3,5 @@ pub use wasm::common::MssqlUrl; #[cfg(feature = "mssql")] pub(crate) mod wasm; -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 6a1019c4f594..8458935814b4 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the MSSQL connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `mssql-connector` feature. +//! This module is only available with the `mssql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 1794cc738b1e..1e52af6a83a0 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -4,5 +4,5 @@ pub use wasm::error::MysqlError; #[cfg(feature = "mysql")] pub(crate) mod wasm; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 234f7fb3d74f..e72a2c47a9a1 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the MySQL connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `mysql-connector` feature. +//! This module is only available with the `mysql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs index 58598d6509ac..c17b2224c0ef 100644 --- a/quaint/src/connector/mysql/wasm/common.rs +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -123,7 +123,7 @@ impl MysqlUrl { } fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] let mut ssl_opts = { let mut ssl_opts = mysql_async::SslOpts::default(); ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); @@ -159,7 +159,7 @@ impl MysqlUrl { "sslcert" => { use_ssl = true; - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); } @@ -219,7 +219,7 @@ impl MysqlUrl { use_ssl = true; match v.as_ref() { "strict" => { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); } @@ -263,7 +263,7 @@ impl MysqlUrl { // Wrapping this in a block, as attributes on expressions are still experimental // See: https://github.com/rust-lang/rust/issues/15701 - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = match identity { Some((Some(path), Some(pw))) => { @@ -279,7 +279,7 @@ impl MysqlUrl { } Ok(MysqlUrlQueryParams { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] ssl_opts, connection_limit, use_ssl, @@ -313,6 +313,6 @@ pub(crate) struct MysqlUrlQueryParams { pub(crate) prefer_socket: Option, pub(crate) statement_cache_size: usize, - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] pub(crate) ssl_opts: mysql_async::SslOpts, } diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 0f4da84a7c67..73a8547b8a65 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -4,5 +4,5 @@ pub use wasm::error::PostgresError; #[cfg(feature = "postgresql")] pub(crate) mod wasm; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index a6628086aaae..fbb4760ed19f 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the Postgres connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `postgresql-connector` feature. +//! This module is only available with the `postgresql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index c90826c40548..7b9b3aafabb4 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -11,7 +11,7 @@ use url::{Host, Url}; use crate::error::{Error, ErrorKind}; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] use tokio_postgres::config::{ChannelBinding, SslMode}; #[derive(Clone)] @@ -211,9 +211,9 @@ impl PostgresUrl { } fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] let mut ssl_mode = SslMode::Prefer; - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] let mut channel_binding = ChannelBinding::Prefer; let mut connection_limit = None; @@ -240,7 +240,7 @@ impl PostgresUrl { .parse() .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] "sslmode" => { match v.as_ref() { "disable" => ssl_mode = SslMode::Disable, @@ -348,7 +348,7 @@ impl PostgresUrl { "application_name" => { application_name = Some(v.to_string()); } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] "channel_binding" => { match v.as_ref() { "disable" => channel_binding = ChannelBinding::Disable, @@ -390,9 +390,9 @@ impl PostgresUrl { max_idle_connection_lifetime, application_name, options, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] channel_binding, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] ssl_mode, }) } @@ -427,10 +427,10 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) application_name: Option, pub(crate) options: Option, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] pub(crate) channel_binding: ChannelBinding, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] pub(crate) ssl_mode: SslMode, } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 0e699c211878..45611aab9357 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -3,5 +3,5 @@ pub use wasm::error::SqliteError; #[cfg(feature = "sqlite")] pub(crate) mod wasm; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 66f0e6d840df..bdf5c473fd4d 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the SQLite connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `sqlite-connector` feature. +//! This module is only available with the `sqlite-native` feature. mod conversion; mod error; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index f6ae3b3ee34a..a77513876726 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c31fd44fbcae..73441b7609ba 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index e4e72ab614fa..1a4dbdf52a61 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -7,7 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -129,27 +129,27 @@ impl Quaint { #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -165,7 +165,7 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; From 5ab6d9636220469772b7969d8c6db84701e6a196 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 12:37:51 +0100 Subject: [PATCH 011/134] feat(quaint): enable pure Wasm SqliteError --- quaint/Cargo.toml | 2 +- quaint/src/connector/sqlite/native/error.rs | 17 +++++++++++++ quaint/src/connector/sqlite/wasm/error.rs | 28 ++++++--------------- quaint/src/connector/sqlite/wasm/ffi.rs | 7 ++++++ quaint/src/connector/sqlite/wasm/mod.rs | 1 + 5 files changed, 34 insertions(+), 21 deletions(-) create mode 100644 quaint/src/connector/sqlite/wasm/ffi.rs diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 7c804add2f5e..52a7edf72aca 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -71,7 +71,7 @@ mysql = ["chrono/std"] pooled = ["mobc"] sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] -sqlite = ["rusqlite"] +sqlite = [] fmt-sql = ["sqlformat"] diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs index 9e2b2e7c3ea1..d09e2959ce28 100644 --- a/quaint/src/connector/sqlite/native/error.rs +++ b/quaint/src/connector/sqlite/native/error.rs @@ -2,6 +2,17 @@ use crate::connector::sqlite::wasm::error::SqliteError; use crate::error::*; +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error code {}: {}", + self.extended_code, + rusqlite::ffi::code_to_str(self.extended_code) + ) + } +} + impl From for Error { fn from(e: rusqlite::Error) -> Error { match e { @@ -47,3 +58,9 @@ impl From for Error { } } } + +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { + Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + } +} diff --git a/quaint/src/connector/sqlite/wasm/error.rs b/quaint/src/connector/sqlite/wasm/error.rs index 9cd0ef64e8a4..2c6ff11350fd 100644 --- a/quaint/src/connector/sqlite/wasm/error.rs +++ b/quaint/src/connector/sqlite/wasm/error.rs @@ -1,5 +1,3 @@ -use std::fmt; - use crate::error::*; #[derive(Debug)] @@ -8,14 +6,10 @@ pub struct SqliteError { pub message: Option, } -impl fmt::Display for SqliteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Error code {}: {}", - self.extended_code, - rusqlite::ffi::code_to_str(self.extended_code) - ) +#[cfg(not(feature = "sqlite-native"))] +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error code {}", self.extended_code) } } @@ -35,7 +29,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE | rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: super::ffi::SQLITE_CONSTRAINT_UNIQUE | super::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -56,7 +50,7 @@ impl From for Error { } SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: super::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -77,7 +71,7 @@ impl From for Error { } SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | rusqlite::ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: super::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | super::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -90,7 +84,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == rusqlite::ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == super::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -150,9 +144,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: rusqlite::types::FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() - } -} diff --git a/quaint/src/connector/sqlite/wasm/ffi.rs b/quaint/src/connector/sqlite/wasm/ffi.rs new file mode 100644 index 000000000000..bddfd4354237 --- /dev/null +++ b/quaint/src/connector/sqlite/wasm/ffi.rs @@ -0,0 +1,7 @@ +//! This is a partial copy of `rusqlite::ffi::*`. +pub const SQLITE_BUSY: i32 = 5; +pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; +pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; +pub const SQLITE_CONSTRAINT_PRIMARYKEY: i32 = 1555; +pub const SQLITE_CONSTRAINT_TRIGGER: i32 = 1811; +pub const SQLITE_CONSTRAINT_UNIQUE: i32 = 2067; diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs index 45307cccd0a3..662237af30a1 100644 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -2,3 +2,4 @@ //! This module is only available with the `sqlite` feature. pub(crate) mod common; pub mod error; +mod ffi; From ab65c9539cb5bfef59fd2c3f2187ec83d415e3fd Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 12:38:24 +0100 Subject: [PATCH 012/134] feat(query-connect): allow wasm32-unknown-unknown compilation --- libs/user-facing-errors/Cargo.toml | 2 +- query-engine/connectors/query-connector/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/user-facing-errors/Cargo.toml b/libs/user-facing-errors/Cargo.toml index 9900892209c6..3049a19712b1 100644 --- a/libs/user-facing-errors/Cargo.toml +++ b/libs/user-facing-errors/Cargo.toml @@ -11,7 +11,7 @@ backtrace = "0.3.40" tracing = "0.1" indoc.workspace = true itertools = "0.10" -quaint = { workspace = true, optional = true } +quaint = { path = "../../quaint", optional = true } [features] default = [] diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index d16771aa3daf..788b8ca65576 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -14,6 +14,6 @@ prisma-value = {path = "../../../libs/prisma-value"} serde.workspace = true serde_json.workspace = true thiserror = "1.0" -user-facing-errors = {path = "../../../libs/user-facing-errors"} +user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid = "1" indexmap = "1.7" From cfb550743b7c39dadf87672141bf2bc16c9318d4 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:54:54 +0100 Subject: [PATCH 013/134] feat(sql-query-connector): allow wasm32-unknown-unknown compilation --- .../connectors/sql-query-connector/Cargo.toml | 6 +++-- .../sql-query-connector/src/database/mod.rs | 24 ++++++++++++------- .../src/database/{ => native}/mssql.rs | 4 ++-- .../src/database/{ => native}/mysql.rs | 4 ++-- .../src/database/{ => native}/postgresql.rs | 4 ++-- .../src/database/{ => native}/sqlite.rs | 4 ++-- .../src/database/operations/write.rs | 21 +++++++++++++++- .../connectors/sql-query-connector/src/lib.rs | 5 +++- 8 files changed, 52 insertions(+), 20 deletions(-) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/mssql.rs (94%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/mysql.rs (95%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/postgresql.rs (95%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/sqlite.rs (96%) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 62d0be640761..fa9c32ef88e1 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -5,6 +5,8 @@ version = "0.1.0" [features] vendored-openssl = ["quaint/vendored-openssl"] + +# Enable Driver Adapters driver-adapters = [] [dependencies] @@ -18,13 +20,13 @@ once_cell = "1.3" rand = "0.7" serde_json = {version = "1.0", features = ["float_roundtrip"]} thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = "0.1" tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint.workspace = true +quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } [dependencies.connector-interface] diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 695db13b6620..7172e0101400 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -1,12 +1,16 @@ mod connection; #[cfg(feature = "driver-adapters")] mod js; -mod mssql; -mod mysql; -mod postgresql; -mod sqlite; mod transaction; +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod native { + pub(crate) mod mssql; + pub(crate) mod mysql; + pub(crate) mod postgresql; + pub(crate) mod sqlite; +} + pub(crate) mod operations; use async_trait::async_trait; @@ -14,10 +18,14 @@ use connector_interface::{error::ConnectorError, Connector}; #[cfg(feature = "driver-adapters")] pub use js::*; -pub use mssql::*; -pub use mysql::*; -pub use postgresql::*; -pub use sqlite::*; + +#[cfg(not(target_arch = "wasm32"))] +pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; + +// pub use mssql::*; +// pub use mysql::*; +// pub use postgresql::*; +// pub use sqlite::*; #[async_trait] pub trait FromSource { diff --git a/query-engine/connectors/sql-query-connector/src/database/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs similarity index 94% rename from query-engine/connectors/sql-query-connector/src/database/mssql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index 9655d205e4ca..bdb6e2ee103c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/mysql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index deb3e6a4f35f..a1cd585c0005 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/postgresql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 242b2b63090e..80025add046f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs similarity index 96% rename from query-engine/connectors/sql-query-connector/src/database/sqlite.rs rename to query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index 6be9faeac54d..b1250b18b2be 100644 --- a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info().clone(), async move { + super::super::catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 425f4ac1d4b3..611557c4f3ba 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -18,9 +18,28 @@ use std::{ ops::Deref, usize, }; -use tracing::log::trace; use user_facing_errors::query_engine::DatabaseConstraint; +#[cfg(target_arch = "wasm32")] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => {{ + // No-op in WebAssembly + }}; + ($($arg:tt)+) => {{ + // No-op in WebAssembly + }}; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => { + tracing::log::trace!(target: $target, $($arg)+); + }; + ($($arg:tt)+) => { + tracing::log::trace!($($arg)+); + }; +} + async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index ed1528ded6b5..74c0a4aab5d3 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -22,9 +22,12 @@ mod value_ext; use self::{column_metadata::*, context::Context, query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; +pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::{activate_driver_adapter, Js}; -pub use database::{FromSource, Mssql, Mysql, PostgreSql, Sqlite}; pub use error::SqlError; +#[cfg(not(target_arch = "wasm32"))] +pub use database::{Mssql, Mysql, PostgreSql, Sqlite}; + type Result = std::result::Result; From e7df5a3d0c219d7c59efa0822a3352ed30965e0f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:55:37 +0100 Subject: [PATCH 014/134] chore(query-engine-wasm): add currently unused local crates to test wasm32-unknown-unknown compilation --- query-engine/query-engine-wasm/Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index a8bc393aee3f..f65f31c2d63b 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -14,6 +14,9 @@ async-trait = "0.1" user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } +quaint = { path = "../../quaint" } +connector = { path = "../connectors/query-connector", package = "query-connector" } +sql-query-connector = { path = "../connectors/sql-query-connector" } thiserror = "1" connection-string.workspace = true From 8c5d3dc999167815c0dbc3f4f5fe31557a6086e0 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:55:48 +0100 Subject: [PATCH 015/134] chore: update Cargo.lock --- Cargo.lock | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 4c59bfea573b..b88de804c816 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3822,9 +3822,12 @@ dependencies = [ "log", "prisma-models", "psl", + "quaint", + "query-connector", "serde", "serde-wasm-bindgen", "serde_json", + "sql-query-connector", "thiserror", "tokio", "tracing", From 6648a882b4e2d0d8aa5449b2e94875cb807f2949 Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Tue, 14 Nov 2023 14:58:04 +0100 Subject: [PATCH 016/134] chore: remove leftover comments --- .../connectors/sql-query-connector/src/database/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 7172e0101400..e693769373b0 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -22,11 +22,6 @@ pub use js::*; #[cfg(not(target_arch = "wasm32"))] pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; -// pub use mssql::*; -// pub use mysql::*; -// pub use postgresql::*; -// pub use sqlite::*; - #[async_trait] pub trait FromSource { /// Instantiate a query connector from a Datasource. From 754746ecdaf0dadae3e44532bc268715ab3ce813 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 16:38:30 +0100 Subject: [PATCH 017/134] feat(query-core): allow wasm32-unknown-unknown compilation --- Cargo.lock | 3 ++ .../query-tests-setup/Cargo.toml | 2 +- query-engine/core-tests/Cargo.toml | 2 +- query-engine/core/Cargo.toml | 11 +++- .../core/src/executor/execute_operation.rs | 11 ++++ query-engine/core/src/executor/mod.rs | 51 +++++++++++++++++++ .../interactive_transactions/actor_manager.rs | 2 +- .../src/interactive_transactions/actors.rs | 15 ++++-- query-engine/core/src/lib.rs | 7 ++- query-engine/query-engine-node-api/Cargo.toml | 2 +- query-engine/query-engine-wasm/Cargo.toml | 1 + query-engine/query-engine/Cargo.toml | 2 +- query-engine/request-handlers/Cargo.toml | 2 +- 13 files changed, 97 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b88de804c816..50df863820fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3680,6 +3680,7 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "pin-project", "prisma-models", "psl", "query-connector", @@ -3695,6 +3696,7 @@ dependencies = [ "tracing-subscriber", "user-facing-errors", "uuid", + "wasm-bindgen-futures", ] [[package]] @@ -3824,6 +3826,7 @@ dependencies = [ "psl", "quaint", "query-connector", + "query-core", "serde", "serde-wasm-bindgen", "serde_json", diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index 088a0d4b2d34..f257d9e52162 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -10,7 +10,7 @@ once_cell = "1" qe-setup = { path = "../qe-setup" } request-handlers = { path = "../../request-handlers" } tokio.workspace = true -query-core = { path = "../../core" } +query-core = { path = "../../core", features = ["metrics"] } sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine"} psl.workspace = true diff --git a/query-engine/core-tests/Cargo.toml b/query-engine/core-tests/Cargo.toml index 9a2c3f5686eb..bac9219c3522 100644 --- a/query-engine/core-tests/Cargo.toml +++ b/query-engine/core-tests/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" dissimilar = "1.0.4" user-facing-errors = { path = "../../libs/user-facing-errors" } request-handlers = { path = "../request-handlers" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } schema = { path = "../schema" } psl.workspace = true serde_json.workspace = true diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index caadf6cdba00..6441abf8ca3a 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -3,6 +3,10 @@ edition = "2021" name = "query-core" version = "0.1.0" +[features] +# default = ["metrics"] +metrics = ["query-engine-metrics"] + [dependencies] async-trait = "0.1" bigdecimal = "0.3" @@ -18,11 +22,11 @@ once_cell = "1" petgraph = "0.4" prisma-models = { path = "../prisma-models", features = ["default_generators"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = {path = "../metrics"} +query-engine-metrics = { path = "../metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = { version = "0.1", features = ["attributes"] } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -34,3 +38,6 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +[target.'cfg(target_arch = "wasm32")'.dependencies] +pin-project = "1" +wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 06452fcdd865..6ba21d37f9ff 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(unused_variables))] + use super::pipeline::QueryPipeline; use crate::{ executor::request_context, protocol::EngineProtocol, CoreError, IrSerializer, Operation, QueryGraph, @@ -5,9 +7,12 @@ use crate::{ }; use connector::{Connection, ConnectionLike, Connector}; use futures::future; + +#[cfg(feature = "metrics")] use query_engine_metrics::{ histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, }; + use schema::{QuerySchema, QuerySchemaRef}; use std::time::{Duration, Instant}; use tracing::Instrument; @@ -24,6 +29,7 @@ pub async fn execute_single_operation( let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -45,6 +51,8 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = Instant::now(); let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); match result { @@ -98,6 +106,7 @@ pub async fn execute_many_self_contained( let dispatcher = crate::get_current_dispatcher(); for op in operations { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let conn_span = info_span!( @@ -158,6 +167,7 @@ async fn execute_self_contained( execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await }; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -259,6 +269,7 @@ async fn execute_on<'a>( query_schema: &'a QuerySchema, trace_id: Option, ) -> crate::Result { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let interpreter = QueryInterpreter::new(conn); diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ddbb7dfc8429..5ff9830013d6 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -12,6 +12,7 @@ mod pipeline; mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; +use futures::Future; pub(crate) use request_context::*; @@ -131,3 +132,53 @@ pub trait TransactionManager { pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod task { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +#[cfg(target_arch = "wasm32")] +pub(crate) mod task { + use super::*; + + #[pin_project::pin_project] + pub struct JoinHandle(#[pin] tokio::sync::oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let this = self.project(); + this.0.poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop for WASM builds + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (tx, rx) = tokio::sync::oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + tx.send(result).ok(); + }); + JoinHandle(rx) + } +} diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 98208343d28a..105733be4166 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -1,3 +1,4 @@ +use crate::executor::task::JoinHandle; use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; use connector::Connection; use lru::LruCache; @@ -9,7 +10,6 @@ use tokio::{ mpsc::{channel, Sender}, RwLock, }, - task::JoinHandle, time::Duration, }; diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 88402d86fedd..104ffc26812f 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,7 +1,8 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; +use crate::executor::task::{spawn, JoinHandle}; use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, - telemetry::helpers::set_span_link_from_traceparent, ClosedTx, Operation, ResponseData, TxId, + execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, + TxId, }; use connector::Connection; use schema::QuerySchemaRef; @@ -11,13 +12,15 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - task::JoinHandle, time::{self, Duration, Instant}, }; use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; +#[cfg(feature = "metrics")] +use crate::telemetry::helpers::set_span_link_from_traceparent; + #[derive(PartialEq)] enum RunState { Continue, @@ -81,6 +84,8 @@ impl<'a> ITXServer<'a> { traceparent: Option, ) -> crate::Result { let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); + + #[cfg(feature = "metrics")] set_span_link_from_traceparent(&span, traceparent.clone()); let conn = self.cached_tx.as_open()?; @@ -267,7 +272,7 @@ pub(crate) async fn spawn_itx_actor( }; let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - tokio::task::spawn( + spawn( crate::executor::with_request_context(engine_protocol, async move { // We match on the result in order to send the error to the parent task and abort this // task, on error. This is a separate task (actor), not a function where we can just bubble up the @@ -380,7 +385,7 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { - tokio::task::spawn(async move { + spawn(async move { loop { if let Some((id, closed_tx)) = rx.recv().await { trace!("removing {} from client list", id); diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 7970c96139b7..38f39e9fb5d9 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -9,6 +9,8 @@ pub mod protocol; pub mod query_document; pub mod query_graph_builder; pub mod response_ir; + +#[cfg(feature = "metrics")] pub mod telemetry; pub use self::{ @@ -16,8 +18,11 @@ pub use self::{ executor::{QueryExecutor, TransactionOptions}, interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, query_document::*, - telemetry::*, }; + +#[cfg(feature = "metrics")] +pub use self::telemetry::*; + pub use connector::{ error::{ConnectorError, ErrorKind as ConnectorErrorKind}, Connector, diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index 74f9686189fc..0eaed9eff7ce 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -16,7 +16,7 @@ driver-adapters = ["request-handlers/driver-adapters", "sql-connector/driver-ada [dependencies] anyhow = "1" async-trait = "0.1" -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } query-connector = { path = "../connectors/query-connector" } user-facing-errors = { path = "../../libs/user-facing-errors" } diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index f65f31c2d63b..c8bc6e2b5178 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -17,6 +17,7 @@ prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } connector = { path = "../connectors/query-connector", package = "query-connector" } sql-query-connector = { path = "../connectors/sql-query-connector" } +query-core = { path = "../core" } thiserror = "1" connection-string.workspace = true diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index be36e4f842dc..c70d8590d0ff 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -20,7 +20,7 @@ enumflags2 = { version = "0.7"} psl.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser" } mongodb-connector = { path = "../connectors/mongodb-query-connector", optional = true, package = "mongodb-query-connector" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } serde.workspace = true serde_json.workspace = true diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index f5fb433b13ba..e6545eda2234 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] prisma-models = { path = "../prisma-models" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } From fe2fb8bd412dfa7273e9cb140f515f30fc6c7072 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 16:42:24 +0100 Subject: [PATCH 018/134] chore(sql-query-connector): fix clipppy on wasm32 --- .../connectors/sql-query-connector/src/database/connection.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 0247e8c4b601..7895e838399a 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use super::{catch, transaction::SqlConnectorTransaction}; use crate::{database::operations::*, Context, SqlError}; use async_trait::async_trait; From 37bd8d1b6e71ad9033ce6837e588d91f4a4194da Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Tue, 14 Nov 2023 16:51:08 +0100 Subject: [PATCH 019/134] chore: remove leftover comment --- query-engine/core/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 6441abf8ca3a..7ccf1a293411 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -4,7 +4,6 @@ name = "query-core" version = "0.1.0" [features] -# default = ["metrics"] metrics = ["query-engine-metrics"] [dependencies] From 9c41dc1fba3c560819740d5506fb075ac0310099 Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Tue, 14 Nov 2023 16:51:08 +0100 Subject: [PATCH 020/134] chore: remove leftover comment --- query-engine/core/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 6441abf8ca3a..7ccf1a293411 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -4,7 +4,6 @@ name = "query-core" version = "0.1.0" [features] -# default = ["metrics"] metrics = ["query-engine-metrics"] [dependencies] From b69bb840f0f58731a3ff9f4663dcd812a3937c81 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 11:00:17 +0100 Subject: [PATCH 021/134] feat(driver-adapters): enable Wasm on request-handlers --- query-engine/request-handlers/Cargo.toml | 9 +- .../request-handlers/src/connector_mode.rs | 1 + .../request-handlers/src/load_executor.rs | 162 +++++++++--------- 3 files changed, 90 insertions(+), 82 deletions(-) diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index e6545eda2234..f04d742c448e 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -5,8 +5,9 @@ edition = "2021" [dependencies] prisma-models = { path = "../prisma-models" } -query-core = { path = "../core", features = ["metrics"] } +query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +quaint = { path = "../../quaint" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools = "0.10" @@ -20,7 +21,6 @@ thiserror = "1" tracing = "0.1" url = "2" connection-string.workspace = true -quaint.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } @@ -32,10 +32,11 @@ schema = { path = "../schema" } codspeed-criterion-compat = "1.1.0" [features] -default = ["mongodb", "sql"] +default = ["sql", "mongodb", "native"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] -driver-adapters = ["sql-query-connector"] +driver-adapters = ["sql-query-connector/driver-adapters"] +native = ["mongodb", "sql-query-connector", "quaint/native", "query-core/metrics"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/connector_mode.rs b/query-engine/request-handlers/src/connector_mode.rs index 00e0515a596e..be03fbab5820 100644 --- a/query-engine/request-handlers/src/connector_mode.rs +++ b/query-engine/request-handlers/src/connector_mode.rs @@ -1,6 +1,7 @@ #[derive(Copy, Clone, PartialEq, Eq)] pub enum ConnectorMode { /// Indicates that Rust drivers are used in Query Engine. + #[cfg(feature = "native")] Rust, /// Indicates that JS drivers are used in Query Engine. diff --git a/query-engine/request-handlers/src/load_executor.rs b/query-engine/request-handlers/src/load_executor.rs index 652ad3108f0d..26728605f92a 100644 --- a/query-engine/request-handlers/src/load_executor.rs +++ b/query-engine/request-handlers/src/load_executor.rs @@ -1,14 +1,12 @@ +#![allow(unused_imports)] + use psl::{builtin_connectors::*, Datasource, PreviewFeatures}; use query_core::{executor::InterpretingExecutor, Connector, QueryExecutor}; use sql_query_connector::*; use std::collections::HashMap; use std::env; -use tracing::trace; use url::Url; -#[cfg(feature = "mongodb")] -use mongodb_query_connector::MongoDb; - use super::ConnectorMode; /// Loads a query executor based on the parsed Prisma schema (datasource). @@ -27,6 +25,7 @@ pub async fn load( driver_adapter(source, url, features).await } + #[cfg(feature = "native")] ConnectorMode::Rust => { if let Ok(value) = env::var("PRISMA_DISABLE_QUAINT_EXECUTORS") { let disable = value.to_uppercase(); @@ -36,14 +35,14 @@ pub async fn load( } match source.active_provider { - p if SQLITE.is_provider(p) => sqlite(source, url, features).await, - p if MYSQL.is_provider(p) => mysql(source, url, features).await, - p if POSTGRES.is_provider(p) => postgres(source, url, features).await, - p if MSSQL.is_provider(p) => mssql(source, url, features).await, - p if COCKROACH.is_provider(p) => postgres(source, url, features).await, + p if SQLITE.is_provider(p) => native::sqlite(source, url, features).await, + p if MYSQL.is_provider(p) => native::mysql(source, url, features).await, + p if POSTGRES.is_provider(p) => native::postgres(source, url, features).await, + p if MSSQL.is_provider(p) => native::mssql(source, url, features).await, + p if COCKROACH.is_provider(p) => native::postgres(source, url, features).await, #[cfg(feature = "mongodb")] - p if MONGODB.is_provider(p) => mongodb(source, url, features).await, + p if MONGODB.is_provider(p) => native::mongodb(source, url, features).await, x => Err(query_core::CoreError::ConfigurationError(format!( "Unsupported connector type: {x}" @@ -53,57 +52,88 @@ pub async fn load( } } -async fn sqlite( +#[cfg(feature = "driver-adapters")] +async fn driver_adapter( source: &Datasource, url: &str, features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQLite query connector..."); - let sqlite = Sqlite::from_source(source, url, features).await?; - trace!("Loaded SQLite query connector."); - Ok(executor_for(sqlite, false)) +) -> Result, query_core::CoreError> { + let js = Js::from_source(source, url, features).await?; + Ok(executor_for(js, false)) } -async fn postgres( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading Postgres query connector..."); - let database_str = url; - let psql = PostgreSql::from_source(source, url, features).await?; - - let url = Url::parse(database_str) - .map_err(|err| query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")))?; - let params: HashMap = url.query_pairs().into_owned().collect(); - - let force_transactions = params - .get("pgbouncer") - .and_then(|flag| flag.parse().ok()) - .unwrap_or(false); - trace!("Loaded Postgres query connector."); - Ok(executor_for(psql, force_transactions)) -} +#[cfg(feature = "native")] +mod native { + use super::*; + use tracing::trace; + + pub(crate) async fn sqlite( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQLite query connector..."); + let sqlite = Sqlite::from_source(source, url, features).await?; + trace!("Loaded SQLite query connector."); + Ok(executor_for(sqlite, false)) + } -async fn mysql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - let mysql = Mysql::from_source(source, url, features).await?; - trace!("Loaded MySQL query connector."); - Ok(executor_for(mysql, false)) -} + pub(crate) async fn postgres( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading Postgres query connector..."); + let database_str = url; + let psql = PostgreSql::from_source(source, url, features).await?; + + let url = Url::parse(database_str).map_err(|err| { + query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")) + })?; + let params: HashMap = url.query_pairs().into_owned().collect(); + + let force_transactions = params + .get("pgbouncer") + .and_then(|flag| flag.parse().ok()) + .unwrap_or(false); + trace!("Loaded Postgres query connector."); + Ok(executor_for(psql, force_transactions)) + } -async fn mssql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQL Server query connector..."); - let mssql = Mssql::from_source(source, url, features).await?; - trace!("Loaded SQL Server query connector."); - Ok(executor_for(mssql, false)) + pub(crate) async fn mysql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + let mysql = Mysql::from_source(source, url, features).await?; + trace!("Loaded MySQL query connector."); + Ok(executor_for(mysql, false)) + } + + pub(crate) async fn mssql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQL Server query connector..."); + let mssql = Mssql::from_source(source, url, features).await?; + trace!("Loaded SQL Server query connector."); + Ok(executor_for(mssql, false)) + } + + #[cfg(feature = "mongodb")] + pub(crate) async fn mongodb( + source: &Datasource, + url: &str, + _features: PreviewFeatures, + ) -> query_core::Result> { + use mongodb_query_connector::MongoDb; + + trace!("Loading MongoDB query connector..."); + let mongo = MongoDb::new(source, url).await?; + trace!("Loaded MongoDB query connector."); + Ok(executor_for(mongo, false)) + } } fn executor_for(connector: T, force_transactions: bool) -> Box @@ -112,27 +142,3 @@ where { Box::new(InterpretingExecutor::new(connector, force_transactions)) } - -#[cfg(feature = "mongodb")] -async fn mongodb( - source: &Datasource, - url: &str, - _features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading MongoDB query connector..."); - let mongo = MongoDb::new(source, url).await?; - trace!("Loaded MongoDB query connector."); - Ok(executor_for(mongo, false)) -} - -#[cfg(feature = "driver-adapters")] -async fn driver_adapter( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> Result, query_core::CoreError> { - trace!("Loading driver adapter..."); - let js = Js::from_source(source, url, features).await?; - trace!("Loaded driver adapter..."); - Ok(executor_for(js, false)) -} From c987dceb3895fa57f0e16fa84d72d217fd186673 Mon Sep 17 00:00:00 2001 From: Miguel Fernandez Date: Wed, 15 Nov 2023 12:51:00 +0100 Subject: [PATCH 022/134] WIP: refactor mysql module to flatten its structure --- quaint/src/connector.rs | 4 ++-- quaint/src/connector/mysql.rs | 11 +++++++---- quaint/src/connector/mysql/{wasm => }/error.rs | 0 quaint/src/connector/mysql/native/error.rs | 2 +- quaint/src/connector/mysql/native/mod.rs | 2 +- quaint/src/connector/mysql/{wasm/common.rs => url.rs} | 0 quaint/src/connector/mysql/wasm/mod.rs | 6 ------ 7 files changed, 11 insertions(+), 14 deletions(-) rename quaint/src/connector/mysql/{wasm => }/error.rs (100%) rename quaint/src/connector/mysql/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/mysql/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 7903d23931c0..a2ee455fee22 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -38,10 +38,10 @@ pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; +#[cfg(feature = "mysql")] +pub use mysql::*; #[cfg(feature = "mysql-native")] pub use mysql::native::*; -#[cfg(feature = "mysql")] -pub use mysql::wasm::common::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 1e52af6a83a0..0834be88949e 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,8 +1,11 @@ -pub use wasm::common::MysqlUrl; -pub use wasm::error::MysqlError; +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. +pub mod error; +pub(crate) mod url; -#[cfg(feature = "mysql")] -pub(crate) mod wasm; +pub use error::MysqlError; +pub use url::MysqlUrl; #[cfg(feature = "mysql-native")] pub(crate) mod native; + diff --git a/quaint/src/connector/mysql/wasm/error.rs b/quaint/src/connector/mysql/error.rs similarity index 100% rename from quaint/src/connector/mysql/wasm/error.rs rename to quaint/src/connector/mysql/error.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs index e00ff1e0aa74..89c21fb706f6 100644 --- a/quaint/src/connector/mysql/native/error.rs +++ b/quaint/src/connector/mysql/native/error.rs @@ -1,5 +1,5 @@ use crate::{ - connector::mysql::wasm::error::MysqlError, + connector::mysql::error::MysqlError, error::{Error, ErrorKind}, }; use mysql_async as my; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index e72a2c47a9a1..7a95ee47b614 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -pub(crate) use crate::connector::mysql::wasm::common::MysqlUrl; +pub(crate) use crate::connector::mysql::MysqlUrl; use crate::connector::{timeout, IsolationLevel}; use crate::{ diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/url.rs similarity index 100% rename from quaint/src/connector/mysql/wasm/common.rs rename to quaint/src/connector/mysql/url.rs diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs deleted file mode 100644 index 4f73f82031d5..000000000000 --- a/quaint/src/connector/mysql/wasm/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Wasm-compatible definitions for the MySQL connector. -//! This module is only available with the `mysql` feature. -pub(crate) mod common; -pub mod error; - -pub use common::MysqlUrl; From 626bc1ef904d0e46fec7046a64cb3927889b6452 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:17:23 +0100 Subject: [PATCH 023/134] feat(quaint): flatten mssql connector module --- quaint/src/connector.rs | 6 +++--- quaint/src/connector/mssql.rs | 7 ++++--- quaint/src/connector/mssql/native/mod.rs | 2 +- quaint/src/connector/mssql/{wasm/common.rs => url.rs} | 0 quaint/src/connector/mssql/wasm/mod.rs | 5 ----- quaint/src/connector/mysql.rs | 1 - 6 files changed, 8 insertions(+), 13 deletions(-) rename quaint/src/connector/mssql/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/mssql/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index a2ee455fee22..97643978228b 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -38,10 +38,10 @@ pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; -#[cfg(feature = "mysql")] -pub use mysql::*; #[cfg(feature = "mysql-native")] pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; @@ -55,4 +55,4 @@ pub(crate) mod mssql; #[cfg(feature = "mssql-native")] pub use mssql::native::*; #[cfg(feature = "mssql")] -pub use mssql::wasm::common::*; +pub use mssql::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index c83b5f1f7266..09f589192676 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,7 +1,8 @@ -pub use wasm::common::MssqlUrl; +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. +pub(crate) mod url; -#[cfg(feature = "mssql")] -pub(crate) mod wasm; +pub use url::MssqlUrl; #[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 8458935814b4..d7052d5e5180 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -pub(crate) use crate::connector::mssql::wasm::common::MssqlUrl; +pub(crate) use crate::connector::mssql::MssqlUrl; use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; use crate::{ diff --git a/quaint/src/connector/mssql/wasm/common.rs b/quaint/src/connector/mssql/url.rs similarity index 100% rename from quaint/src/connector/mssql/wasm/common.rs rename to quaint/src/connector/mssql/url.rs diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs deleted file mode 100644 index 5a25a32836c2..000000000000 --- a/quaint/src/connector/mssql/wasm/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Wasm-compatible definitions for the MSSQL connector. -//! This module is only available with the `mssql` feature. -pub(crate) mod common; - -pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 0834be88949e..5ca2c3551f29 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -8,4 +8,3 @@ pub use url::MysqlUrl; #[cfg(feature = "mysql-native")] pub(crate) mod native; - From a9f8ba841de6f1715b6e1002f67d22a3b60c5c6d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:23:19 +0100 Subject: [PATCH 024/134] feat(quaint): flatten postgres connector module --- quaint/src/connector.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/postgres.rs | 10 ++++++---- quaint/src/connector/postgres/{wasm => }/error.rs | 0 quaint/src/connector/postgres/native/error.rs | 2 +- quaint/src/connector/postgres/native/mod.rs | 6 +++--- .../src/connector/postgres/{wasm/common.rs => url.rs} | 0 quaint/src/connector/postgres/wasm/mod.rs | 6 ------ 8 files changed, 12 insertions(+), 16 deletions(-) rename quaint/src/connector/postgres/{wasm => }/error.rs (100%) rename quaint/src/connector/postgres/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/postgres/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 97643978228b..82b1437b6c03 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -34,7 +34,7 @@ pub(crate) mod postgres; #[cfg(feature = "postgresql-native")] pub use postgres::native::*; #[cfg(feature = "postgresql")] -pub use postgres::wasm::common::*; +pub use postgres::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 5ca2c3551f29..23fed3c70bd3 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,6 @@ //! Wasm-compatible definitions for the MySQL connector. //! This module is only available with the `mysql` feature. -pub mod error; +pub(crate) mod error; pub(crate) mod url; pub use error::MysqlError; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 73a8547b8a65..71d40e71ba0f 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,8 +1,10 @@ -pub use wasm::common::PostgresUrl; -pub use wasm::error::PostgresError; +//! Wasm-compatible definitions for the PostgreSQL connector. +//! This module is only available with the `postgresql` feature. +pub(crate) mod error; +pub(crate) mod url; -#[cfg(feature = "postgresql")] -pub(crate) mod wasm; +pub use error::PostgresError; +pub use url::{PostgresFlavour, PostgresUrl}; #[cfg(feature = "postgresql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres/wasm/error.rs b/quaint/src/connector/postgres/error.rs similarity index 100% rename from quaint/src/connector/postgres/wasm/error.rs rename to quaint/src/connector/postgres/error.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs index 05b792e27900..c353e397705c 100644 --- a/quaint/src/connector/postgres/native/error.rs +++ b/quaint/src/connector/postgres/native/error.rs @@ -1,7 +1,7 @@ use tokio_postgres::error::DbError; use crate::{ - connector::postgres::wasm::error::PostgresError, + connector::postgres::error::PostgresError, error::{Error, ErrorKind}, }; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index fbb4760ed19f..5dbf67a91cdf 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -4,8 +4,8 @@ mod conversion; mod error; -pub(crate) use crate::connector::postgres::wasm::common::PostgresUrl; -use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use crate::connector::postgres::url::PostgresUrl; +use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{timeout, IsolationLevel, Transaction}; use crate::{ @@ -670,7 +670,7 @@ fn is_safe_identifier(ident: &str) -> bool { #[cfg(test)] mod tests { use super::*; - pub(crate) use crate::connector::postgres::wasm::common::PostgresFlavour; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; use crate::{connector::Queryable, error::*, single::Quaint}; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/url.rs similarity index 100% rename from quaint/src/connector/postgres/wasm/common.rs rename to quaint/src/connector/postgres/url.rs diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs deleted file mode 100644 index 859de8f6fd3c..000000000000 --- a/quaint/src/connector/postgres/wasm/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Wasm-compatible definitions for the Postgres connector. -//! This module is only available with the `postgresql` feature. -pub(crate) mod common; -pub mod error; - -pub use common::PostgresUrl; From 3c1a1008c915f1baae29c39dbc82f55fb6b0945a Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:28:58 +0100 Subject: [PATCH 025/134] feat(quaint): flatten sqlite connector module --- quaint/src/connector.rs | 2 +- quaint/src/connector/sqlite.rs | 10 +++++++--- quaint/src/connector/sqlite/{wasm => }/error.rs | 0 quaint/src/connector/sqlite/{wasm => }/ffi.rs | 0 quaint/src/connector/sqlite/native/error.rs | 2 +- quaint/src/connector/sqlite/native/mod.rs | 2 +- .../src/connector/sqlite/{wasm/common.rs => params.rs} | 0 quaint/src/connector/sqlite/wasm/mod.rs | 5 ----- 8 files changed, 10 insertions(+), 11 deletions(-) rename quaint/src/connector/sqlite/{wasm => }/error.rs (100%) rename quaint/src/connector/sqlite/{wasm => }/ffi.rs (100%) rename quaint/src/connector/sqlite/{wasm/common.rs => params.rs} (100%) delete mode 100644 quaint/src/connector/sqlite/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 82b1437b6c03..dddb3c953ad7 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -48,7 +48,7 @@ pub(crate) mod sqlite; #[cfg(feature = "sqlite-native")] pub use sqlite::native::*; #[cfg(feature = "sqlite")] -pub use sqlite::wasm::common::*; +pub use sqlite::*; #[cfg(feature = "mssql")] pub(crate) mod mssql; diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 45611aab9357..c59c947b8dc1 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,7 +1,11 @@ -pub use wasm::error::SqliteError; +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. +pub(crate) mod error; +mod ffi; +pub(crate) mod params; -#[cfg(feature = "sqlite")] -pub(crate) mod wasm; +pub use error::SqliteError; +pub use params::*; #[cfg(feature = "sqlite-native")] pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/wasm/error.rs b/quaint/src/connector/sqlite/error.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/error.rs rename to quaint/src/connector/sqlite/error.rs diff --git a/quaint/src/connector/sqlite/wasm/ffi.rs b/quaint/src/connector/sqlite/ffi.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/ffi.rs rename to quaint/src/connector/sqlite/ffi.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs index d09e2959ce28..51b2417ed821 100644 --- a/quaint/src/connector/sqlite/native/error.rs +++ b/quaint/src/connector/sqlite/native/error.rs @@ -1,4 +1,4 @@ -use crate::connector::sqlite::wasm::error::SqliteError; +use crate::connector::sqlite::error::SqliteError; use crate::error::*; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index bdf5c473fd4d..4b686f5968d6 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -use crate::connector::sqlite::wasm::common::SqliteParams; +use crate::connector::sqlite::params::SqliteParams; use crate::connector::IsolationLevel; pub use rusqlite::{params_from_iter, version as sqlite_version}; diff --git a/quaint/src/connector/sqlite/wasm/common.rs b/quaint/src/connector/sqlite/params.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/common.rs rename to quaint/src/connector/sqlite/params.rs diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs deleted file mode 100644 index 662237af30a1..000000000000 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Wasm-compatible definitions for the SQLite connector. -//! This module is only available with the `sqlite` feature. -pub(crate) mod common; -pub mod error; -mod ffi; From 7f4c8f943142d45340dbd2c4c621093998130a72 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:30:47 +0100 Subject: [PATCH 026/134] chore(quaint): export all public definitions in connector "url" modules --- quaint/src/connector/mssql.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/postgres.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index 09f589192676..e18b68fb2ce1 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -2,7 +2,7 @@ //! This module is only available with the `mssql` feature. pub(crate) mod url; -pub use url::MssqlUrl; +pub use url::*; #[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 23fed3c70bd3..0dc504dd2d11 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -4,7 +4,7 @@ pub(crate) mod error; pub(crate) mod url; pub use error::MysqlError; -pub use url::MysqlUrl; +pub use url::*; #[cfg(feature = "mysql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 71d40e71ba0f..d1694108a1b7 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -4,7 +4,7 @@ pub(crate) mod error; pub(crate) mod url; pub use error::PostgresError; -pub use url::{PostgresFlavour, PostgresUrl}; +pub use url::*; #[cfg(feature = "postgresql-native")] pub(crate) mod native; From 95a4e28c89a1cda3c7bdf6f15b7ae543ae5da780 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:06:22 +0100 Subject: [PATCH 027/134] chore(quaint): refactor tests for connectors, addressing feedback --- quaint/src/connector/mssql/native/mod.rs | 17 -- quaint/src/connector/mysql/native/mod.rs | 83 -------- quaint/src/connector/mysql/url.rs | 83 ++++++++ quaint/src/connector/postgres/native/mod.rs | 215 +------------------ quaint/src/connector/postgres/url.rs | 223 ++++++++++++++++++++ quaint/src/connector/sqlite/native/mod.rs | 21 -- quaint/src/connector/sqlite/params.rs | 26 +++ 7 files changed, 333 insertions(+), 335 deletions(-) diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index d7052d5e5180..d22aa7a15dd6 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -237,20 +237,3 @@ impl Queryable for Mssql { true } } - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 7a95ee47b614..fdcc3a6276d1 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -295,86 +295,3 @@ impl Queryable for Mysql { true } } - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} diff --git a/quaint/src/connector/mysql/url.rs b/quaint/src/connector/mysql/url.rs index c17b2224c0ef..f0756fa95833 100644 --- a/quaint/src/connector/mysql/url.rs +++ b/quaint/src/connector/mysql/url.rs @@ -316,3 +316,86 @@ pub(crate) struct MysqlUrlQueryParams { #[cfg(feature = "mysql-native")] pub(crate) ssl_opts: mysql_async::SslOpts, } + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 5dbf67a91cdf..30f34e7002be 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -671,89 +671,11 @@ fn is_safe_identifier(ident: &str) -> bool { mod tests { use super::*; pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::connector::Queryable; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; use url::Url; - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - #[tokio::test] async fn test_custom_search_path_pg() { async fn test_path(schema_name: &str) -> Option { @@ -1010,82 +932,6 @@ mod tests { } } - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - #[test] fn test_safe_ident() { // Safe @@ -1123,63 +969,4 @@ mod tests { assert!(!is_safe_identifier(ident)); } } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } } diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 7b9b3aafabb4..f0b60d88a848 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -470,3 +470,226 @@ impl Display for SetSearchPath<'_> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Value; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::tests::test_api::postgres::CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 4b686f5968d6..3bf0c46a7db5 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -165,27 +165,6 @@ mod tests { error::{ErrorKind, Name}, }; - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - #[tokio::test] async fn unknown_table_should_give_a_good_error() { let conn = Sqlite::try_from("file:db/test.db").unwrap(); diff --git a/quaint/src/connector/sqlite/params.rs b/quaint/src/connector/sqlite/params.rs index 46fb5c08f669..f024aa97a694 100644 --- a/quaint/src/connector/sqlite/params.rs +++ b/quaint/src/connector/sqlite/params.rs @@ -103,3 +103,29 @@ impl TryFrom<&str> for SqliteParams { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } +} From bacb635367bb994939c9cdcf530033d334a3224b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:19:27 +0100 Subject: [PATCH 028/134] chore: add comment on MysqlAsyncError --- quaint/src/connector/mysql/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/error.rs index 615f0c69dda4..7b4813bf0223 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/error.rs @@ -1,6 +1,8 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; use thiserror::Error; +// This is a partial copy of the `mysql_async::Error` using only the enum variant used by Prisma. +// This avoids pulling in `mysql_async`, which would break Wasm compilation. #[derive(Debug, Error)] enum MysqlAsyncError { #[error("Server error: `{}'", _0)] From 263bab0c84a396cf5867dd65911c2af79cad7824 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:22:52 +0100 Subject: [PATCH 029/134] chore: add comment on ffi.rs for sqlite --- quaint/src/connector/sqlite/ffi.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quaint/src/connector/sqlite/ffi.rs b/quaint/src/connector/sqlite/ffi.rs index bddfd4354237..c510a459be81 100644 --- a/quaint/src/connector/sqlite/ffi.rs +++ b/quaint/src/connector/sqlite/ffi.rs @@ -1,4 +1,5 @@ -//! This is a partial copy of `rusqlite::ffi::*`. +//! Here, we export only the constants we need to avoid pulling in `rusqlite::ffi::*`, in the sibling `error.rs` file, +//! which would break Wasm compilation. pub const SQLITE_BUSY: i32 = 5; pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; From 76816fdc30780b0b026d6db1307ff8e0ade77510 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:52:05 +0100 Subject: [PATCH 030/134] chore: replace awkward "super::super::" with "crate::..." --- .../sql-query-connector/src/database/native/mssql.rs | 4 ++-- .../sql-query-connector/src/database/native/mysql.rs | 4 ++-- .../sql-query-connector/src/database/native/postgresql.rs | 4 ++-- .../sql-query-connector/src/database/native/sqlite.rs | 4 ++-- .../sql_schema_calculator_flavour/mssql.rs | 2 +- .../sql_schema_calculator_flavour/postgres.rs | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index bdb6e2ee103c..19d3580bba9f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index a1cd585c0005..477d687b995b 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 80025add046f..0e49a1de8bbd 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index b1250b18b2be..e38bccb861f4 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info().clone(), async move { + catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs index 18a0b8e94b3c..51a8f5ef54be 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs @@ -23,7 +23,7 @@ impl SqlSchemaCalculatorFlavour for MssqlFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut data = MssqlSchemaExt::default(); for model in context.datamodel.db.walk_models() { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs index 40577d68a35d..656fe432a970 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs @@ -37,7 +37,7 @@ impl SqlSchemaCalculatorFlavour for PostgresFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut postgres_ext = PostgresSchemaExt::default(); let db = &context.datamodel.db; From 28c0ebc360e3cfa178cac67041fd85af2e4865e3 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 20:53:54 +0100 Subject: [PATCH 031/134] chore: add comments around "query_core::executor::task" --- query-engine/core/Cargo.toml | 3 + query-engine/core/src/executor/mod.rs | 94 +++++++++++++++------------ 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 7ccf1a293411..9e0f03517cb5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,6 +37,9 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +pin-project = "1" +wasm-bindgen-futures = "0.4" + [target.'cfg(target_arch = "wasm32")'.dependencies] pin-project = "1" wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 5ff9830013d6..43df839e9635 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -12,7 +12,6 @@ mod pipeline; mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; -use futures::Future; pub(crate) use request_context::*; @@ -133,52 +132,65 @@ pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } -#[cfg(not(target_arch = "wasm32"))] +// The `task` module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. pub(crate) mod task { - use super::*; - - pub type JoinHandle = tokio::task::JoinHandle; - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - tokio::spawn(future) + pub use arch::{spawn, JoinHandle}; + use futures::Future; + + // On native targets, `tokio::spawn` spawns a new asynchronous task. + #[cfg(not(target_arch = "wasm32"))] + mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } } -} - -#[cfg(target_arch = "wasm32")] -pub(crate) mod task { - use super::*; - - #[pin_project::pin_project] - pub struct JoinHandle(#[pin] tokio::sync::oneshot::Receiver); - impl Future for JoinHandle { - type Output = Result; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let this = self.project(); - this.0.poll(cx) + // On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. + #[cfg(target_arch = "wasm32")] + mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + #[pin_project::pin_project] + pub struct JoinHandle(#[pin] oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + let receiver: std::pin::Pin<&mut oneshot::Receiver> = self.project().0; + receiver.poll(cx) + } } - } - impl JoinHandle { - pub fn abort(&mut self) { - // abort is noop for WASM builds + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } } - } - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (tx, rx) = tokio::sync::oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - tx.send(result).ok(); - }); - JoinHandle(rx) + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } } } From 5126a75cbe880cfe799b4706a8272304cb0473b2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 20:53:54 +0100 Subject: [PATCH 032/134] chore: add comments around "query_core::executor::task" --- query-engine/core/Cargo.toml | 3 + query-engine/core/src/executor/mod.rs | 94 +++++++++++++++------------ 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 7ccf1a293411..9e0f03517cb5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,6 +37,9 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +pin-project = "1" +wasm-bindgen-futures = "0.4" + [target.'cfg(target_arch = "wasm32")'.dependencies] pin-project = "1" wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 5ff9830013d6..43df839e9635 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -12,7 +12,6 @@ mod pipeline; mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; -use futures::Future; pub(crate) use request_context::*; @@ -133,52 +132,65 @@ pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } -#[cfg(not(target_arch = "wasm32"))] +// The `task` module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. pub(crate) mod task { - use super::*; - - pub type JoinHandle = tokio::task::JoinHandle; - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - tokio::spawn(future) + pub use arch::{spawn, JoinHandle}; + use futures::Future; + + // On native targets, `tokio::spawn` spawns a new asynchronous task. + #[cfg(not(target_arch = "wasm32"))] + mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } } -} - -#[cfg(target_arch = "wasm32")] -pub(crate) mod task { - use super::*; - - #[pin_project::pin_project] - pub struct JoinHandle(#[pin] tokio::sync::oneshot::Receiver); - impl Future for JoinHandle { - type Output = Result; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let this = self.project(); - this.0.poll(cx) + // On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. + #[cfg(target_arch = "wasm32")] + mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + #[pin_project::pin_project] + pub struct JoinHandle(#[pin] oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + let receiver: std::pin::Pin<&mut oneshot::Receiver> = self.project().0; + receiver.poll(cx) + } } - } - impl JoinHandle { - pub fn abort(&mut self) { - // abort is noop for WASM builds + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } } - } - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (tx, rx) = tokio::sync::oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - tx.send(result).ok(); - }); - JoinHandle(rx) + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } } } From de39d9e6a46ba8e7612d39989e009d635512bb3c Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 21:02:12 +0100 Subject: [PATCH 033/134] chore: add "request-handlers" to "query-engine-wasm" --- Cargo.lock | 1 + query-engine/query-engine-wasm/Cargo.toml | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 50df863820fd..8544b8ae8134 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3827,6 +3827,7 @@ dependencies = [ "quaint", "query-connector", "query-core", + "request-handlers", "serde", "serde-wasm-bindgen", "serde_json", diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index c8bc6e2b5178..07757fde5d0a 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -15,9 +15,10 @@ user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } -connector = { path = "../connectors/query-connector", package = "query-connector" } +query-connector = { path = "../connectors/query-connector" } sql-query-connector = { path = "../connectors/sql-query-connector" } query-core = { path = "../core" } +request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } thiserror = "1" connection-string.workspace = true From 0a705ecb98fee7593f545958cf755ab423cb5d2b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 21:42:58 +0100 Subject: [PATCH 034/134] chore: add wasm dependencies to Cargo workspace --- Cargo.lock | 2 + Cargo.toml | 5 + prisma-schema-wasm/Cargo.toml | 2 +- query-engine/core/Cargo.toml | 5 +- query-engine/driver-adapters/Cargo.toml | 14 +- .../driver-adapters/src/async_js_function.rs | 70 -- .../driver-adapters/src/conversion.rs | 60 -- .../driver-adapters/src/conversion/mysql.rs | 2 +- query-engine/driver-adapters/src/error.rs | 35 - query-engine/driver-adapters/src/lib.rs | 29 +- query-engine/driver-adapters/src/proxy.rs | 983 ------------------ query-engine/driver-adapters/src/queryable.rs | 303 ------ query-engine/driver-adapters/src/result.rs | 119 --- .../driver-adapters/src/transaction.rs | 136 --- query-engine/query-engine-wasm/Cargo.toml | 15 +- 15 files changed, 45 insertions(+), 1735 deletions(-) delete mode 100644 query-engine/driver-adapters/src/async_js_function.rs delete mode 100644 query-engine/driver-adapters/src/conversion.rs delete mode 100644 query-engine/driver-adapters/src/error.rs delete mode 100644 query-engine/driver-adapters/src/proxy.rs delete mode 100644 query-engine/driver-adapters/src/queryable.rs delete mode 100644 query-engine/driver-adapters/src/result.rs delete mode 100644 query-engine/driver-adapters/src/transaction.rs diff --git a/Cargo.lock b/Cargo.lock index 8544b8ae8134..43a32df4cb92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1099,6 +1099,7 @@ dependencies = [ "chrono", "expect-test", "futures", + "js-sys", "metrics 0.18.1", "napi", "napi-derive", @@ -1112,6 +1113,7 @@ dependencies = [ "tracing", "tracing-core", "uuid", + "wasm-bindgen", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b32a1a85cf18..f496e01fd500 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,11 @@ napi = { version = "2.12.4", default-features = false, features = [ "serde-json", ] } napi-derive = "2.12.4" +js-sys = { version = "0.3" } +serde-wasm-bindgen = { version = "0.5" } +tsify = { version = "0.4.5" } +wasm-bindgen = { version = "0.2.87" } +wasm-bindgen-futures = { version = "0.4" } [workspace.dependencies.quaint] path = "quaint" diff --git a/prisma-schema-wasm/Cargo.toml b/prisma-schema-wasm/Cargo.toml index 248c726c9ba4..38ef9328cb8a 100644 --- a/prisma-schema-wasm/Cargo.toml +++ b/prisma-schema-wasm/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -wasm-bindgen = "=0.2.87" +wasm-bindgen.workspace = true wasm-logger = { version = "0.2.0", optional = true } prisma-fmt = { path = "../prisma-fmt" } diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 9e0f03517cb5..272564fedd88 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,9 +37,6 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" -pin-project = "1" -wasm-bindgen-futures = "0.4" - [target.'cfg(target_arch = "wasm32")'.dependencies] pin-project = "1" -wasm-bindgen-futures = "0.4" +wasm-bindgen-futures.workspace = true diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 4c0b55bb0a92..029c3b5492c3 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -8,7 +8,6 @@ async-trait = "0.1" once_cell = "1.15" serde.workspace = true serde_json.workspace = true -quaint.workspace = true psl.workspace = true tracing = "0.1" tracing-core = "0.1" @@ -22,9 +21,16 @@ bigdecimal = "0.3.0" chrono = "0.4.20" futures = "0.3" -napi.workspace = true -napi-derive.workspace = true - [dev-dependencies] expect-test = "1" tokio.workspace = true + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +napi.workspace = true +napi-derive.workspace = true +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen.workspace = true +js-sys.workspace = true +quaint = { path = "../../quaint" } diff --git a/query-engine/driver-adapters/src/async_js_function.rs b/query-engine/driver-adapters/src/async_js_function.rs deleted file mode 100644 index 5f535334ffb9..000000000000 --- a/query-engine/driver-adapters/src/async_js_function.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::marker::PhantomData; - -use napi::{ - bindgen_prelude::*, - threadsafe_function::{ErrorStrategy, ThreadsafeFunction}, -}; - -use crate::{ - error::{async_unwinding_panic, into_quaint_error}, - result::JsResult, -}; - -/// Wrapper for napi-rs's ThreadsafeFunction that is aware of -/// JS drivers conventions. Performs following things: -/// - Automatically unrefs the function so it won't hold off event loop -/// - Awaits for returned Promise -/// - Unpacks JS `Result` type into Rust `Result` type and converts the error -/// into `quaint::Error`. -/// - Catches panics and converts them to `quaint:Error` -pub(crate) struct AsyncJsFunction -where - ArgType: ToNapiValue + 'static, - ReturnType: FromNapiValue + 'static, -{ - threadsafe_fn: ThreadsafeFunction, - _phantom: PhantomData, -} - -impl AsyncJsFunction -where - ArgType: ToNapiValue + 'static, - ReturnType: FromNapiValue + 'static, -{ - fn from_threadsafe_function( - mut threadsafe_fn: ThreadsafeFunction, - env: Env, - ) -> napi::Result { - threadsafe_fn.unref(&env)?; - - Ok(AsyncJsFunction { - threadsafe_fn, - _phantom: PhantomData, - }) - } - - pub(crate) async fn call(&self, arg: ArgType) -> quaint::Result { - let js_result = async_unwinding_panic(async { - let promise = self - .threadsafe_fn - .call_async::>>(arg) - .await?; - promise.await - }) - .await - .map_err(into_quaint_error)?; - js_result.into() - } -} - -impl FromNapiValue for AsyncJsFunction -where - ArgType: ToNapiValue + 'static, - ReturnType: FromNapiValue + 'static, -{ - unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { - let env = Env::from_raw(napi_env); - let threadsafe_fn = ThreadsafeFunction::from_napi_value(napi_env, napi_val)?; - Self::from_threadsafe_function(threadsafe_fn, env) - } -} diff --git a/query-engine/driver-adapters/src/conversion.rs b/query-engine/driver-adapters/src/conversion.rs deleted file mode 100644 index 00061d72de44..000000000000 --- a/query-engine/driver-adapters/src/conversion.rs +++ /dev/null @@ -1,60 +0,0 @@ -pub(crate) mod mysql; -pub(crate) mod postgres; -pub(crate) mod sqlite; - -use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; -use napi::NapiValue; -use serde::Serialize; -use serde_json::value::Value as JsonValue; - -#[derive(Debug, PartialEq, Serialize)] -#[serde(untagged)] -pub enum JSArg { - Value(serde_json::Value), - Buffer(Vec), - Array(Vec), -} - -impl From for JSArg { - fn from(v: JsonValue) -> Self { - JSArg::Value(v) - } -} - -// FromNapiValue is the napi equivalent to serde::Deserialize. -// Note: we can safely leave this unimplemented as we don't need deserialize napi_value back to JSArg. -// However, removing this altogether would cause a compile error. -impl FromNapiValue for JSArg { - unsafe fn from_napi_value(_env: napi::sys::napi_env, _napi_value: napi::sys::napi_value) -> napi::Result { - unreachable!() - } -} - -// ToNapiValue is the napi equivalent to serde::Serialize. -impl ToNapiValue for JSArg { - unsafe fn to_napi_value(env: napi::sys::napi_env, value: Self) -> napi::Result { - match value { - JSArg::Value(v) => ToNapiValue::to_napi_value(env, v), - JSArg::Buffer(bytes) => { - ToNapiValue::to_napi_value(env, napi::Env::from_raw(env).create_buffer_with_data(bytes)?.into_raw()) - } - // While arrays are encodable as JSON generally, their element might not be, or may be - // represented in a different way than we need. We use this custom logic for all arrays - // to avoid having separate `JsonArray` and `BytesArray` variants in `JSArg` and - // avoid complicating the logic in `conv_params`. - JSArg::Array(items) => { - let env = napi::Env::from_raw(env); - let mut array = env.create_array(items.len().try_into().expect("JS array length must fit into u32"))?; - - for (index, item) in items.into_iter().enumerate() { - let js_value = ToNapiValue::to_napi_value(env.raw(), item)?; - // TODO: NapiRaw could be implemented for sys::napi_value directly, there should - // be no need for re-wrapping; submit a patch to napi-rs and simplify here. - array.set(index as u32, napi::JsUnknown::from_raw_unchecked(env.raw(), js_value))?; - } - - ToNapiValue::to_napi_value(env.raw(), array) - } - } - } -} diff --git a/query-engine/driver-adapters/src/conversion/mysql.rs b/query-engine/driver-adapters/src/conversion/mysql.rs index aab33213431a..114d7e3dfcfe 100644 --- a/query-engine/driver-adapters/src/conversion/mysql.rs +++ b/query-engine/driver-adapters/src/conversion/mysql.rs @@ -1,4 +1,4 @@ -use crate::conversion::JSArg; +use super::JSArg; use serde_json::value::Value as JsonValue; const DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S%.f"; diff --git a/query-engine/driver-adapters/src/error.rs b/query-engine/driver-adapters/src/error.rs deleted file mode 100644 index 4f4128088f49..000000000000 --- a/query-engine/driver-adapters/src/error.rs +++ /dev/null @@ -1,35 +0,0 @@ -use futures::{Future, FutureExt}; -use napi::Error as NapiError; -use quaint::error::Error as QuaintError; -use std::{any::Any, panic::AssertUnwindSafe}; - -/// transforms a napi error into a quaint error copying the status and reason -/// properties over -pub(crate) fn into_quaint_error(napi_err: NapiError) -> QuaintError { - let status = napi_err.status.as_ref().to_owned(); - let reason = napi_err.reason.clone(); - - QuaintError::raw_connector_error(status, reason) -} - -/// catches a panic thrown during the execution of an asynchronous closure and transforms it into -/// the Error variant of a napi::Result. -pub(crate) async fn async_unwinding_panic(fut: F) -> napi::Result -where - F: Future>, -{ - AssertUnwindSafe(fut) - .catch_unwind() - .await - .unwrap_or_else(panic_to_napi_err) -} - -fn panic_to_napi_err(panic_payload: Box) -> napi::Result { - panic_payload - .downcast_ref::<&str>() - .map(|s| -> String { (*s).to_owned() }) - .or_else(|| panic_payload.downcast_ref::().map(|s| s.to_owned())) - .map(|message| Err(napi::Error::from_reason(format!("PANIC: {message}")))) - .ok_or(napi::Error::from_reason("PANIC: unknown panic".to_string())) - .unwrap() -} diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 6e29f9e69609..22b7883180a6 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -1,17 +1,22 @@ //! Query Engine Driver Adapters -//! This crate is responsible for defining a quaint::Connector implementation that uses functions -//! exposed by client connectors via N-API. +//! This crate is responsible for defining a `quaint::Connector` implementation that uses functions +//! exposed by client connectors via either `napi-rs` (on native targets) or `wasm_bindgen` / `js_sys` (on Wasm targets). //! //! A driver adapter is an object defined in javascript that uses a driver -//! (ex. '@planetscale/database') to provide a similar implementation of that of a quaint Connector. i.e. the ability to query and execute SQL -//! plus some transformation of types to adhere to what a quaint::Value expresses. +//! (ex. '@planetscale/database') to provide a similar implementation of that of a `quaint::Connector`. i.e. the ability to query and execute SQL +//! plus some transformation of types to adhere to what a `quaint::Value` expresses. //! -mod async_js_function; -mod conversion; -mod error; -mod proxy; -mod queryable; -mod result; -mod transaction; -pub use queryable::{from_napi, JsQueryable}; +pub(crate) mod conversion; + +#[cfg(not(target_arch = "wasm32"))] +pub mod napi; + +#[cfg(not(target_arch = "wasm32"))] +pub use napi::*; + +#[cfg(target_arch = "wasm32")] +pub mod wasm; + +#[cfg(target_arch = "wasm32")] +pub use wasm::*; diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs deleted file mode 100644 index a708d75c0e32..000000000000 --- a/query-engine/driver-adapters/src/proxy.rs +++ /dev/null @@ -1,983 +0,0 @@ -use std::borrow::Cow; -use std::str::FromStr; - -use crate::async_js_function::AsyncJsFunction; -use crate::conversion::JSArg; -use crate::transaction::JsTransaction; -use metrics::increment_gauge; -use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; -use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; -use napi::{JsObject, JsString}; -use napi_derive::napi; -use quaint::connector::ResultSet as QuaintResultSet; -use quaint::{ - error::{Error as QuaintError, ErrorKind}, - Value as QuaintValue, -}; - -// TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. -use bigdecimal::{BigDecimal, FromPrimitive}; -use chrono::{DateTime, Utc}; -use chrono::{NaiveDate, NaiveTime}; - -/// Proxy is a struct wrapping a javascript object that exhibits basic primitives for -/// querying and executing SQL (i.e. a client connector). The Proxy uses NAPI ThreadSafeFunction to -/// invoke the code within the node runtime that implements the client connector. -pub(crate) struct CommonProxy { - /// Execute a query given as SQL, interpolating the given parameters. - query_raw: AsyncJsFunction, - - /// Execute a query given as SQL, interpolating the given parameters and - /// returning the number of affected rows. - execute_raw: AsyncJsFunction, - - /// Return the flavour for this driver. - pub(crate) flavour: String, -} - -/// This is a JS proxy for accessing the methods specific to top level -/// JS driver objects -pub(crate) struct DriverProxy { - start_transaction: AsyncJsFunction<(), JsTransaction>, -} -/// This a JS proxy for accessing the methods, specific -/// to JS transaction objects -pub(crate) struct TransactionProxy { - /// transaction options - options: TransactionOptions, - - /// commit transaction - commit: AsyncJsFunction<(), ()>, - - /// rollback transaction - rollback: AsyncJsFunction<(), ()>, - - /// dispose transaction, cleanup logic executed at the end of the transaction lifecycle - /// on drop. - dispose: ThreadsafeFunction<(), ErrorStrategy::Fatal>, -} - -/// This result set is more convenient to be manipulated from both Rust and NodeJS. -/// Quaint's version of ResultSet is: -/// -/// pub struct ResultSet { -/// pub(crate) columns: Arc>, -/// pub(crate) rows: Vec>>, -/// pub(crate) last_insert_id: Option, -/// } -/// -/// If we used this ResultSet would we would have worse ergonomics as quaint::Value is a structured -/// enum and cannot be used directly with the #[napi(Object)] macro. Thus requiring us to implement -/// the FromNapiValue and ToNapiValue traits for quaint::Value, and use a different custom type -/// representing the Value in javascript. -/// -#[napi(object)] -#[derive(Debug)] -pub struct JSResultSet { - pub column_types: Vec, - pub column_names: Vec, - // Note this might be encoded differently for performance reasons - pub rows: Vec>, - pub last_insert_id: Option, -} - -impl JSResultSet { - pub fn len(&self) -> usize { - self.rows.len() - } -} - -#[napi] -#[derive(Debug)] -pub enum ColumnType { - // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] - /// The following PlanetScale type IDs are mapped into Int32: - /// - INT8 (TINYINT) -> e.g. `127` - /// - INT16 (SMALLINT) -> e.g. `32767` - /// - INT24 (MEDIUMINT) -> e.g. `8388607` - /// - INT32 (INT) -> e.g. `2147483647` - Int32 = 0, - - /// The following PlanetScale type IDs are mapped into Int64: - /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) - Int64 = 1, - - /// The following PlanetScale type IDs are mapped into Float: - /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` - Float = 2, - - /// The following PlanetScale type IDs are mapped into Double: - /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` - Double = 3, - - /// The following PlanetScale type IDs are mapped into Numeric: - /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) - Numeric = 4, - - /// The following PlanetScale type IDs are mapped into Boolean: - /// - BOOLEAN (BOOLEAN) -> e.g. `1` - Boolean = 5, - - Character = 6, - - /// The following PlanetScale type IDs are mapped into Text: - /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) - /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) - Text = 7, - - /// The following PlanetScale type IDs are mapped into Date: - /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) - Date = 8, - - /// The following PlanetScale type IDs are mapped into Time: - /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) - Time = 9, - - /// The following PlanetScale type IDs are mapped into DateTime: - /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - DateTime = 10, - - /// The following PlanetScale type IDs are mapped into Json: - /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) - Json = 11, - - /// The following PlanetScale type IDs are mapped into Enum: - /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) - Enum = 12, - - /// The following PlanetScale type IDs are mapped into Bytes: - /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) - /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) - /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) - /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) - Bytes = 13, - - /// The following PlanetScale type IDs are mapped into Set: - /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) - /// This is currently unhandled, and will panic if encountered. - Set = 14, - - /// UUID from postgres-flavored driver adapters is mapped to this type. - Uuid = 15, - - /* - * Scalar arrays - */ - /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) - Int32Array = 64, - - /// Int64 array (INT8_ARRAY in PostgreSQL) - Int64Array = 65, - - /// Float array (FLOAT4_ARRAY in PostgreSQL) - FloatArray = 66, - - /// Double array (FLOAT8_ARRAY in PostgreSQL) - DoubleArray = 67, - - /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) - NumericArray = 68, - - /// Boolean array (BOOL_ARRAY in PostgreSQL) - BooleanArray = 69, - - /// Char array (CHAR_ARRAY in PostgreSQL) - CharacterArray = 70, - - /// Text array (TEXT_ARRAY in PostgreSQL) - TextArray = 71, - - /// Date array (DATE_ARRAY in PostgreSQL) - DateArray = 72, - - /// Time array (TIME_ARRAY in PostgreSQL) - TimeArray = 73, - - /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) - DateTimeArray = 74, - - /// Json array (JSON_ARRAY in PostgreSQL) - JsonArray = 75, - - /// Enum array - EnumArray = 76, - - /// Bytes array (BYTEA_ARRAY in PostgreSQL) - BytesArray = 77, - - /// Uuid array (UUID_ARRAY in PostgreSQL) - UuidArray = 78, - - /* - * Below there are custom types that don't have a 1:1 translation with a quaint::Value. - * enum variant. - */ - /// UnknownNumber is used when the type of the column is a number but of unknown particular type - /// and precision. - /// - /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or - /// COUNT, and it can be mapped to either Int64, or Double - UnknownNumber = 128, -} - -#[napi(object)] -#[derive(Debug)] -pub struct Query { - pub sql: String, - pub args: Vec, -} - -fn conversion_error(args: &std::fmt::Arguments) -> QuaintError { - let msg = match args.as_str() { - Some(s) => Cow::Borrowed(s), - None => Cow::Owned(args.to_string()), - }; - QuaintError::builder(ErrorKind::ConversionError(msg)).build() -} - -macro_rules! conversion_error { - ($($arg:tt)*) => { - conversion_error(&format_args!($($arg)*)) - }; -} - -/// Handle data-type conversion from a JSON value to a Quaint value. -/// This is used for most data types, except those that require connector-specific handling, e.g., `ColumnType::Boolean`. -fn js_value_to_quaint( - json_value: serde_json::Value, - column_type: ColumnType, - column_name: &str, -) -> quaint::Result> { - let parse_number_as_i64 = |n: &serde_json::Number| { - n.as_i64().ok_or(conversion_error!( - "number must be an integer in column '{column_name}', got '{n}'" - )) - }; - - // Note for the future: it may be worth revisiting how much bloat so many panics with different static - // strings add to the compiled artefact, and in case we should come up with a restricted set of panic - // messages, or even find a way of removing them altogether. - match column_type { - ColumnType::Int32 => match json_value { - serde_json::Value::Number(n) => { - // n.as_i32() is not implemented, so we need to downcast from i64 instead - parse_number_as_i64(&n) - .and_then(|n| -> quaint::Result { - n.try_into() - .map_err(|e| conversion_error!("cannot convert {n} to i32 in column '{column_name}': {e}")) - }) - .map(QuaintValue::int32) - } - serde_json::Value::String(s) => s.parse::().map(QuaintValue::int32).map_err(|e| { - conversion_error!("string-encoded number must be an i32 in column '{column_name}', got {s}: {e}") - }), - serde_json::Value::Null => Ok(QuaintValue::null_int32()), - mismatch => Err(conversion_error!( - "expected an i32 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Int64 => match json_value { - serde_json::Value::Number(n) => parse_number_as_i64(&n).map(QuaintValue::int64), - serde_json::Value::String(s) => s.parse::().map(QuaintValue::int64).map_err(|e| { - conversion_error!("string-encoded number must be an i64 in column '{column_name}', got {s}: {e}") - }), - serde_json::Value::Null => Ok(QuaintValue::null_int64()), - mismatch => Err(conversion_error!( - "expected a string or number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Float => match json_value { - // n.as_f32() is not implemented, so we need to downcast from f64 instead. - // We assume that the JSON value is a valid f32 number, but we check for overflows anyway. - serde_json::Value::Number(n) => n - .as_f64() - .ok_or(conversion_error!( - "number must be a float in column '{column_name}', got {n}" - )) - .and_then(f64_to_f32) - .map(QuaintValue::float), - serde_json::Value::Null => Ok(QuaintValue::null_float()), - mismatch => Err(conversion_error!( - "expected an f32 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Double => match json_value { - serde_json::Value::Number(n) => n.as_f64().map(QuaintValue::double).ok_or(conversion_error!( - "number must be a f64 in column '{column_name}', got {n}" - )), - serde_json::Value::Null => Ok(QuaintValue::null_double()), - mismatch => Err(conversion_error!( - "expected an f64 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Numeric => match json_value { - serde_json::Value::String(s) => BigDecimal::from_str(&s).map(QuaintValue::numeric).map_err(|e| { - conversion_error!("invalid numeric value when parsing {s} in column '{column_name}': {e}") - }), - serde_json::Value::Number(n) => n - .as_f64() - .and_then(BigDecimal::from_f64) - .ok_or(conversion_error!( - "number must be an f64 in column '{column_name}', got {n}" - )) - .map(QuaintValue::numeric), - serde_json::Value::Null => Ok(QuaintValue::null_numeric()), - mismatch => Err(conversion_error!( - "expected a string-encoded number in column '{column_name}', found {mismatch}", - )), - }, - ColumnType::Boolean => match json_value { - serde_json::Value::Bool(b) => Ok(QuaintValue::boolean(b)), - serde_json::Value::Null => Ok(QuaintValue::null_boolean()), - serde_json::Value::Number(n) => match n.as_i64() { - Some(0) => Ok(QuaintValue::boolean(false)), - Some(1) => Ok(QuaintValue::boolean(true)), - _ => Err(conversion_error!( - "expected number-encoded boolean to be 0 or 1 in column '{column_name}', got {n}" - )), - }, - serde_json::Value::String(s) => match s.as_str() { - "false" | "FALSE" | "0" => Ok(QuaintValue::boolean(false)), - "true" | "TRUE" | "1" => Ok(QuaintValue::boolean(true)), - _ => Err(conversion_error!( - "expected string-encoded boolean in column '{column_name}', got {s}" - )), - }, - mismatch => Err(conversion_error!( - "expected a boolean in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Character => match json_value { - serde_json::Value::String(s) => match s.chars().next() { - Some(c) => Ok(QuaintValue::character(c)), - None => Ok(QuaintValue::null_character()), - }, - serde_json::Value::Null => Ok(QuaintValue::null_character()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Text => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::text(s)), - serde_json::Value::Null => Ok(QuaintValue::null_text()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Date => match json_value { - serde_json::Value::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d") - .map(QuaintValue::date) - .map_err(|_| conversion_error!("expected a date string in column '{column_name}', got {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_date()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Time => match json_value { - serde_json::Value::String(s) => NaiveTime::parse_from_str(&s, "%H:%M:%S%.f") - .map(QuaintValue::time) - .map_err(|_| conversion_error!("expected a time string in column '{column_name}', got {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_time()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::DateTime => match json_value { - // TODO: change parsing order to prefer RFC3339 - serde_json::Value::String(s) => chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") - .map(|dt| DateTime::from_utc(dt, Utc)) - .or_else(|_| DateTime::parse_from_rfc3339(&s).map(DateTime::::from)) - .map(QuaintValue::datetime) - .map_err(|_| conversion_error!("expected a datetime string in column '{column_name}', found {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_datetime()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Json => { - match json_value { - // DbNull - serde_json::Value::Null => Ok(QuaintValue::null_json()), - // JsonNull - serde_json::Value::String(s) if s == "$__prisma_null" => Ok(QuaintValue::json(serde_json::Value::Null)), - json => Ok(QuaintValue::json(json)), - } - } - ColumnType::Enum => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::enum_variant(s)), - serde_json::Value::Null => Ok(QuaintValue::null_enum()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Bytes => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::bytes(s.into_bytes())), - serde_json::Value::Array(array) => array - .iter() - .map(|value| value.as_i64().and_then(|maybe_byte| maybe_byte.try_into().ok())) - .collect::>>() - .map(QuaintValue::bytes) - .ok_or(conversion_error!( - "elements of the array in column '{column_name}' must be u8" - )), - serde_json::Value::Null => Ok(QuaintValue::null_bytes()), - mismatch => Err(conversion_error!( - "expected a string or an array in column '{column_name}', found {mismatch}", - )), - }, - ColumnType::Uuid => match json_value { - serde_json::Value::String(s) => uuid::Uuid::parse_str(&s) - .map(QuaintValue::uuid) - .map_err(|_| conversion_error!("Expected a UUID string in column '{column_name}'")), - serde_json::Value::Null => Ok(QuaintValue::null_bytes()), - mismatch => Err(conversion_error!( - "Expected a UUID string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::UnknownNumber => match json_value { - serde_json::Value::Number(n) => n - .as_i64() - .map(QuaintValue::int64) - .or(n.as_f64().map(QuaintValue::double)) - .ok_or(conversion_error!( - "number must be an i64 or f64 in column '{column_name}', got {n}" - )), - mismatch => Err(conversion_error!( - "expected a either an i64 or a f64 in column '{column_name}', found {mismatch}", - )), - }, - - ColumnType::Int32Array => js_array_to_quaint(ColumnType::Int32, json_value, column_name), - ColumnType::Int64Array => js_array_to_quaint(ColumnType::Int64, json_value, column_name), - ColumnType::FloatArray => js_array_to_quaint(ColumnType::Float, json_value, column_name), - ColumnType::DoubleArray => js_array_to_quaint(ColumnType::Double, json_value, column_name), - ColumnType::NumericArray => js_array_to_quaint(ColumnType::Numeric, json_value, column_name), - ColumnType::BooleanArray => js_array_to_quaint(ColumnType::Boolean, json_value, column_name), - ColumnType::CharacterArray => js_array_to_quaint(ColumnType::Character, json_value, column_name), - ColumnType::TextArray => js_array_to_quaint(ColumnType::Text, json_value, column_name), - ColumnType::DateArray => js_array_to_quaint(ColumnType::Date, json_value, column_name), - ColumnType::TimeArray => js_array_to_quaint(ColumnType::Time, json_value, column_name), - ColumnType::DateTimeArray => js_array_to_quaint(ColumnType::DateTime, json_value, column_name), - ColumnType::JsonArray => js_array_to_quaint(ColumnType::Json, json_value, column_name), - ColumnType::EnumArray => js_array_to_quaint(ColumnType::Enum, json_value, column_name), - ColumnType::BytesArray => js_array_to_quaint(ColumnType::Bytes, json_value, column_name), - ColumnType::UuidArray => js_array_to_quaint(ColumnType::Uuid, json_value, column_name), - - unimplemented => { - todo!("support column type {:?} in column {}", unimplemented, column_name) - } - } -} - -fn js_array_to_quaint( - base_type: ColumnType, - json_value: serde_json::Value, - column_name: &str, -) -> quaint::Result> { - match json_value { - serde_json::Value::Array(array) => Ok(QuaintValue::array( - array - .into_iter() - .enumerate() - .map(|(index, elem)| js_value_to_quaint(elem, base_type, &format!("{column_name}[{index}]"))) - .collect::>>()?, - )), - serde_json::Value::Null => Ok(QuaintValue::null_array()), - mismatch => Err(conversion_error!( - "expected an array in column '{column_name}', found {mismatch}", - )), - } -} - -impl TryFrom for QuaintResultSet { - type Error = quaint::error::Error; - - fn try_from(js_result_set: JSResultSet) -> Result { - let JSResultSet { - rows, - column_names, - column_types, - last_insert_id, - } = js_result_set; - - let mut quaint_rows = Vec::with_capacity(rows.len()); - - for row in rows { - let mut quaint_row = Vec::with_capacity(column_types.len()); - - for (i, row) in row.into_iter().enumerate() { - let column_type = column_types[i]; - let column_name = column_names[i].as_str(); - - quaint_row.push(js_value_to_quaint(row, column_type, column_name)?); - } - - quaint_rows.push(quaint_row); - } - - let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); - let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); - - // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), - // but that's Quaint's ResultSet API and that's how the MySQL connector does it. - // Sqlite, on the other hand, uses a `last_insert_id.unwrap_or(0)` approach. - if let Some(last_insert_id) = last_insert_id { - quaint_result_set.set_last_insert_id(last_insert_id); - } - - Ok(quaint_result_set) - } -} - -impl CommonProxy { - pub fn new(object: &JsObject) -> napi::Result { - let flavour: JsString = object.get_named_property("flavour")?; - - Ok(Self { - query_raw: object.get_named_property("queryRaw")?, - execute_raw: object.get_named_property("executeRaw")?, - flavour: flavour.into_utf8()?.as_str()?.to_owned(), - }) - } - - pub async fn query_raw(&self, params: Query) -> quaint::Result { - self.query_raw.call(params).await - } - - pub async fn execute_raw(&self, params: Query) -> quaint::Result { - self.execute_raw.call(params).await - } -} - -impl DriverProxy { - pub fn new(driver_adapter: &JsObject) -> napi::Result { - Ok(Self { - start_transaction: driver_adapter.get_named_property("startTransaction")?, - }) - } - - pub async fn start_transaction(&self) -> quaint::Result> { - let tx = self.start_transaction.call(()).await?; - - // Decrement for this gauge is done in JsTransaction::commit/JsTransaction::rollback - // Previously, it was done in JsTransaction::new, similar to the native Transaction. - // However, correct Dispatcher is lost there and increment does not register, so we moved - // it here instead. - increment_gauge!("prisma_client_queries_active", 1.0); - Ok(Box::new(tx)) - } -} - -#[derive(Debug)] -#[napi(object)] -pub struct TransactionOptions { - /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) - /// before opening a transaction, committing, or rollbacking. - pub use_phantom_query: bool, -} - -impl TransactionProxy { - pub fn new(js_transaction: &JsObject) -> napi::Result { - let commit = js_transaction.get_named_property("commit")?; - let rollback = js_transaction.get_named_property("rollback")?; - let dispose = js_transaction.get_named_property("dispose")?; - let options = js_transaction.get_named_property("options")?; - - Ok(Self { - commit, - rollback, - dispose, - options, - }) - } - - pub fn options(&self) -> &TransactionOptions { - &self.options - } - - pub async fn commit(&self) -> quaint::Result<()> { - self.commit.call(()).await - } - - pub async fn rollback(&self) -> quaint::Result<()> { - self.rollback.call(()).await - } -} - -impl Drop for TransactionProxy { - fn drop(&mut self) { - _ = self - .dispose - .call((), napi::threadsafe_function::ThreadsafeFunctionCallMode::NonBlocking); - } -} - -/// Coerce a `f64` to a `f32`, asserting that the conversion is lossless. -/// Note that, when overflow occurs during conversion, the result is `infinity`. -fn f64_to_f32(x: f64) -> quaint::Result { - let y = x as f32; - - if x.is_finite() == y.is_finite() { - Ok(y) - } else { - Err(conversion_error!("f32 overflow during conversion")) - } -} -#[cfg(test)] -mod proxy_test { - use num_bigint::BigInt; - use serde_json::json; - - use super::*; - - #[track_caller] - fn test_null<'a, T: Into>>(quaint_none: T, column_type: ColumnType) { - let json_value = serde_json::Value::Null; - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, quaint_none.into()); - } - - #[test] - fn js_value_int32_to_quaint() { - let column_type = ColumnType::Int32; - - // null - test_null(QuaintValue::null_int32(), column_type); - - // 0 - let n: i32 = 0; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // max - let n: i32 = i32::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // min - let n: i32 = i32::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // string-encoded - let n = i32::MAX; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - } - - #[test] - fn js_value_int64_to_quaint() { - let column_type = ColumnType::Int64; - - // null - test_null(QuaintValue::null_int64(), column_type); - - // 0 - let n: i64 = 0; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // max - let n: i64 = i64::MAX; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // min - let n: i64 = i64::MIN; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // number-encoded - let n: i64 = (1 << 53) - 1; // max JS safe integer - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - } - - #[test] - fn js_value_float_to_quaint() { - let column_type = ColumnType::Float; - - // null - test_null(QuaintValue::null_float(), column_type); - - // 0 - let n: f32 = 0.0; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - - // max - let n: f32 = f32::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - - // min - let n: f32 = f32::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - } - - #[test] - fn js_value_double_to_quaint() { - let column_type = ColumnType::Double; - - // null - test_null(QuaintValue::null_double(), column_type); - - // 0 - let n: f64 = 0.0; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - - // max - let n: f64 = f64::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - - // min - let n: f64 = f64::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - } - - #[test] - fn js_value_numeric_to_quaint() { - let column_type = ColumnType::Numeric; - - // null - test_null(QuaintValue::null_numeric(), column_type); - - let n_as_string = "1234.99"; - let decimal = BigDecimal::new(BigInt::parse_bytes(b"123499", 10).unwrap(), 2); - - let json_value = serde_json::Value::String(n_as_string.into()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::numeric(decimal)); - - let n_as_string = "1234.999999"; - let decimal = BigDecimal::new(BigInt::parse_bytes(b"1234999999", 10).unwrap(), 6); - - let json_value = serde_json::Value::String(n_as_string.into()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::numeric(decimal)); - } - - #[test] - fn js_value_boolean_to_quaint() { - let column_type = ColumnType::Boolean; - - // null - test_null(QuaintValue::null_boolean(), column_type); - - // true - for truthy_value in [json!(true), json!(1), json!("true"), json!("TRUE"), json!("1")] { - let quaint_value = js_value_to_quaint(truthy_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::boolean(true)); - } - - // false - for falsy_value in [json!(false), json!(0), json!("false"), json!("FALSE"), json!("0")] { - let quaint_value = js_value_to_quaint(falsy_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::boolean(false)); - } - } - - #[test] - fn js_value_char_to_quaint() { - let column_type = ColumnType::Character; - - // null - test_null(QuaintValue::null_character(), column_type); - - let c = 'c'; - let json_value = serde_json::Value::String(c.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::character(c)); - } - - #[test] - fn js_value_text_to_quaint() { - let column_type = ColumnType::Text; - - // null - test_null(QuaintValue::null_text(), column_type); - - let s = "some text"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::text(s)); - } - - #[test] - fn js_value_date_to_quaint() { - let column_type = ColumnType::Date; - - // null - test_null(QuaintValue::null_date(), column_type); - - let s = "2023-01-01"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); - assert_eq!(quaint_value, QuaintValue::date(date)); - } - - #[test] - fn js_value_time_to_quaint() { - let column_type = ColumnType::Time; - - // null - test_null(QuaintValue::null_time(), column_type); - - let s = "23:59:59"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_opt(23, 59, 59).unwrap(); - assert_eq!(quaint_value, QuaintValue::time(time)); - - let s = "13:02:20.321"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); - assert_eq!(quaint_value, QuaintValue::time(time)); - } - - #[test] - fn js_value_datetime_to_quaint() { - let column_type = ColumnType::DateTime; - - // null - test_null(QuaintValue::null_datetime(), column_type); - - let s = "2023-01-01 23:59:59.415"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 415) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - - let s = "2023-01-01 23:59:59.123456"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_micro_opt(23, 59, 59, 123_456) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - - let s = "2023-01-01 23:59:59"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 0) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - } - - #[test] - fn js_value_json_to_quaint() { - let column_type = ColumnType::Json; - - // null - test_null(QuaintValue::null_json(), column_type); - - let json = json!({ - "key": "value", - "nested": [ - true, - false, - 1, - null - ] - }); - let json_value = json.clone(); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::json(json.clone())); - } - - #[test] - fn js_value_enum_to_quaint() { - let column_type = ColumnType::Enum; - - // null - test_null(QuaintValue::null_enum(), column_type); - - let s = "some enum variant"; - let json_value = serde_json::Value::String(s.to_string()); - - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::enum_variant(s)); - } - - #[test] - fn js_int32_array_to_quaint() { - let column_type = ColumnType::Int32Array; - test_null(QuaintValue::null_array(), column_type); - - let json_value = json!([1, 2, 3]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - assert_eq!( - quaint_value, - QuaintValue::array(vec![ - QuaintValue::int32(1), - QuaintValue::int32(2), - QuaintValue::int32(3) - ]) - ); - - let json_value = json!([1, 2, {}]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); - - assert_eq!( - quaint_value.err().unwrap().to_string(), - "Conversion failed: expected an i32 number in column 'column_name[2]', found {}" - ); - } - - #[test] - fn js_text_array_to_quaint() { - let column_type = ColumnType::TextArray; - test_null(QuaintValue::null_array(), column_type); - - let json_value = json!(["hi", "there"]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - assert_eq!( - quaint_value, - QuaintValue::array(vec![QuaintValue::text("hi"), QuaintValue::text("there"),]) - ); - - let json_value = json!([10]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); - - assert_eq!( - quaint_value.err().unwrap().to_string(), - "Conversion failed: expected a string in column 'column_name[0]', found 10" - ); - } -} diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs deleted file mode 100644 index ab154eccc139..000000000000 --- a/query-engine/driver-adapters/src/queryable.rs +++ /dev/null @@ -1,303 +0,0 @@ -use crate::{ - conversion, - proxy::{CommonProxy, DriverProxy, Query}, -}; -use async_trait::async_trait; -use napi::JsObject; -use psl::datamodel_connector::Flavour; -use quaint::{ - connector::{metrics, IsolationLevel, Transaction}, - error::{Error, ErrorKind}, - prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, - visitor::{self, Visitor}, -}; -use tracing::{info_span, Instrument}; - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub(crate) struct JsBaseQueryable { - pub(crate) proxy: CommonProxy, - pub flavour: Flavour, -} - -impl JsBaseQueryable { - pub(crate) fn new(proxy: CommonProxy) -> Self { - let flavour: Flavour = proxy.flavour.parse().unwrap(); - Self { proxy, flavour } - } - - /// visit a quaint query AST according to the flavour of the JS connector - fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { - match self.flavour { - Flavour::Mysql => visitor::Mysql::build(q), - Flavour::Postgres => visitor::Postgres::build(q), - Flavour::Sqlite => visitor::Sqlite::build(q), - _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), - } - } - - async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { - let sql: String = sql.to_string(); - - let converter = match self.flavour { - Flavour::Postgres => conversion::postgres::value_to_js_arg, - Flavour::Sqlite => conversion::sqlite::value_to_js_arg, - Flavour::Mysql => conversion::mysql::value_to_js_arg, - _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), - }; - - let args = values - .iter() - .map(converter) - .collect::>>()?; - - Ok(Query { sql, args }) - } -} - -#[async_trait] -impl QuaintQueryable for JsBaseQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.query_raw", sql, params, move || async move { - self.do_query_raw(sql, params).await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.execute_raw", sql, params, move || async move { - self.do_execute_raw(sql, params).await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - let params = &[]; - metrics::query("js.raw_cmd", cmd, params, move || async move { - self.do_execute_raw(cmd, params).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> quaint::Result> { - // Note: JS Connectors don't use this method. - Ok(None) - } - - fn is_healthy(&self) -> bool { - // Note: JS Connectors don't use this method. - true - } - - /// Sets the transaction isolation level to given value. - /// Implementers have to make sure that the passed isolation level is valid for the underlying database. - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - if self.flavour == Flavour::Sqlite { - return match isolation_level { - IsolationLevel::Serializable => Ok(()), - _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), - }; - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await - } - - fn requires_isolation_first(&self) -> bool { - match self.flavour { - Flavour::Mysql => true, - Flavour::Postgres | Flavour::Sqlite => false, - _ => unreachable!(), - } - } -} - -impl JsBaseQueryable { - pub fn phantom_query_message(stmt: &str) -> String { - format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) - } - - async fn do_query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; - - let len = result_set.len(); - let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); - - result_set.try_into() - } - - async fn do_execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; - - Ok(affected_rows as u64) - } -} - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub struct JsQueryable { - inner: JsBaseQueryable, - driver_proxy: DriverProxy, -} - -impl std::fmt::Display for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -impl std::fmt::Debug for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -#[async_trait] -impl QuaintQueryable for JsQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.query(q).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw(sql, params).await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw_typed(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.execute(q).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw(sql, params).await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw_typed(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - self.inner.raw_cmd(cmd).await - } - - async fn version(&self) -> quaint::Result> { - self.inner.version().await - } - - fn is_healthy(&self) -> bool { - self.inner.is_healthy() - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - self.inner.set_tx_isolation_level(isolation_level).await - } - - fn requires_isolation_first(&self) -> bool { - self.inner.requires_isolation_first() - } -} - -#[async_trait] -impl TransactionCapable for JsQueryable { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> quaint::Result> { - let tx = self.driver_proxy.start_transaction().await?; - - let isolation_first = tx.requires_isolation_first(); - - if isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - let begin_stmt = tx.begin_statement(); - - let tx_opts = tx.options(); - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(begin_stmt).await?; - } - - if !isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - self.server_reset_query(tx.as_ref()).await?; - - Ok(tx) - } -} - -pub fn from_napi(driver: JsObject) -> JsQueryable { - let common = CommonProxy::new(&driver).unwrap(); - let driver_proxy = DriverProxy::new(&driver).unwrap(); - - JsQueryable { - inner: JsBaseQueryable::new(common), - driver_proxy, - } -} diff --git a/query-engine/driver-adapters/src/result.rs b/query-engine/driver-adapters/src/result.rs deleted file mode 100644 index ad4ce7cbb546..000000000000 --- a/query-engine/driver-adapters/src/result.rs +++ /dev/null @@ -1,119 +0,0 @@ -use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; -use quaint::error::{Error as QuaintError, ErrorKind, MysqlError, PostgresError, SqliteError}; -use serde::Deserialize; - -#[derive(Deserialize)] -#[serde(remote = "PostgresError")] -pub struct PostgresErrorDef { - code: String, - message: String, - severity: String, - detail: Option, - column: Option, - hint: Option, -} - -#[derive(Deserialize)] -#[serde(remote = "MysqlError")] -pub struct MysqlErrorDef { - pub code: u16, - pub message: String, - pub state: String, -} - -#[derive(Deserialize)] -#[serde(remote = "SqliteError", rename_all = "camelCase")] -pub struct SqliteErrorDef { - pub extended_code: i32, - pub message: Option, -} - -#[derive(Deserialize)] -#[serde(tag = "kind")] -/// Wrapper for JS-side errors -pub(crate) enum DriverAdapterError { - /// Unexpected JS exception - GenericJs { - id: i32, - }, - UnsupportedNativeDataType { - #[serde(rename = "type")] - native_type: String, - }, - Postgres(#[serde(with = "PostgresErrorDef")] PostgresError), - Mysql(#[serde(with = "MysqlErrorDef")] MysqlError), - Sqlite(#[serde(with = "SqliteErrorDef")] SqliteError), -} - -impl FromNapiValue for DriverAdapterError { - unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { - let env = Env::from_raw(napi_env); - let value = JsUnknown::from_raw(napi_env, napi_val)?; - env.from_js_value(value) - } -} - -impl From for QuaintError { - fn from(value: DriverAdapterError) -> Self { - match value { - DriverAdapterError::UnsupportedNativeDataType { native_type } => { - QuaintError::builder(ErrorKind::UnsupportedColumnType { - column_type: native_type, - }) - .build() - } - DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), - DriverAdapterError::Postgres(e) => e.into(), - DriverAdapterError::Mysql(e) => e.into(), - DriverAdapterError::Sqlite(e) => e.into(), - // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here - } - } -} - -/// Wrapper for JS-side result type -pub(crate) enum JsResult -where - T: FromNapiValue, -{ - Ok(T), - Err(DriverAdapterError), -} - -impl JsResult -where - T: FromNapiValue, -{ - fn from_js_unknown(unknown: JsUnknown) -> napi::Result { - let object = unknown.coerce_to_object()?; - let ok: bool = object.get_named_property("ok")?; - if ok { - let value: JsUnknown = object.get_named_property("value")?; - return Ok(Self::Ok(T::from_unknown(value)?)); - } - - let error = object.get_named_property("error")?; - Ok(Self::Err(error)) - } -} - -impl FromNapiValue for JsResult -where - T: FromNapiValue, -{ - unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { - Self::from_js_unknown(JsUnknown::from_raw(napi_env, napi_val)?) - } -} - -impl From> for quaint::Result -where - T: FromNapiValue, -{ - fn from(value: JsResult) -> Self { - match value { - JsResult::Ok(result) => Ok(result), - JsResult::Err(error) => Err(error.into()), - } - } -} diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs deleted file mode 100644 index d35a9019c6bc..000000000000 --- a/query-engine/driver-adapters/src/transaction.rs +++ /dev/null @@ -1,136 +0,0 @@ -use async_trait::async_trait; -use metrics::decrement_gauge; -use napi::{bindgen_prelude::FromNapiValue, JsObject}; -use quaint::{ - connector::{IsolationLevel, Transaction as QuaintTransaction}, - prelude::{Query as QuaintQuery, Queryable, ResultSet}, - Value, -}; - -use crate::{ - proxy::{CommonProxy, TransactionOptions, TransactionProxy}, - queryable::JsBaseQueryable, -}; - -// Wrapper around JS transaction objects that implements Queryable -// and quaint::Transaction. Can be used in place of quaint transaction, -// but delegates most operations to JS -pub(crate) struct JsTransaction { - tx_proxy: TransactionProxy, - inner: JsBaseQueryable, -} - -impl JsTransaction { - pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } - } - - pub fn options(&self) -> &TransactionOptions { - self.tx_proxy.options() - } - - pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { - let params = &[]; - quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await - } -} - -#[async_trait] -impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { - // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); - - let commit_stmt = "COMMIT"; - - if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); - self.raw_phantom_cmd(commit_stmt.as_str()).await?; - } else { - self.inner.raw_cmd(commit_stmt).await?; - } - - self.tx_proxy.commit().await - } - - async fn rollback(&self) -> quaint::Result<()> { - // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); - - let rollback_stmt = "ROLLBACK"; - - if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); - self.raw_phantom_cmd(rollback_stmt.as_str()).await?; - } else { - self.inner.raw_cmd(rollback_stmt).await?; - } - - self.tx_proxy.rollback().await - } - - fn as_queryable(&self) -> &dyn Queryable { - self - } -} - -#[async_trait] -impl Queryable for JsTransaction { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.query(q).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.query_raw(sql, params).await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.query_raw_typed(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.execute(q).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.execute_raw(sql, params).await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.execute_raw_typed(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - self.inner.raw_cmd(cmd).await - } - - async fn version(&self) -> quaint::Result> { - self.inner.version().await - } - - fn is_healthy(&self) -> bool { - self.inner.is_healthy() - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - self.inner.set_tx_isolation_level(isolation_level).await - } - - fn requires_isolation_first(&self) -> bool { - self.inner.requires_isolation_first() - } -} - -/// Implementing unsafe `from_napi_value` is only way I managed to get threadsafe -/// JsTransaction value in `DriverProxy`. Going through any intermediate safe napi.rs value, -/// like `JsObject` or `JsUnknown` wrapped inside `JsPromise` makes it impossible to extract the value -/// out of promise while keeping the future `Send`. -impl FromNapiValue for JsTransaction { - unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { - let object = JsObject::from_napi_value(env, napi_val)?; - let common_proxy = CommonProxy::new(&object)?; - let tx_proxy = TransactionProxy::new(&object)?; - - Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) - } -} diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 07757fde5d0a..bf179102dbde 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -20,22 +20,23 @@ sql-query-connector = { path = "../connectors/sql-query-connector" } query-core = { path = "../core" } request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } -thiserror = "1" connection-string.workspace = true -url = "2" +js-sys.workspace = true +serde-wasm-bindgen.workspace = true serde_json.workspace = true +tsify.workspace = true +wasm-bindgen.workspace = true +wasm-bindgen-futures.workspace = true + +thiserror = "1" +url = "2" serde.workspace = true tokio = { version = "1.25", features = ["macros", "sync", "io-util", "time"] } futures = "0.3" -wasm-bindgen = "=0.2.87" -wasm-bindgen-futures = "0.4" -serde-wasm-bindgen = "0.5" -js-sys = "0.3" log = "0.4.6" wasm-logger = "0.2.0" tracing = "0.1" tracing-subscriber = { version = "0.3" } tracing-futures = "0.2" -tsify = "0.4.5" console_error_panic_hook = "0.1.7" From 9a599d1f12f66db1a08a0be53de46616a35a3ce8 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 21:43:59 +0100 Subject: [PATCH 035/134] feat(driver-adapters): move napi-specific code into "napi" module, prepare empty "wasm" module --- .../driver-adapters/src/conversion/js_arg.rs | 16 + .../driver-adapters/src/conversion/mod.rs | 7 + .../src/napi/async_js_function.rs | 70 ++ .../driver-adapters/src/napi/conversion.rs | 42 + .../driver-adapters/src/napi/error.rs | 35 + query-engine/driver-adapters/src/napi/mod.rs | 10 + .../driver-adapters/src/napi/proxy.rs | 983 ++++++++++++++++++ .../driver-adapters/src/napi/queryable.rs | 303 ++++++ .../driver-adapters/src/napi/result.rs | 119 +++ .../driver-adapters/src/napi/transaction.rs | 136 +++ query-engine/driver-adapters/src/wasm/mod.rs | 1 + 11 files changed, 1722 insertions(+) create mode 100644 query-engine/driver-adapters/src/conversion/js_arg.rs create mode 100644 query-engine/driver-adapters/src/conversion/mod.rs create mode 100644 query-engine/driver-adapters/src/napi/async_js_function.rs create mode 100644 query-engine/driver-adapters/src/napi/conversion.rs create mode 100644 query-engine/driver-adapters/src/napi/error.rs create mode 100644 query-engine/driver-adapters/src/napi/mod.rs create mode 100644 query-engine/driver-adapters/src/napi/proxy.rs create mode 100644 query-engine/driver-adapters/src/napi/queryable.rs create mode 100644 query-engine/driver-adapters/src/napi/result.rs create mode 100644 query-engine/driver-adapters/src/napi/transaction.rs create mode 100644 query-engine/driver-adapters/src/wasm/mod.rs diff --git a/query-engine/driver-adapters/src/conversion/js_arg.rs b/query-engine/driver-adapters/src/conversion/js_arg.rs new file mode 100644 index 000000000000..c5b65e80882a --- /dev/null +++ b/query-engine/driver-adapters/src/conversion/js_arg.rs @@ -0,0 +1,16 @@ +use serde::Serialize; +use serde_json::value::Value as JsonValue; + +#[derive(Debug, PartialEq, Serialize)] +#[serde(untagged)] +pub enum JSArg { + Value(serde_json::Value), + Buffer(Vec), + Array(Vec), +} + +impl From for JSArg { + fn from(v: JsonValue) -> Self { + JSArg::Value(v) + } +} diff --git a/query-engine/driver-adapters/src/conversion/mod.rs b/query-engine/driver-adapters/src/conversion/mod.rs new file mode 100644 index 000000000000..5173b2349bab --- /dev/null +++ b/query-engine/driver-adapters/src/conversion/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod js_arg; + +pub(crate) mod mysql; +pub(crate) mod postgres; +pub(crate) mod sqlite; + +pub use js_arg::JSArg; diff --git a/query-engine/driver-adapters/src/napi/async_js_function.rs b/query-engine/driver-adapters/src/napi/async_js_function.rs new file mode 100644 index 000000000000..f55c7f89caa8 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/async_js_function.rs @@ -0,0 +1,70 @@ +use std::marker::PhantomData; + +use napi::{ + bindgen_prelude::*, + threadsafe_function::{ErrorStrategy, ThreadsafeFunction}, +}; + +use super::{ + error::{async_unwinding_panic, into_quaint_error}, + result::JsResult, +}; + +/// Wrapper for napi-rs's ThreadsafeFunction that is aware of +/// JS drivers conventions. Performs following things: +/// - Automatically unrefs the function so it won't hold off event loop +/// - Awaits for returned Promise +/// - Unpacks JS `Result` type into Rust `Result` type and converts the error +/// into `quaint::Error`. +/// - Catches panics and converts them to `quaint:Error` +pub(crate) struct AsyncJsFunction +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ + threadsafe_fn: ThreadsafeFunction, + _phantom: PhantomData, +} + +impl AsyncJsFunction +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ + fn from_threadsafe_function( + mut threadsafe_fn: ThreadsafeFunction, + env: Env, + ) -> napi::Result { + threadsafe_fn.unref(&env)?; + + Ok(AsyncJsFunction { + threadsafe_fn, + _phantom: PhantomData, + }) + } + + pub(crate) async fn call(&self, arg: ArgType) -> quaint::Result { + let js_result = async_unwinding_panic(async { + let promise = self + .threadsafe_fn + .call_async::>>(arg) + .await?; + promise.await + }) + .await + .map_err(into_quaint_error)?; + js_result.into() + } +} + +impl FromNapiValue for AsyncJsFunction +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ + unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { + let env = Env::from_raw(napi_env); + let threadsafe_fn = ThreadsafeFunction::from_napi_value(napi_env, napi_val)?; + Self::from_threadsafe_function(threadsafe_fn, env) + } +} diff --git a/query-engine/driver-adapters/src/napi/conversion.rs b/query-engine/driver-adapters/src/napi/conversion.rs new file mode 100644 index 000000000000..5ab630998d27 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/conversion.rs @@ -0,0 +1,42 @@ +pub(crate) use crate::conversion::{mysql, postgres, sqlite, JSArg}; + +use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; +use napi::NapiValue; + +// FromNapiValue is the napi equivalent to serde::Deserialize. +// Note: we can safely leave this unimplemented as we don't need deserialize napi_value back to JSArg. +// However, removing this altogether would cause a compile error. +impl FromNapiValue for JSArg { + unsafe fn from_napi_value(_env: napi::sys::napi_env, _napi_value: napi::sys::napi_value) -> napi::Result { + unreachable!() + } +} + +// ToNapiValue is the napi equivalent to serde::Serialize. +impl ToNapiValue for JSArg { + unsafe fn to_napi_value(env: napi::sys::napi_env, value: Self) -> napi::Result { + match value { + JSArg::Value(v) => ToNapiValue::to_napi_value(env, v), + JSArg::Buffer(bytes) => { + ToNapiValue::to_napi_value(env, napi::Env::from_raw(env).create_buffer_with_data(bytes)?.into_raw()) + } + // While arrays are encodable as JSON generally, their element might not be, or may be + // represented in a different way than we need. We use this custom logic for all arrays + // to avoid having separate `JsonArray` and `BytesArray` variants in `JSArg` and + // avoid complicating the logic in `conv_params`. + JSArg::Array(items) => { + let env = napi::Env::from_raw(env); + let mut array = env.create_array(items.len().try_into().expect("JS array length must fit into u32"))?; + + for (index, item) in items.into_iter().enumerate() { + let js_value = ToNapiValue::to_napi_value(env.raw(), item)?; + // TODO: NapiRaw could be implemented for sys::napi_value directly, there should + // be no need for re-wrapping; submit a patch to napi-rs and simplify here. + array.set(index as u32, napi::JsUnknown::from_raw_unchecked(env.raw(), js_value))?; + } + + ToNapiValue::to_napi_value(env.raw(), array) + } + } + } +} diff --git a/query-engine/driver-adapters/src/napi/error.rs b/query-engine/driver-adapters/src/napi/error.rs new file mode 100644 index 000000000000..4f4128088f49 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/error.rs @@ -0,0 +1,35 @@ +use futures::{Future, FutureExt}; +use napi::Error as NapiError; +use quaint::error::Error as QuaintError; +use std::{any::Any, panic::AssertUnwindSafe}; + +/// transforms a napi error into a quaint error copying the status and reason +/// properties over +pub(crate) fn into_quaint_error(napi_err: NapiError) -> QuaintError { + let status = napi_err.status.as_ref().to_owned(); + let reason = napi_err.reason.clone(); + + QuaintError::raw_connector_error(status, reason) +} + +/// catches a panic thrown during the execution of an asynchronous closure and transforms it into +/// the Error variant of a napi::Result. +pub(crate) async fn async_unwinding_panic(fut: F) -> napi::Result +where + F: Future>, +{ + AssertUnwindSafe(fut) + .catch_unwind() + .await + .unwrap_or_else(panic_to_napi_err) +} + +fn panic_to_napi_err(panic_payload: Box) -> napi::Result { + panic_payload + .downcast_ref::<&str>() + .map(|s| -> String { (*s).to_owned() }) + .or_else(|| panic_payload.downcast_ref::().map(|s| s.to_owned())) + .map(|message| Err(napi::Error::from_reason(format!("PANIC: {message}")))) + .ok_or(napi::Error::from_reason("PANIC: unknown panic".to_string())) + .unwrap() +} diff --git a/query-engine/driver-adapters/src/napi/mod.rs b/query-engine/driver-adapters/src/napi/mod.rs new file mode 100644 index 000000000000..05267dec453b --- /dev/null +++ b/query-engine/driver-adapters/src/napi/mod.rs @@ -0,0 +1,10 @@ +//! Query Engine Driver Adapters: `napi`-specific implementation. + +mod async_js_function; +mod conversion; +mod error; +mod proxy; +mod queryable; +mod result; +mod transaction; +pub use queryable::{from_napi, JsQueryable}; diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs new file mode 100644 index 000000000000..9511e0463770 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -0,0 +1,983 @@ +use std::borrow::Cow; +use std::str::FromStr; + +use super::async_js_function::AsyncJsFunction; +use super::conversion::JSArg; +use super::transaction::JsTransaction; +use metrics::increment_gauge; +use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; +use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; +use napi::{JsObject, JsString}; +use napi_derive::napi; +use quaint::connector::ResultSet as QuaintResultSet; +use quaint::{ + error::{Error as QuaintError, ErrorKind}, + Value as QuaintValue, +}; + +// TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. +use bigdecimal::{BigDecimal, FromPrimitive}; +use chrono::{DateTime, Utc}; +use chrono::{NaiveDate, NaiveTime}; + +/// Proxy is a struct wrapping a javascript object that exhibits basic primitives for +/// querying and executing SQL (i.e. a client connector). The Proxy uses NAPI ThreadSafeFunction to +/// invoke the code within the node runtime that implements the client connector. +pub(crate) struct CommonProxy { + /// Execute a query given as SQL, interpolating the given parameters. + query_raw: AsyncJsFunction, + + /// Execute a query given as SQL, interpolating the given parameters and + /// returning the number of affected rows. + execute_raw: AsyncJsFunction, + + /// Return the flavour for this driver. + pub(crate) flavour: String, +} + +/// This is a JS proxy for accessing the methods specific to top level +/// JS driver objects +pub(crate) struct DriverProxy { + start_transaction: AsyncJsFunction<(), JsTransaction>, +} +/// This a JS proxy for accessing the methods, specific +/// to JS transaction objects +pub(crate) struct TransactionProxy { + /// transaction options + options: TransactionOptions, + + /// commit transaction + commit: AsyncJsFunction<(), ()>, + + /// rollback transaction + rollback: AsyncJsFunction<(), ()>, + + /// dispose transaction, cleanup logic executed at the end of the transaction lifecycle + /// on drop. + dispose: ThreadsafeFunction<(), ErrorStrategy::Fatal>, +} + +/// This result set is more convenient to be manipulated from both Rust and NodeJS. +/// Quaint's version of ResultSet is: +/// +/// pub struct ResultSet { +/// pub(crate) columns: Arc>, +/// pub(crate) rows: Vec>>, +/// pub(crate) last_insert_id: Option, +/// } +/// +/// If we used this ResultSet would we would have worse ergonomics as quaint::Value is a structured +/// enum and cannot be used directly with the #[napi(Object)] macro. Thus requiring us to implement +/// the FromNapiValue and ToNapiValue traits for quaint::Value, and use a different custom type +/// representing the Value in javascript. +/// +#[napi(object)] +#[derive(Debug)] +pub struct JSResultSet { + pub column_types: Vec, + pub column_names: Vec, + // Note this might be encoded differently for performance reasons + pub rows: Vec>, + pub last_insert_id: Option, +} + +impl JSResultSet { + pub fn len(&self) -> usize { + self.rows.len() + } +} + +#[napi] +#[derive(Debug)] +pub enum ColumnType { + // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] + /// The following PlanetScale type IDs are mapped into Int32: + /// - INT8 (TINYINT) -> e.g. `127` + /// - INT16 (SMALLINT) -> e.g. `32767` + /// - INT24 (MEDIUMINT) -> e.g. `8388607` + /// - INT32 (INT) -> e.g. `2147483647` + Int32 = 0, + + /// The following PlanetScale type IDs are mapped into Int64: + /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) + Int64 = 1, + + /// The following PlanetScale type IDs are mapped into Float: + /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` + Float = 2, + + /// The following PlanetScale type IDs are mapped into Double: + /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` + Double = 3, + + /// The following PlanetScale type IDs are mapped into Numeric: + /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) + Numeric = 4, + + /// The following PlanetScale type IDs are mapped into Boolean: + /// - BOOLEAN (BOOLEAN) -> e.g. `1` + Boolean = 5, + + Character = 6, + + /// The following PlanetScale type IDs are mapped into Text: + /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) + /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) + Text = 7, + + /// The following PlanetScale type IDs are mapped into Date: + /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) + Date = 8, + + /// The following PlanetScale type IDs are mapped into Time: + /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) + Time = 9, + + /// The following PlanetScale type IDs are mapped into DateTime: + /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) + /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) + DateTime = 10, + + /// The following PlanetScale type IDs are mapped into Json: + /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) + Json = 11, + + /// The following PlanetScale type IDs are mapped into Enum: + /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) + Enum = 12, + + /// The following PlanetScale type IDs are mapped into Bytes: + /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) + /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) + /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) + /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) + Bytes = 13, + + /// The following PlanetScale type IDs are mapped into Set: + /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) + /// This is currently unhandled, and will panic if encountered. + Set = 14, + + /// UUID from postgres-flavored driver adapters is mapped to this type. + Uuid = 15, + + /* + * Scalar arrays + */ + /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) + Int32Array = 64, + + /// Int64 array (INT8_ARRAY in PostgreSQL) + Int64Array = 65, + + /// Float array (FLOAT4_ARRAY in PostgreSQL) + FloatArray = 66, + + /// Double array (FLOAT8_ARRAY in PostgreSQL) + DoubleArray = 67, + + /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) + NumericArray = 68, + + /// Boolean array (BOOL_ARRAY in PostgreSQL) + BooleanArray = 69, + + /// Char array (CHAR_ARRAY in PostgreSQL) + CharacterArray = 70, + + /// Text array (TEXT_ARRAY in PostgreSQL) + TextArray = 71, + + /// Date array (DATE_ARRAY in PostgreSQL) + DateArray = 72, + + /// Time array (TIME_ARRAY in PostgreSQL) + TimeArray = 73, + + /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) + DateTimeArray = 74, + + /// Json array (JSON_ARRAY in PostgreSQL) + JsonArray = 75, + + /// Enum array + EnumArray = 76, + + /// Bytes array (BYTEA_ARRAY in PostgreSQL) + BytesArray = 77, + + /// Uuid array (UUID_ARRAY in PostgreSQL) + UuidArray = 78, + + /* + * Below there are custom types that don't have a 1:1 translation with a quaint::Value. + * enum variant. + */ + /// UnknownNumber is used when the type of the column is a number but of unknown particular type + /// and precision. + /// + /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or + /// COUNT, and it can be mapped to either Int64, or Double + UnknownNumber = 128, +} + +#[napi(object)] +#[derive(Debug)] +pub struct Query { + pub sql: String, + pub args: Vec, +} + +fn conversion_error(args: &std::fmt::Arguments) -> QuaintError { + let msg = match args.as_str() { + Some(s) => Cow::Borrowed(s), + None => Cow::Owned(args.to_string()), + }; + QuaintError::builder(ErrorKind::ConversionError(msg)).build() +} + +macro_rules! conversion_error { + ($($arg:tt)*) => { + conversion_error(&format_args!($($arg)*)) + }; +} + +/// Handle data-type conversion from a JSON value to a Quaint value. +/// This is used for most data types, except those that require connector-specific handling, e.g., `ColumnType::Boolean`. +fn js_value_to_quaint( + json_value: serde_json::Value, + column_type: ColumnType, + column_name: &str, +) -> quaint::Result> { + let parse_number_as_i64 = |n: &serde_json::Number| { + n.as_i64().ok_or(conversion_error!( + "number must be an integer in column '{column_name}', got '{n}'" + )) + }; + + // Note for the future: it may be worth revisiting how much bloat so many panics with different static + // strings add to the compiled artefact, and in case we should come up with a restricted set of panic + // messages, or even find a way of removing them altogether. + match column_type { + ColumnType::Int32 => match json_value { + serde_json::Value::Number(n) => { + // n.as_i32() is not implemented, so we need to downcast from i64 instead + parse_number_as_i64(&n) + .and_then(|n| -> quaint::Result { + n.try_into() + .map_err(|e| conversion_error!("cannot convert {n} to i32 in column '{column_name}': {e}")) + }) + .map(QuaintValue::int32) + } + serde_json::Value::String(s) => s.parse::().map(QuaintValue::int32).map_err(|e| { + conversion_error!("string-encoded number must be an i32 in column '{column_name}', got {s}: {e}") + }), + serde_json::Value::Null => Ok(QuaintValue::null_int32()), + mismatch => Err(conversion_error!( + "expected an i32 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Int64 => match json_value { + serde_json::Value::Number(n) => parse_number_as_i64(&n).map(QuaintValue::int64), + serde_json::Value::String(s) => s.parse::().map(QuaintValue::int64).map_err(|e| { + conversion_error!("string-encoded number must be an i64 in column '{column_name}', got {s}: {e}") + }), + serde_json::Value::Null => Ok(QuaintValue::null_int64()), + mismatch => Err(conversion_error!( + "expected a string or number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Float => match json_value { + // n.as_f32() is not implemented, so we need to downcast from f64 instead. + // We assume that the JSON value is a valid f32 number, but we check for overflows anyway. + serde_json::Value::Number(n) => n + .as_f64() + .ok_or(conversion_error!( + "number must be a float in column '{column_name}', got {n}" + )) + .and_then(f64_to_f32) + .map(QuaintValue::float), + serde_json::Value::Null => Ok(QuaintValue::null_float()), + mismatch => Err(conversion_error!( + "expected an f32 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Double => match json_value { + serde_json::Value::Number(n) => n.as_f64().map(QuaintValue::double).ok_or(conversion_error!( + "number must be a f64 in column '{column_name}', got {n}" + )), + serde_json::Value::Null => Ok(QuaintValue::null_double()), + mismatch => Err(conversion_error!( + "expected an f64 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Numeric => match json_value { + serde_json::Value::String(s) => BigDecimal::from_str(&s).map(QuaintValue::numeric).map_err(|e| { + conversion_error!("invalid numeric value when parsing {s} in column '{column_name}': {e}") + }), + serde_json::Value::Number(n) => n + .as_f64() + .and_then(BigDecimal::from_f64) + .ok_or(conversion_error!( + "number must be an f64 in column '{column_name}', got {n}" + )) + .map(QuaintValue::numeric), + serde_json::Value::Null => Ok(QuaintValue::null_numeric()), + mismatch => Err(conversion_error!( + "expected a string-encoded number in column '{column_name}', found {mismatch}", + )), + }, + ColumnType::Boolean => match json_value { + serde_json::Value::Bool(b) => Ok(QuaintValue::boolean(b)), + serde_json::Value::Null => Ok(QuaintValue::null_boolean()), + serde_json::Value::Number(n) => match n.as_i64() { + Some(0) => Ok(QuaintValue::boolean(false)), + Some(1) => Ok(QuaintValue::boolean(true)), + _ => Err(conversion_error!( + "expected number-encoded boolean to be 0 or 1 in column '{column_name}', got {n}" + )), + }, + serde_json::Value::String(s) => match s.as_str() { + "false" | "FALSE" | "0" => Ok(QuaintValue::boolean(false)), + "true" | "TRUE" | "1" => Ok(QuaintValue::boolean(true)), + _ => Err(conversion_error!( + "expected string-encoded boolean in column '{column_name}', got {s}" + )), + }, + mismatch => Err(conversion_error!( + "expected a boolean in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Character => match json_value { + serde_json::Value::String(s) => match s.chars().next() { + Some(c) => Ok(QuaintValue::character(c)), + None => Ok(QuaintValue::null_character()), + }, + serde_json::Value::Null => Ok(QuaintValue::null_character()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Text => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::text(s)), + serde_json::Value::Null => Ok(QuaintValue::null_text()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Date => match json_value { + serde_json::Value::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d") + .map(QuaintValue::date) + .map_err(|_| conversion_error!("expected a date string in column '{column_name}', got {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_date()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Time => match json_value { + serde_json::Value::String(s) => NaiveTime::parse_from_str(&s, "%H:%M:%S%.f") + .map(QuaintValue::time) + .map_err(|_| conversion_error!("expected a time string in column '{column_name}', got {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_time()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::DateTime => match json_value { + // TODO: change parsing order to prefer RFC3339 + serde_json::Value::String(s) => chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") + .map(|dt| DateTime::from_utc(dt, Utc)) + .or_else(|_| DateTime::parse_from_rfc3339(&s).map(DateTime::::from)) + .map(QuaintValue::datetime) + .map_err(|_| conversion_error!("expected a datetime string in column '{column_name}', found {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_datetime()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Json => { + match json_value { + // DbNull + serde_json::Value::Null => Ok(QuaintValue::null_json()), + // JsonNull + serde_json::Value::String(s) if s == "$__prisma_null" => Ok(QuaintValue::json(serde_json::Value::Null)), + json => Ok(QuaintValue::json(json)), + } + } + ColumnType::Enum => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::enum_variant(s)), + serde_json::Value::Null => Ok(QuaintValue::null_enum()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Bytes => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::bytes(s.into_bytes())), + serde_json::Value::Array(array) => array + .iter() + .map(|value| value.as_i64().and_then(|maybe_byte| maybe_byte.try_into().ok())) + .collect::>>() + .map(QuaintValue::bytes) + .ok_or(conversion_error!( + "elements of the array in column '{column_name}' must be u8" + )), + serde_json::Value::Null => Ok(QuaintValue::null_bytes()), + mismatch => Err(conversion_error!( + "expected a string or an array in column '{column_name}', found {mismatch}", + )), + }, + ColumnType::Uuid => match json_value { + serde_json::Value::String(s) => uuid::Uuid::parse_str(&s) + .map(QuaintValue::uuid) + .map_err(|_| conversion_error!("Expected a UUID string in column '{column_name}'")), + serde_json::Value::Null => Ok(QuaintValue::null_bytes()), + mismatch => Err(conversion_error!( + "Expected a UUID string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::UnknownNumber => match json_value { + serde_json::Value::Number(n) => n + .as_i64() + .map(QuaintValue::int64) + .or(n.as_f64().map(QuaintValue::double)) + .ok_or(conversion_error!( + "number must be an i64 or f64 in column '{column_name}', got {n}" + )), + mismatch => Err(conversion_error!( + "expected a either an i64 or a f64 in column '{column_name}', found {mismatch}", + )), + }, + + ColumnType::Int32Array => js_array_to_quaint(ColumnType::Int32, json_value, column_name), + ColumnType::Int64Array => js_array_to_quaint(ColumnType::Int64, json_value, column_name), + ColumnType::FloatArray => js_array_to_quaint(ColumnType::Float, json_value, column_name), + ColumnType::DoubleArray => js_array_to_quaint(ColumnType::Double, json_value, column_name), + ColumnType::NumericArray => js_array_to_quaint(ColumnType::Numeric, json_value, column_name), + ColumnType::BooleanArray => js_array_to_quaint(ColumnType::Boolean, json_value, column_name), + ColumnType::CharacterArray => js_array_to_quaint(ColumnType::Character, json_value, column_name), + ColumnType::TextArray => js_array_to_quaint(ColumnType::Text, json_value, column_name), + ColumnType::DateArray => js_array_to_quaint(ColumnType::Date, json_value, column_name), + ColumnType::TimeArray => js_array_to_quaint(ColumnType::Time, json_value, column_name), + ColumnType::DateTimeArray => js_array_to_quaint(ColumnType::DateTime, json_value, column_name), + ColumnType::JsonArray => js_array_to_quaint(ColumnType::Json, json_value, column_name), + ColumnType::EnumArray => js_array_to_quaint(ColumnType::Enum, json_value, column_name), + ColumnType::BytesArray => js_array_to_quaint(ColumnType::Bytes, json_value, column_name), + ColumnType::UuidArray => js_array_to_quaint(ColumnType::Uuid, json_value, column_name), + + unimplemented => { + todo!("support column type {:?} in column {}", unimplemented, column_name) + } + } +} + +fn js_array_to_quaint( + base_type: ColumnType, + json_value: serde_json::Value, + column_name: &str, +) -> quaint::Result> { + match json_value { + serde_json::Value::Array(array) => Ok(QuaintValue::array( + array + .into_iter() + .enumerate() + .map(|(index, elem)| js_value_to_quaint(elem, base_type, &format!("{column_name}[{index}]"))) + .collect::>>()?, + )), + serde_json::Value::Null => Ok(QuaintValue::null_array()), + mismatch => Err(conversion_error!( + "expected an array in column '{column_name}', found {mismatch}", + )), + } +} + +impl TryFrom for QuaintResultSet { + type Error = quaint::error::Error; + + fn try_from(js_result_set: JSResultSet) -> Result { + let JSResultSet { + rows, + column_names, + column_types, + last_insert_id, + } = js_result_set; + + let mut quaint_rows = Vec::with_capacity(rows.len()); + + for row in rows { + let mut quaint_row = Vec::with_capacity(column_types.len()); + + for (i, row) in row.into_iter().enumerate() { + let column_type = column_types[i]; + let column_name = column_names[i].as_str(); + + quaint_row.push(js_value_to_quaint(row, column_type, column_name)?); + } + + quaint_rows.push(quaint_row); + } + + let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); + let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); + + // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), + // but that's Quaint's ResultSet API and that's how the MySQL connector does it. + // Sqlite, on the other hand, uses a `last_insert_id.unwrap_or(0)` approach. + if let Some(last_insert_id) = last_insert_id { + quaint_result_set.set_last_insert_id(last_insert_id); + } + + Ok(quaint_result_set) + } +} + +impl CommonProxy { + pub fn new(object: &JsObject) -> napi::Result { + let flavour: JsString = object.get_named_property("flavour")?; + + Ok(Self { + query_raw: object.get_named_property("queryRaw")?, + execute_raw: object.get_named_property("executeRaw")?, + flavour: flavour.into_utf8()?.as_str()?.to_owned(), + }) + } + + pub async fn query_raw(&self, params: Query) -> quaint::Result { + self.query_raw.call(params).await + } + + pub async fn execute_raw(&self, params: Query) -> quaint::Result { + self.execute_raw.call(params).await + } +} + +impl DriverProxy { + pub fn new(driver_adapter: &JsObject) -> napi::Result { + Ok(Self { + start_transaction: driver_adapter.get_named_property("startTransaction")?, + }) + } + + pub async fn start_transaction(&self) -> quaint::Result> { + let tx = self.start_transaction.call(()).await?; + + // Decrement for this gauge is done in JsTransaction::commit/JsTransaction::rollback + // Previously, it was done in JsTransaction::new, similar to the native Transaction. + // However, correct Dispatcher is lost there and increment does not register, so we moved + // it here instead. + increment_gauge!("prisma_client_queries_active", 1.0); + Ok(Box::new(tx)) + } +} + +#[derive(Debug)] +#[napi(object)] +pub struct TransactionOptions { + /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) + /// before opening a transaction, committing, or rollbacking. + pub use_phantom_query: bool, +} + +impl TransactionProxy { + pub fn new(js_transaction: &JsObject) -> napi::Result { + let commit = js_transaction.get_named_property("commit")?; + let rollback = js_transaction.get_named_property("rollback")?; + let dispose = js_transaction.get_named_property("dispose")?; + let options = js_transaction.get_named_property("options")?; + + Ok(Self { + commit, + rollback, + dispose, + options, + }) + } + + pub fn options(&self) -> &TransactionOptions { + &self.options + } + + pub async fn commit(&self) -> quaint::Result<()> { + self.commit.call(()).await + } + + pub async fn rollback(&self) -> quaint::Result<()> { + self.rollback.call(()).await + } +} + +impl Drop for TransactionProxy { + fn drop(&mut self) { + _ = self + .dispose + .call((), napi::threadsafe_function::ThreadsafeFunctionCallMode::NonBlocking); + } +} + +/// Coerce a `f64` to a `f32`, asserting that the conversion is lossless. +/// Note that, when overflow occurs during conversion, the result is `infinity`. +fn f64_to_f32(x: f64) -> quaint::Result { + let y = x as f32; + + if x.is_finite() == y.is_finite() { + Ok(y) + } else { + Err(conversion_error!("f32 overflow during conversion")) + } +} +#[cfg(test)] +mod proxy_test { + use num_bigint::BigInt; + use serde_json::json; + + use super::*; + + #[track_caller] + fn test_null<'a, T: Into>>(quaint_none: T, column_type: ColumnType) { + let json_value = serde_json::Value::Null; + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, quaint_none.into()); + } + + #[test] + fn js_value_int32_to_quaint() { + let column_type = ColumnType::Int32; + + // null + test_null(QuaintValue::null_int32(), column_type); + + // 0 + let n: i32 = 0; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // max + let n: i32 = i32::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // min + let n: i32 = i32::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // string-encoded + let n = i32::MAX; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + } + + #[test] + fn js_value_int64_to_quaint() { + let column_type = ColumnType::Int64; + + // null + test_null(QuaintValue::null_int64(), column_type); + + // 0 + let n: i64 = 0; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // max + let n: i64 = i64::MAX; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // min + let n: i64 = i64::MIN; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // number-encoded + let n: i64 = (1 << 53) - 1; // max JS safe integer + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + } + + #[test] + fn js_value_float_to_quaint() { + let column_type = ColumnType::Float; + + // null + test_null(QuaintValue::null_float(), column_type); + + // 0 + let n: f32 = 0.0; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + + // max + let n: f32 = f32::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + + // min + let n: f32 = f32::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + } + + #[test] + fn js_value_double_to_quaint() { + let column_type = ColumnType::Double; + + // null + test_null(QuaintValue::null_double(), column_type); + + // 0 + let n: f64 = 0.0; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + + // max + let n: f64 = f64::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + + // min + let n: f64 = f64::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + } + + #[test] + fn js_value_numeric_to_quaint() { + let column_type = ColumnType::Numeric; + + // null + test_null(QuaintValue::null_numeric(), column_type); + + let n_as_string = "1234.99"; + let decimal = BigDecimal::new(BigInt::parse_bytes(b"123499", 10).unwrap(), 2); + + let json_value = serde_json::Value::String(n_as_string.into()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::numeric(decimal)); + + let n_as_string = "1234.999999"; + let decimal = BigDecimal::new(BigInt::parse_bytes(b"1234999999", 10).unwrap(), 6); + + let json_value = serde_json::Value::String(n_as_string.into()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::numeric(decimal)); + } + + #[test] + fn js_value_boolean_to_quaint() { + let column_type = ColumnType::Boolean; + + // null + test_null(QuaintValue::null_boolean(), column_type); + + // true + for truthy_value in [json!(true), json!(1), json!("true"), json!("TRUE"), json!("1")] { + let quaint_value = js_value_to_quaint(truthy_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::boolean(true)); + } + + // false + for falsy_value in [json!(false), json!(0), json!("false"), json!("FALSE"), json!("0")] { + let quaint_value = js_value_to_quaint(falsy_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::boolean(false)); + } + } + + #[test] + fn js_value_char_to_quaint() { + let column_type = ColumnType::Character; + + // null + test_null(QuaintValue::null_character(), column_type); + + let c = 'c'; + let json_value = serde_json::Value::String(c.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::character(c)); + } + + #[test] + fn js_value_text_to_quaint() { + let column_type = ColumnType::Text; + + // null + test_null(QuaintValue::null_text(), column_type); + + let s = "some text"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::text(s)); + } + + #[test] + fn js_value_date_to_quaint() { + let column_type = ColumnType::Date; + + // null + test_null(QuaintValue::null_date(), column_type); + + let s = "2023-01-01"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + assert_eq!(quaint_value, QuaintValue::date(date)); + } + + #[test] + fn js_value_time_to_quaint() { + let column_type = ColumnType::Time; + + // null + test_null(QuaintValue::null_time(), column_type); + + let s = "23:59:59"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + let time: NaiveTime = NaiveTime::from_hms_opt(23, 59, 59).unwrap(); + assert_eq!(quaint_value, QuaintValue::time(time)); + + let s = "13:02:20.321"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); + assert_eq!(quaint_value, QuaintValue::time(time)); + } + + #[test] + fn js_value_datetime_to_quaint() { + let column_type = ColumnType::DateTime; + + // null + test_null(QuaintValue::null_datetime(), column_type); + + let s = "2023-01-01 23:59:59.415"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(23, 59, 59, 415) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + + let s = "2023-01-01 23:59:59.123456"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_micro_opt(23, 59, 59, 123_456) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + + let s = "2023-01-01 23:59:59"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(23, 59, 59, 0) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + } + + #[test] + fn js_value_json_to_quaint() { + let column_type = ColumnType::Json; + + // null + test_null(QuaintValue::null_json(), column_type); + + let json = json!({ + "key": "value", + "nested": [ + true, + false, + 1, + null + ] + }); + let json_value = json.clone(); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::json(json.clone())); + } + + #[test] + fn js_value_enum_to_quaint() { + let column_type = ColumnType::Enum; + + // null + test_null(QuaintValue::null_enum(), column_type); + + let s = "some enum variant"; + let json_value = serde_json::Value::String(s.to_string()); + + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::enum_variant(s)); + } + + #[test] + fn js_int32_array_to_quaint() { + let column_type = ColumnType::Int32Array; + test_null(QuaintValue::null_array(), column_type); + + let json_value = json!([1, 2, 3]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + assert_eq!( + quaint_value, + QuaintValue::array(vec![ + QuaintValue::int32(1), + QuaintValue::int32(2), + QuaintValue::int32(3) + ]) + ); + + let json_value = json!([1, 2, {}]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); + + assert_eq!( + quaint_value.err().unwrap().to_string(), + "Conversion failed: expected an i32 number in column 'column_name[2]', found {}" + ); + } + + #[test] + fn js_text_array_to_quaint() { + let column_type = ColumnType::TextArray; + test_null(QuaintValue::null_array(), column_type); + + let json_value = json!(["hi", "there"]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + assert_eq!( + quaint_value, + QuaintValue::array(vec![QuaintValue::text("hi"), QuaintValue::text("there"),]) + ); + + let json_value = json!([10]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); + + assert_eq!( + quaint_value.err().unwrap().to_string(), + "Conversion failed: expected a string in column 'column_name[0]', found 10" + ); + } +} diff --git a/query-engine/driver-adapters/src/napi/queryable.rs b/query-engine/driver-adapters/src/napi/queryable.rs new file mode 100644 index 000000000000..900ff076b806 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/queryable.rs @@ -0,0 +1,303 @@ +use super::{ + conversion, + proxy::{CommonProxy, DriverProxy, Query}, +}; +use async_trait::async_trait; +use napi::JsObject; +use psl::datamodel_connector::Flavour; +use quaint::{ + connector::{metrics, IsolationLevel, Transaction}, + error::{Error, ErrorKind}, + prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, + visitor::{self, Visitor}, +}; +use tracing::{info_span, Instrument}; + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub(crate) struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, + pub flavour: Flavour, +} + +impl JsBaseQueryable { + pub(crate) fn new(proxy: CommonProxy) -> Self { + let flavour: Flavour = proxy.flavour.parse().unwrap(); + Self { proxy, flavour } + } + + /// visit a quaint query AST according to the flavour of the JS connector + fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { + match self.flavour { + Flavour::Mysql => visitor::Mysql::build(q), + Flavour::Postgres => visitor::Postgres::build(q), + Flavour::Sqlite => visitor::Sqlite::build(q), + _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), + } + } + + async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { + let sql: String = sql.to_string(); + + let converter = match self.flavour { + Flavour::Postgres => conversion::postgres::value_to_js_arg, + Flavour::Sqlite => conversion::sqlite::value_to_js_arg, + Flavour::Mysql => conversion::mysql::value_to_js_arg, + _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), + }; + + let args = values + .iter() + .map(converter) + .collect::>>()?; + + Ok(Query { sql, args }) + } +} + +#[async_trait] +impl QuaintQueryable for JsBaseQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.query_raw", sql, params, move || async move { + self.do_query_raw(sql, params).await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.execute_raw", sql, params, move || async move { + self.do_execute_raw(sql, params).await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + let params = &[]; + metrics::query("js.raw_cmd", cmd, params, move || async move { + self.do_execute_raw(cmd, params).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> quaint::Result> { + // Note: JS Connectors don't use this method. + Ok(None) + } + + fn is_healthy(&self) -> bool { + // Note: JS Connectors don't use this method. + true + } + + /// Sets the transaction isolation level to given value. + /// Implementers have to make sure that the passed isolation level is valid for the underlying database. + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + if self.flavour == Flavour::Sqlite { + return match isolation_level { + IsolationLevel::Serializable => Ok(()), + _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), + }; + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await + } + + fn requires_isolation_first(&self) -> bool { + match self.flavour { + Flavour::Mysql => true, + Flavour::Postgres | Flavour::Sqlite => false, + _ => unreachable!(), + } + } +} + +impl JsBaseQueryable { + pub fn phantom_query_message(stmt: &str) -> String { + format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) + } + + async fn do_query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; + + let len = result_set.len(); + let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); + + result_set.try_into() + } + + async fn do_execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; + + Ok(affected_rows as u64) + } +} + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub struct JsQueryable { + inner: JsBaseQueryable, + driver_proxy: DriverProxy, +} + +impl std::fmt::Display for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +impl std::fmt::Debug for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +#[async_trait] +impl QuaintQueryable for JsQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + let tx = self.driver_proxy.start_transaction().await?; + + let isolation_first = tx.requires_isolation_first(); + + if isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + let begin_stmt = tx.begin_statement(); + + let tx_opts = tx.options(); + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(begin_stmt).await?; + } + + if !isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + self.server_reset_query(tx.as_ref()).await?; + + Ok(tx) + } +} + +pub fn from_napi(driver: JsObject) -> JsQueryable { + let common = CommonProxy::new(&driver).unwrap(); + let driver_proxy = DriverProxy::new(&driver).unwrap(); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } +} diff --git a/query-engine/driver-adapters/src/napi/result.rs b/query-engine/driver-adapters/src/napi/result.rs new file mode 100644 index 000000000000..ad4ce7cbb546 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/result.rs @@ -0,0 +1,119 @@ +use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; +use quaint::error::{Error as QuaintError, ErrorKind, MysqlError, PostgresError, SqliteError}; +use serde::Deserialize; + +#[derive(Deserialize)] +#[serde(remote = "PostgresError")] +pub struct PostgresErrorDef { + code: String, + message: String, + severity: String, + detail: Option, + column: Option, + hint: Option, +} + +#[derive(Deserialize)] +#[serde(remote = "MysqlError")] +pub struct MysqlErrorDef { + pub code: u16, + pub message: String, + pub state: String, +} + +#[derive(Deserialize)] +#[serde(remote = "SqliteError", rename_all = "camelCase")] +pub struct SqliteErrorDef { + pub extended_code: i32, + pub message: Option, +} + +#[derive(Deserialize)] +#[serde(tag = "kind")] +/// Wrapper for JS-side errors +pub(crate) enum DriverAdapterError { + /// Unexpected JS exception + GenericJs { + id: i32, + }, + UnsupportedNativeDataType { + #[serde(rename = "type")] + native_type: String, + }, + Postgres(#[serde(with = "PostgresErrorDef")] PostgresError), + Mysql(#[serde(with = "MysqlErrorDef")] MysqlError), + Sqlite(#[serde(with = "SqliteErrorDef")] SqliteError), +} + +impl FromNapiValue for DriverAdapterError { + unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { + let env = Env::from_raw(napi_env); + let value = JsUnknown::from_raw(napi_env, napi_val)?; + env.from_js_value(value) + } +} + +impl From for QuaintError { + fn from(value: DriverAdapterError) -> Self { + match value { + DriverAdapterError::UnsupportedNativeDataType { native_type } => { + QuaintError::builder(ErrorKind::UnsupportedColumnType { + column_type: native_type, + }) + .build() + } + DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), + DriverAdapterError::Postgres(e) => e.into(), + DriverAdapterError::Mysql(e) => e.into(), + DriverAdapterError::Sqlite(e) => e.into(), + // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here + } + } +} + +/// Wrapper for JS-side result type +pub(crate) enum JsResult +where + T: FromNapiValue, +{ + Ok(T), + Err(DriverAdapterError), +} + +impl JsResult +where + T: FromNapiValue, +{ + fn from_js_unknown(unknown: JsUnknown) -> napi::Result { + let object = unknown.coerce_to_object()?; + let ok: bool = object.get_named_property("ok")?; + if ok { + let value: JsUnknown = object.get_named_property("value")?; + return Ok(Self::Ok(T::from_unknown(value)?)); + } + + let error = object.get_named_property("error")?; + Ok(Self::Err(error)) + } +} + +impl FromNapiValue for JsResult +where + T: FromNapiValue, +{ + unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { + Self::from_js_unknown(JsUnknown::from_raw(napi_env, napi_val)?) + } +} + +impl From> for quaint::Result +where + T: FromNapiValue, +{ + fn from(value: JsResult) -> Self { + match value { + JsResult::Ok(result) => Ok(result), + JsResult::Err(error) => Err(error.into()), + } + } +} diff --git a/query-engine/driver-adapters/src/napi/transaction.rs b/query-engine/driver-adapters/src/napi/transaction.rs new file mode 100644 index 000000000000..16ecbb435ce9 --- /dev/null +++ b/query-engine/driver-adapters/src/napi/transaction.rs @@ -0,0 +1,136 @@ +use async_trait::async_trait; +use metrics::decrement_gauge; +use napi::{bindgen_prelude::FromNapiValue, JsObject}; +use quaint::{ + connector::{IsolationLevel, Transaction as QuaintTransaction}, + prelude::{Query as QuaintQuery, Queryable, ResultSet}, + Value, +}; + +use super::{ + proxy::{CommonProxy, TransactionOptions, TransactionProxy}, + queryable::JsBaseQueryable, +}; + +// Wrapper around JS transaction objects that implements Queryable +// and quaint::Transaction. Can be used in place of quaint transaction, +// but delegates most operations to JS +pub(crate) struct JsTransaction { + tx_proxy: TransactionProxy, + inner: JsBaseQueryable, +} + +impl JsTransaction { + pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { + Self { inner, tx_proxy } + } + + pub fn options(&self) -> &TransactionOptions { + self.tx_proxy.options() + } + + pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { + let params = &[]; + quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await + } +} + +#[async_trait] +impl QuaintTransaction for JsTransaction { + async fn commit(&self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let commit_stmt = "COMMIT"; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(commit_stmt).await?; + } + + self.tx_proxy.commit().await + } + + async fn rollback(&self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let rollback_stmt = "ROLLBACK"; + + if self.options().use_phantom_query { + let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); + self.raw_phantom_cmd(rollback_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(rollback_stmt).await?; + } + + self.tx_proxy.rollback().await + } + + fn as_queryable(&self) -> &dyn Queryable { + self + } +} + +#[async_trait] +impl Queryable for JsTransaction { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +/// Implementing unsafe `from_napi_value` is only way I managed to get threadsafe +/// JsTransaction value in `DriverProxy`. Going through any intermediate safe napi.rs value, +/// like `JsObject` or `JsUnknown` wrapped inside `JsPromise` makes it impossible to extract the value +/// out of promise while keeping the future `Send`. +impl FromNapiValue for JsTransaction { + unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { + let object = JsObject::from_napi_value(env, napi_val)?; + let common_proxy = CommonProxy::new(&object)?; + let tx_proxy = TransactionProxy::new(&object)?; + + Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) + } +} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs new file mode 100644 index 000000000000..3854af9dfa6b --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -0,0 +1 @@ +//! Query Engine Driver Adapters: `wasm`-specific implementation. From 3c7a778e796078ed597528f3438c027d1615f15d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 10:19:31 +0100 Subject: [PATCH 036/134] feat(driver-adapters): extracted platform-agnostic "DriverAdapterError" into "driver_adapters::error" --- query-engine/driver-adapters/src/error.rs | 45 ++++++++++++++++++ query-engine/driver-adapters/src/lib.rs | 1 + .../driver-adapters/src/napi/result.rs | 46 +------------------ 3 files changed, 48 insertions(+), 44 deletions(-) create mode 100644 query-engine/driver-adapters/src/error.rs diff --git a/query-engine/driver-adapters/src/error.rs b/query-engine/driver-adapters/src/error.rs new file mode 100644 index 000000000000..fa01759d9213 --- /dev/null +++ b/query-engine/driver-adapters/src/error.rs @@ -0,0 +1,45 @@ +use quaint::error::{MysqlError, PostgresError, SqliteError}; +use serde::Deserialize; + +#[derive(Deserialize)] +#[serde(remote = "PostgresError")] +pub struct PostgresErrorDef { + code: String, + message: String, + severity: String, + detail: Option, + column: Option, + hint: Option, +} + +#[derive(Deserialize)] +#[serde(remote = "MysqlError")] +pub struct MysqlErrorDef { + pub code: u16, + pub message: String, + pub state: String, +} + +#[derive(Deserialize)] +#[serde(remote = "SqliteError", rename_all = "camelCase")] +pub struct SqliteErrorDef { + pub extended_code: i32, + pub message: Option, +} + +#[derive(Deserialize)] +#[serde(tag = "kind")] +/// Wrapper for JS-side errors +pub(crate) enum DriverAdapterError { + /// Unexpected JS exception + GenericJs { + id: i32, + }, + UnsupportedNativeDataType { + #[serde(rename = "type")] + native_type: String, + }, + Postgres(#[serde(with = "PostgresErrorDef")] PostgresError), + Mysql(#[serde(with = "MysqlErrorDef")] MysqlError), + Sqlite(#[serde(with = "SqliteErrorDef")] SqliteError), +} diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 22b7883180a6..186446cd4b54 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -8,6 +8,7 @@ //! pub(crate) mod conversion; +pub(crate) mod error; #[cfg(not(target_arch = "wasm32"))] pub mod napi; diff --git a/query-engine/driver-adapters/src/napi/result.rs b/query-engine/driver-adapters/src/napi/result.rs index ad4ce7cbb546..d815c9d86dbd 100644 --- a/query-engine/driver-adapters/src/napi/result.rs +++ b/query-engine/driver-adapters/src/napi/result.rs @@ -1,49 +1,7 @@ use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; -use quaint::error::{Error as QuaintError, ErrorKind, MysqlError, PostgresError, SqliteError}; -use serde::Deserialize; +use quaint::error::{Error as QuaintError, ErrorKind}; -#[derive(Deserialize)] -#[serde(remote = "PostgresError")] -pub struct PostgresErrorDef { - code: String, - message: String, - severity: String, - detail: Option, - column: Option, - hint: Option, -} - -#[derive(Deserialize)] -#[serde(remote = "MysqlError")] -pub struct MysqlErrorDef { - pub code: u16, - pub message: String, - pub state: String, -} - -#[derive(Deserialize)] -#[serde(remote = "SqliteError", rename_all = "camelCase")] -pub struct SqliteErrorDef { - pub extended_code: i32, - pub message: Option, -} - -#[derive(Deserialize)] -#[serde(tag = "kind")] -/// Wrapper for JS-side errors -pub(crate) enum DriverAdapterError { - /// Unexpected JS exception - GenericJs { - id: i32, - }, - UnsupportedNativeDataType { - #[serde(rename = "type")] - native_type: String, - }, - Postgres(#[serde(with = "PostgresErrorDef")] PostgresError), - Mysql(#[serde(with = "MysqlErrorDef")] MysqlError), - Sqlite(#[serde(with = "SqliteErrorDef")] SqliteError), -} +use crate::error::DriverAdapterError; impl FromNapiValue for DriverAdapterError { unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { From 23bf4870b0119bcde8606a1cc4c3ff41224bcd88 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 10:30:07 +0100 Subject: [PATCH 037/134] chore(driver-adapters): add "driver-adapters" to "query-engine-wasm" --- Cargo.lock | 3 +++ query-engine/query-engine-wasm/Cargo.toml | 1 + 2 files changed, 4 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 43a32df4cb92..c36a111771af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1108,12 +1108,14 @@ dependencies = [ "psl", "quaint", "serde", + "serde-wasm-bindgen", "serde_json", "tokio", "tracing", "tracing-core", "uuid", "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -3821,6 +3823,7 @@ dependencies = [ "async-trait", "connection-string", "console_error_panic_hook", + "driver-adapters", "futures", "js-sys", "log", diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index bf179102dbde..f4a9703e741e 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -19,6 +19,7 @@ query-connector = { path = "../connectors/query-connector" } sql-query-connector = { path = "../connectors/sql-query-connector" } query-core = { path = "../core" } request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } +driver-adapters = { path = "../driver-adapters" } connection-string.workspace = true js-sys.workspace = true From 4424d4b27cfa41f5d349374c88c8647bef27e58f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 11:49:37 +0100 Subject: [PATCH 038/134] feat(driver-adapters): add Wasm-specific "async_js_function" --- query-engine/driver-adapters/Cargo.toml | 4 ++- .../src/wasm/async_js_function.rs | 31 +++++++++++++++++++ query-engine/driver-adapters/src/wasm/mod.rs | 2 ++ 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 query-engine/driver-adapters/src/wasm/async_js_function.rs diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 029c3b5492c3..254679a01664 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -31,6 +31,8 @@ napi-derive.workspace = true quaint.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] -wasm-bindgen.workspace = true js-sys.workspace = true quaint = { path = "../../quaint" } +serde-wasm-bindgen.workspace = true +wasm-bindgen.workspace = true +wasm-bindgen-futures.workspace = true diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs new file mode 100644 index 000000000000..8e8d6958cce9 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -0,0 +1,31 @@ +use js_sys::{Function as JsFunction, Object as JsObject, Promise as JsPromise}; +use serde::{de::DeserializeOwned, Serialize}; +use std::marker::PhantomData; +use wasm_bindgen::{prelude::wasm_bindgen, JsError, JsValue}; +use wasm_bindgen_futures::JsFuture; + +type JsResult = core::result::Result; + +pub(crate) struct AsyncJsFunction +where + ArgType: Serialize + 'static, + ReturnType: DeserializeOwned + 'static, +{ + threadsafe_fn: JsFunction, + _phantom_arg: PhantomData, + _phantom_return: PhantomData, +} + +impl AsyncJsFunction +where + ArgType: Serialize + 'static, + ReturnType: DeserializeOwned + 'static, +{ + async fn call(&self, arg1: ArgType) -> JsResult { + let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsError::from(&err))?; + let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; + let future = JsFuture::from(JsPromise::from(promise)); + let value = future.await?; + serde_wasm_bindgen::from_value(value).map_err(|err| JsValue::from(err)) + } +} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 3854af9dfa6b..92509cb18c3f 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -1 +1,3 @@ //! Query Engine Driver Adapters: `wasm`-specific implementation. + +mod async_js_function; From 9b60a18eb2cf3224ab1c38b593fbcec65ba845c7 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 13:38:29 +0100 Subject: [PATCH 039/134] feat(driver-adapters): extracted common types to "driver_adapters::types" --- .cargo/config.toml | 2 + Cargo.lock | 1 + query-engine/driver-adapters/Cargo.toml | 1 + query-engine/driver-adapters/src/lib.rs | 1 + .../driver-adapters/src/napi/proxy.rs | 184 +--------------- query-engine/driver-adapters/src/types.rs | 197 ++++++++++++++++++ 6 files changed, 204 insertions(+), 182 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 query-engine/driver-adapters/src/types.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 000000000000..229dd6ee6b3f --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +# target = "wasm32-unknown-unknown" diff --git a/Cargo.lock b/Cargo.lock index c36a111771af..627ecb5fb08c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1113,6 +1113,7 @@ dependencies = [ "tokio", "tracing", "tracing-core", + "tsify", "uuid", "wasm-bindgen", "wasm-bindgen-futures", diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 254679a01664..cb86a96ee7f9 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -36,3 +36,4 @@ quaint = { path = "../../quaint" } serde-wasm-bindgen.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true +tsify.workspace = true diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 186446cd4b54..ca8aa4541bd1 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -9,6 +9,7 @@ pub(crate) mod conversion; pub(crate) mod error; +pub(crate) mod types; #[cfg(not(target_arch = "wasm32"))] pub mod napi; diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index 9511e0463770..0677d279a1f4 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -1,14 +1,13 @@ use std::borrow::Cow; use std::str::FromStr; +pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; + use super::async_js_function::AsyncJsFunction; -use super::conversion::JSArg; use super::transaction::JsTransaction; use metrics::increment_gauge; -use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; use napi::{JsObject, JsString}; -use napi_derive::napi; use quaint::connector::ResultSet as QuaintResultSet; use quaint::{ error::{Error as QuaintError, ErrorKind}, @@ -57,177 +56,6 @@ pub(crate) struct TransactionProxy { dispose: ThreadsafeFunction<(), ErrorStrategy::Fatal>, } -/// This result set is more convenient to be manipulated from both Rust and NodeJS. -/// Quaint's version of ResultSet is: -/// -/// pub struct ResultSet { -/// pub(crate) columns: Arc>, -/// pub(crate) rows: Vec>>, -/// pub(crate) last_insert_id: Option, -/// } -/// -/// If we used this ResultSet would we would have worse ergonomics as quaint::Value is a structured -/// enum and cannot be used directly with the #[napi(Object)] macro. Thus requiring us to implement -/// the FromNapiValue and ToNapiValue traits for quaint::Value, and use a different custom type -/// representing the Value in javascript. -/// -#[napi(object)] -#[derive(Debug)] -pub struct JSResultSet { - pub column_types: Vec, - pub column_names: Vec, - // Note this might be encoded differently for performance reasons - pub rows: Vec>, - pub last_insert_id: Option, -} - -impl JSResultSet { - pub fn len(&self) -> usize { - self.rows.len() - } -} - -#[napi] -#[derive(Debug)] -pub enum ColumnType { - // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] - /// The following PlanetScale type IDs are mapped into Int32: - /// - INT8 (TINYINT) -> e.g. `127` - /// - INT16 (SMALLINT) -> e.g. `32767` - /// - INT24 (MEDIUMINT) -> e.g. `8388607` - /// - INT32 (INT) -> e.g. `2147483647` - Int32 = 0, - - /// The following PlanetScale type IDs are mapped into Int64: - /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) - Int64 = 1, - - /// The following PlanetScale type IDs are mapped into Float: - /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` - Float = 2, - - /// The following PlanetScale type IDs are mapped into Double: - /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` - Double = 3, - - /// The following PlanetScale type IDs are mapped into Numeric: - /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) - Numeric = 4, - - /// The following PlanetScale type IDs are mapped into Boolean: - /// - BOOLEAN (BOOLEAN) -> e.g. `1` - Boolean = 5, - - Character = 6, - - /// The following PlanetScale type IDs are mapped into Text: - /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) - /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) - Text = 7, - - /// The following PlanetScale type IDs are mapped into Date: - /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) - Date = 8, - - /// The following PlanetScale type IDs are mapped into Time: - /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) - Time = 9, - - /// The following PlanetScale type IDs are mapped into DateTime: - /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - DateTime = 10, - - /// The following PlanetScale type IDs are mapped into Json: - /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) - Json = 11, - - /// The following PlanetScale type IDs are mapped into Enum: - /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) - Enum = 12, - - /// The following PlanetScale type IDs are mapped into Bytes: - /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) - /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) - /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) - /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) - Bytes = 13, - - /// The following PlanetScale type IDs are mapped into Set: - /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) - /// This is currently unhandled, and will panic if encountered. - Set = 14, - - /// UUID from postgres-flavored driver adapters is mapped to this type. - Uuid = 15, - - /* - * Scalar arrays - */ - /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) - Int32Array = 64, - - /// Int64 array (INT8_ARRAY in PostgreSQL) - Int64Array = 65, - - /// Float array (FLOAT4_ARRAY in PostgreSQL) - FloatArray = 66, - - /// Double array (FLOAT8_ARRAY in PostgreSQL) - DoubleArray = 67, - - /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) - NumericArray = 68, - - /// Boolean array (BOOL_ARRAY in PostgreSQL) - BooleanArray = 69, - - /// Char array (CHAR_ARRAY in PostgreSQL) - CharacterArray = 70, - - /// Text array (TEXT_ARRAY in PostgreSQL) - TextArray = 71, - - /// Date array (DATE_ARRAY in PostgreSQL) - DateArray = 72, - - /// Time array (TIME_ARRAY in PostgreSQL) - TimeArray = 73, - - /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) - DateTimeArray = 74, - - /// Json array (JSON_ARRAY in PostgreSQL) - JsonArray = 75, - - /// Enum array - EnumArray = 76, - - /// Bytes array (BYTEA_ARRAY in PostgreSQL) - BytesArray = 77, - - /// Uuid array (UUID_ARRAY in PostgreSQL) - UuidArray = 78, - - /* - * Below there are custom types that don't have a 1:1 translation with a quaint::Value. - * enum variant. - */ - /// UnknownNumber is used when the type of the column is a number but of unknown particular type - /// and precision. - /// - /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or - /// COUNT, and it can be mapped to either Int64, or Double - UnknownNumber = 128, -} - -#[napi(object)] -#[derive(Debug)] -pub struct Query { - pub sql: String, - pub args: Vec, -} - fn conversion_error(args: &std::fmt::Arguments) -> QuaintError { let msg = match args.as_str() { Some(s) => Cow::Borrowed(s), @@ -569,14 +397,6 @@ impl DriverProxy { } } -#[derive(Debug)] -#[napi(object)] -pub struct TransactionOptions { - /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) - /// before opening a transaction, committing, or rollbacking. - pub use_phantom_query: bool, -} - impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> napi::Result { let commit = js_transaction.get_named_property("commit")?; diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs new file mode 100644 index 000000000000..9fa2c63b4ffc --- /dev/null +++ b/query-engine/driver-adapters/src/types.rs @@ -0,0 +1,197 @@ +#![allow(unused_imports)] + +#[cfg(not(target_arch = "wasm32"))] +use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; + +#[cfg(target_arch = "wasm32")] +use tsify::Tsify; + +use crate::conversion::JSArg; +use serde::{Deserialize, Serialize}; + +/// This result set is more convenient to be manipulated from both Rust and NodeJS. +/// Quaint's version of ResultSet is: +/// +/// pub struct ResultSet { +/// pub(crate) columns: Arc>, +/// pub(crate) rows: Vec>>, +/// pub(crate) last_insert_id: Option, +/// } +/// +/// If we used this ResultSet would we would have worse ergonomics as quaint::Value is a structured +/// enum and cannot be used directly with the #[napi(Object)] macro. Thus requiring us to implement +/// the FromNapiValue and ToNapiValue traits for quaint::Value, and use a different custom type +/// representing the Value in javascript. +/// +#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] +#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[derive(Debug)] +pub struct JSResultSet { + pub column_types: Vec, + pub column_names: Vec, + // Note this might be encoded differently for performance reasons + pub rows: Vec>, + pub last_insert_id: Option, +} + +impl JSResultSet { + pub fn len(&self) -> usize { + self.rows.len() + } +} + +#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] +#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[derive(Debug)] +pub enum ColumnType { + // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] + /// The following PlanetScale type IDs are mapped into Int32: + /// - INT8 (TINYINT) -> e.g. `127` + /// - INT16 (SMALLINT) -> e.g. `32767` + /// - INT24 (MEDIUMINT) -> e.g. `8388607` + /// - INT32 (INT) -> e.g. `2147483647` + Int32 = 0, + + /// The following PlanetScale type IDs are mapped into Int64: + /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) + Int64 = 1, + + /// The following PlanetScale type IDs are mapped into Float: + /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` + Float = 2, + + /// The following PlanetScale type IDs are mapped into Double: + /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` + Double = 3, + + /// The following PlanetScale type IDs are mapped into Numeric: + /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) + Numeric = 4, + + /// The following PlanetScale type IDs are mapped into Boolean: + /// - BOOLEAN (BOOLEAN) -> e.g. `1` + Boolean = 5, + + Character = 6, + + /// The following PlanetScale type IDs are mapped into Text: + /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) + /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) + Text = 7, + + /// The following PlanetScale type IDs are mapped into Date: + /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) + Date = 8, + + /// The following PlanetScale type IDs are mapped into Time: + /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) + Time = 9, + + /// The following PlanetScale type IDs are mapped into DateTime: + /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) + /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) + DateTime = 10, + + /// The following PlanetScale type IDs are mapped into Json: + /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) + Json = 11, + + /// The following PlanetScale type IDs are mapped into Enum: + /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) + Enum = 12, + + /// The following PlanetScale type IDs are mapped into Bytes: + /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) + /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) + /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) + /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) + Bytes = 13, + + /// The following PlanetScale type IDs are mapped into Set: + /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) + /// This is currently unhandled, and will panic if encountered. + Set = 14, + + /// UUID from postgres-flavored driver adapters is mapped to this type. + Uuid = 15, + + /* + * Scalar arrays + */ + /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) + Int32Array = 64, + + /// Int64 array (INT8_ARRAY in PostgreSQL) + Int64Array = 65, + + /// Float array (FLOAT4_ARRAY in PostgreSQL) + FloatArray = 66, + + /// Double array (FLOAT8_ARRAY in PostgreSQL) + DoubleArray = 67, + + /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) + NumericArray = 68, + + /// Boolean array (BOOL_ARRAY in PostgreSQL) + BooleanArray = 69, + + /// Char array (CHAR_ARRAY in PostgreSQL) + CharacterArray = 70, + + /// Text array (TEXT_ARRAY in PostgreSQL) + TextArray = 71, + + /// Date array (DATE_ARRAY in PostgreSQL) + DateArray = 72, + + /// Time array (TIME_ARRAY in PostgreSQL) + TimeArray = 73, + + /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) + DateTimeArray = 74, + + /// Json array (JSON_ARRAY in PostgreSQL) + JsonArray = 75, + + /// Enum array + EnumArray = 76, + + /// Bytes array (BYTEA_ARRAY in PostgreSQL) + BytesArray = 77, + + /// Uuid array (UUID_ARRAY in PostgreSQL) + UuidArray = 78, + + /* + * Below there are custom types that don't have a 1:1 translation with a quaint::Value. + * enum variant. + */ + /// UnknownNumber is used when the type of the column is a number but of unknown particular type + /// and precision. + /// + /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or + /// COUNT, and it can be mapped to either Int64, or Double + UnknownNumber = 128, +} + +#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] +#[cfg_attr(target_arch = "wasm32", derive(Serialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi))] +#[derive(Debug)] +pub struct Query { + pub sql: String, + pub args: Vec, +} + +#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] +#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[derive(Debug)] +pub struct TransactionOptions { + /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) + /// before opening a transaction, committing, or rollbacking. + pub use_phantom_query: bool, +} From 1eafb3df72aefaf66fd12f9751fd04c0617a7e2d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 14:15:33 +0100 Subject: [PATCH 040/134] feat(driver-adapters): extracted "TryFrom for QuaintResultSet" to "driver_adapters::conversion::js_to_quaint" --- .../src/conversion/js_to_quaint.rs | 687 ++++++++++++++++++ .../driver-adapters/src/conversion/mod.rs | 2 + .../driver-adapters/src/napi/proxy.rs | 684 ----------------- query-engine/driver-adapters/src/types.rs | 2 +- 4 files changed, 690 insertions(+), 685 deletions(-) create mode 100644 query-engine/driver-adapters/src/conversion/js_to_quaint.rs diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs new file mode 100644 index 000000000000..b4d872c3ed2d --- /dev/null +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -0,0 +1,687 @@ +use std::borrow::Cow; +use std::str::FromStr; + +pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; + +use quaint::{ + connector::ResultSet as QuaintResultSet, + error::{Error as QuaintError, ErrorKind}, + Value as QuaintValue, +}; + +// TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. +use bigdecimal::{BigDecimal, FromPrimitive}; +use chrono::{DateTime, Utc}; +use chrono::{NaiveDate, NaiveTime}; + +impl TryFrom for QuaintResultSet { + type Error = quaint::error::Error; + + fn try_from(js_result_set: JSResultSet) -> Result { + let JSResultSet { + rows, + column_names, + column_types, + last_insert_id, + } = js_result_set; + + let mut quaint_rows = Vec::with_capacity(rows.len()); + + for row in rows { + let mut quaint_row = Vec::with_capacity(column_types.len()); + + for (i, row) in row.into_iter().enumerate() { + let column_type = column_types[i]; + let column_name = column_names[i].as_str(); + + quaint_row.push(js_value_to_quaint(row, column_type, column_name)?); + } + + quaint_rows.push(quaint_row); + } + + let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); + let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); + + // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), + // but that's Quaint's ResultSet API and that's how the MySQL connector does it. + // Sqlite, on the other hand, uses a `last_insert_id.unwrap_or(0)` approach. + if let Some(last_insert_id) = last_insert_id { + quaint_result_set.set_last_insert_id(last_insert_id); + } + + Ok(quaint_result_set) + } +} + +fn conversion_error(args: &std::fmt::Arguments) -> QuaintError { + let msg = match args.as_str() { + Some(s) => Cow::Borrowed(s), + None => Cow::Owned(args.to_string()), + }; + QuaintError::builder(ErrorKind::ConversionError(msg)).build() +} + +macro_rules! conversion_error { + ($($arg:tt)*) => { + conversion_error(&format_args!($($arg)*)) + }; +} + +/// Handle data-type conversion from a JSON value to a Quaint value. +/// This is used for most data types, except those that require connector-specific handling, e.g., `ColumnType::Boolean`. +pub fn js_value_to_quaint( + json_value: serde_json::Value, + column_type: ColumnType, + column_name: &str, +) -> quaint::Result> { + let parse_number_as_i64 = |n: &serde_json::Number| { + n.as_i64().ok_or(conversion_error!( + "number must be an integer in column '{column_name}', got '{n}'" + )) + }; + + // Note for the future: it may be worth revisiting how much bloat so many panics with different static + // strings add to the compiled artefact, and in case we should come up with a restricted set of panic + // messages, or even find a way of removing them altogether. + match column_type { + ColumnType::Int32 => match json_value { + serde_json::Value::Number(n) => { + // n.as_i32() is not implemented, so we need to downcast from i64 instead + parse_number_as_i64(&n) + .and_then(|n| -> quaint::Result { + n.try_into() + .map_err(|e| conversion_error!("cannot convert {n} to i32 in column '{column_name}': {e}")) + }) + .map(QuaintValue::int32) + } + serde_json::Value::String(s) => s.parse::().map(QuaintValue::int32).map_err(|e| { + conversion_error!("string-encoded number must be an i32 in column '{column_name}', got {s}: {e}") + }), + serde_json::Value::Null => Ok(QuaintValue::null_int32()), + mismatch => Err(conversion_error!( + "expected an i32 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Int64 => match json_value { + serde_json::Value::Number(n) => parse_number_as_i64(&n).map(QuaintValue::int64), + serde_json::Value::String(s) => s.parse::().map(QuaintValue::int64).map_err(|e| { + conversion_error!("string-encoded number must be an i64 in column '{column_name}', got {s}: {e}") + }), + serde_json::Value::Null => Ok(QuaintValue::null_int64()), + mismatch => Err(conversion_error!( + "expected a string or number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Float => match json_value { + // n.as_f32() is not implemented, so we need to downcast from f64 instead. + // We assume that the JSON value is a valid f32 number, but we check for overflows anyway. + serde_json::Value::Number(n) => n + .as_f64() + .ok_or(conversion_error!( + "number must be a float in column '{column_name}', got {n}" + )) + .and_then(f64_to_f32) + .map(QuaintValue::float), + serde_json::Value::Null => Ok(QuaintValue::null_float()), + mismatch => Err(conversion_error!( + "expected an f32 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Double => match json_value { + serde_json::Value::Number(n) => n.as_f64().map(QuaintValue::double).ok_or(conversion_error!( + "number must be a f64 in column '{column_name}', got {n}" + )), + serde_json::Value::Null => Ok(QuaintValue::null_double()), + mismatch => Err(conversion_error!( + "expected an f64 number in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Numeric => match json_value { + serde_json::Value::String(s) => BigDecimal::from_str(&s).map(QuaintValue::numeric).map_err(|e| { + conversion_error!("invalid numeric value when parsing {s} in column '{column_name}': {e}") + }), + serde_json::Value::Number(n) => n + .as_f64() + .and_then(BigDecimal::from_f64) + .ok_or(conversion_error!( + "number must be an f64 in column '{column_name}', got {n}" + )) + .map(QuaintValue::numeric), + serde_json::Value::Null => Ok(QuaintValue::null_numeric()), + mismatch => Err(conversion_error!( + "expected a string-encoded number in column '{column_name}', found {mismatch}", + )), + }, + ColumnType::Boolean => match json_value { + serde_json::Value::Bool(b) => Ok(QuaintValue::boolean(b)), + serde_json::Value::Null => Ok(QuaintValue::null_boolean()), + serde_json::Value::Number(n) => match n.as_i64() { + Some(0) => Ok(QuaintValue::boolean(false)), + Some(1) => Ok(QuaintValue::boolean(true)), + _ => Err(conversion_error!( + "expected number-encoded boolean to be 0 or 1 in column '{column_name}', got {n}" + )), + }, + serde_json::Value::String(s) => match s.as_str() { + "false" | "FALSE" | "0" => Ok(QuaintValue::boolean(false)), + "true" | "TRUE" | "1" => Ok(QuaintValue::boolean(true)), + _ => Err(conversion_error!( + "expected string-encoded boolean in column '{column_name}', got {s}" + )), + }, + mismatch => Err(conversion_error!( + "expected a boolean in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Character => match json_value { + serde_json::Value::String(s) => match s.chars().next() { + Some(c) => Ok(QuaintValue::character(c)), + None => Ok(QuaintValue::null_character()), + }, + serde_json::Value::Null => Ok(QuaintValue::null_character()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Text => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::text(s)), + serde_json::Value::Null => Ok(QuaintValue::null_text()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Date => match json_value { + serde_json::Value::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d") + .map(QuaintValue::date) + .map_err(|_| conversion_error!("expected a date string in column '{column_name}', got {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_date()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Time => match json_value { + serde_json::Value::String(s) => NaiveTime::parse_from_str(&s, "%H:%M:%S%.f") + .map(QuaintValue::time) + .map_err(|_| conversion_error!("expected a time string in column '{column_name}', got {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_time()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::DateTime => match json_value { + // TODO: change parsing order to prefer RFC3339 + serde_json::Value::String(s) => chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") + .map(|dt| DateTime::from_utc(dt, Utc)) + .or_else(|_| DateTime::parse_from_rfc3339(&s).map(DateTime::::from)) + .map(QuaintValue::datetime) + .map_err(|_| conversion_error!("expected a datetime string in column '{column_name}', found {s}")), + serde_json::Value::Null => Ok(QuaintValue::null_datetime()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Json => { + match json_value { + // DbNull + serde_json::Value::Null => Ok(QuaintValue::null_json()), + // JsonNull + serde_json::Value::String(s) if s == "$__prisma_null" => Ok(QuaintValue::json(serde_json::Value::Null)), + json => Ok(QuaintValue::json(json)), + } + } + ColumnType::Enum => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::enum_variant(s)), + serde_json::Value::Null => Ok(QuaintValue::null_enum()), + mismatch => Err(conversion_error!( + "expected a string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::Bytes => match json_value { + serde_json::Value::String(s) => Ok(QuaintValue::bytes(s.into_bytes())), + serde_json::Value::Array(array) => array + .iter() + .map(|value| value.as_i64().and_then(|maybe_byte| maybe_byte.try_into().ok())) + .collect::>>() + .map(QuaintValue::bytes) + .ok_or(conversion_error!( + "elements of the array in column '{column_name}' must be u8" + )), + serde_json::Value::Null => Ok(QuaintValue::null_bytes()), + mismatch => Err(conversion_error!( + "expected a string or an array in column '{column_name}', found {mismatch}", + )), + }, + ColumnType::Uuid => match json_value { + serde_json::Value::String(s) => uuid::Uuid::parse_str(&s) + .map(QuaintValue::uuid) + .map_err(|_| conversion_error!("Expected a UUID string in column '{column_name}'")), + serde_json::Value::Null => Ok(QuaintValue::null_bytes()), + mismatch => Err(conversion_error!( + "Expected a UUID string in column '{column_name}', found {mismatch}" + )), + }, + ColumnType::UnknownNumber => match json_value { + serde_json::Value::Number(n) => n + .as_i64() + .map(QuaintValue::int64) + .or(n.as_f64().map(QuaintValue::double)) + .ok_or(conversion_error!( + "number must be an i64 or f64 in column '{column_name}', got {n}" + )), + mismatch => Err(conversion_error!( + "expected a either an i64 or a f64 in column '{column_name}', found {mismatch}", + )), + }, + + ColumnType::Int32Array => js_array_to_quaint(ColumnType::Int32, json_value, column_name), + ColumnType::Int64Array => js_array_to_quaint(ColumnType::Int64, json_value, column_name), + ColumnType::FloatArray => js_array_to_quaint(ColumnType::Float, json_value, column_name), + ColumnType::DoubleArray => js_array_to_quaint(ColumnType::Double, json_value, column_name), + ColumnType::NumericArray => js_array_to_quaint(ColumnType::Numeric, json_value, column_name), + ColumnType::BooleanArray => js_array_to_quaint(ColumnType::Boolean, json_value, column_name), + ColumnType::CharacterArray => js_array_to_quaint(ColumnType::Character, json_value, column_name), + ColumnType::TextArray => js_array_to_quaint(ColumnType::Text, json_value, column_name), + ColumnType::DateArray => js_array_to_quaint(ColumnType::Date, json_value, column_name), + ColumnType::TimeArray => js_array_to_quaint(ColumnType::Time, json_value, column_name), + ColumnType::DateTimeArray => js_array_to_quaint(ColumnType::DateTime, json_value, column_name), + ColumnType::JsonArray => js_array_to_quaint(ColumnType::Json, json_value, column_name), + ColumnType::EnumArray => js_array_to_quaint(ColumnType::Enum, json_value, column_name), + ColumnType::BytesArray => js_array_to_quaint(ColumnType::Bytes, json_value, column_name), + ColumnType::UuidArray => js_array_to_quaint(ColumnType::Uuid, json_value, column_name), + + unimplemented => { + todo!("support column type {:?} in column {}", unimplemented, column_name) + } + } +} + +fn js_array_to_quaint( + base_type: ColumnType, + json_value: serde_json::Value, + column_name: &str, +) -> quaint::Result> { + match json_value { + serde_json::Value::Array(array) => Ok(QuaintValue::array( + array + .into_iter() + .enumerate() + .map(|(index, elem)| js_value_to_quaint(elem, base_type, &format!("{column_name}[{index}]"))) + .collect::>>()?, + )), + serde_json::Value::Null => Ok(QuaintValue::null_array()), + mismatch => Err(conversion_error!( + "expected an array in column '{column_name}', found {mismatch}", + )), + } +} + +/// Coerce a `f64` to a `f32`, asserting that the conversion is lossless. +/// Note that, when overflow occurs during conversion, the result is `infinity`. +fn f64_to_f32(x: f64) -> quaint::Result { + let y = x as f32; + + if x.is_finite() == y.is_finite() { + Ok(y) + } else { + Err(conversion_error!("f32 overflow during conversion")) + } +} + +#[cfg(test)] +mod proxy_test { + use num_bigint::BigInt; + use serde_json::json; + + use super::*; + + #[track_caller] + fn test_null<'a, T: Into>>(quaint_none: T, column_type: ColumnType) { + let json_value = serde_json::Value::Null; + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, quaint_none.into()); + } + + #[test] + fn js_value_int32_to_quaint() { + let column_type = ColumnType::Int32; + + // null + test_null(QuaintValue::null_int32(), column_type); + + // 0 + let n: i32 = 0; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // max + let n: i32 = i32::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // min + let n: i32 = i32::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + + // string-encoded + let n = i32::MAX; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int32(n)); + } + + #[test] + fn js_value_int64_to_quaint() { + let column_type = ColumnType::Int64; + + // null + test_null(QuaintValue::null_int64(), column_type); + + // 0 + let n: i64 = 0; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // max + let n: i64 = i64::MAX; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // min + let n: i64 = i64::MIN; + let json_value = serde_json::Value::String(n.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + + // number-encoded + let n: i64 = (1 << 53) - 1; // max JS safe integer + let json_value = serde_json::Value::Number(serde_json::Number::from(n)); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::int64(n)); + } + + #[test] + fn js_value_float_to_quaint() { + let column_type = ColumnType::Float; + + // null + test_null(QuaintValue::null_float(), column_type); + + // 0 + let n: f32 = 0.0; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + + // max + let n: f32 = f32::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + + // min + let n: f32 = f32::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::float(n)); + } + + #[test] + fn js_value_double_to_quaint() { + let column_type = ColumnType::Double; + + // null + test_null(QuaintValue::null_double(), column_type); + + // 0 + let n: f64 = 0.0; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + + // max + let n: f64 = f64::MAX; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + + // min + let n: f64 = f64::MIN; + let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::double(n)); + } + + #[test] + fn js_value_numeric_to_quaint() { + let column_type = ColumnType::Numeric; + + // null + test_null(QuaintValue::null_numeric(), column_type); + + let n_as_string = "1234.99"; + let decimal = BigDecimal::new(BigInt::parse_bytes(b"123499", 10).unwrap(), 2); + + let json_value = serde_json::Value::String(n_as_string.into()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::numeric(decimal)); + + let n_as_string = "1234.999999"; + let decimal = BigDecimal::new(BigInt::parse_bytes(b"1234999999", 10).unwrap(), 6); + + let json_value = serde_json::Value::String(n_as_string.into()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::numeric(decimal)); + } + + #[test] + fn js_value_boolean_to_quaint() { + let column_type = ColumnType::Boolean; + + // null + test_null(QuaintValue::null_boolean(), column_type); + + // true + for truthy_value in [json!(true), json!(1), json!("true"), json!("TRUE"), json!("1")] { + let quaint_value = js_value_to_quaint(truthy_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::boolean(true)); + } + + // false + for falsy_value in [json!(false), json!(0), json!("false"), json!("FALSE"), json!("0")] { + let quaint_value = js_value_to_quaint(falsy_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::boolean(false)); + } + } + + #[test] + fn js_value_char_to_quaint() { + let column_type = ColumnType::Character; + + // null + test_null(QuaintValue::null_character(), column_type); + + let c = 'c'; + let json_value = serde_json::Value::String(c.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::character(c)); + } + + #[test] + fn js_value_text_to_quaint() { + let column_type = ColumnType::Text; + + // null + test_null(QuaintValue::null_text(), column_type); + + let s = "some text"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::text(s)); + } + + #[test] + fn js_value_date_to_quaint() { + let column_type = ColumnType::Date; + + // null + test_null(QuaintValue::null_date(), column_type); + + let s = "2023-01-01"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + assert_eq!(quaint_value, QuaintValue::date(date)); + } + + #[test] + fn js_value_time_to_quaint() { + let column_type = ColumnType::Time; + + // null + test_null(QuaintValue::null_time(), column_type); + + let s = "23:59:59"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + let time: NaiveTime = NaiveTime::from_hms_opt(23, 59, 59).unwrap(); + assert_eq!(quaint_value, QuaintValue::time(time)); + + let s = "13:02:20.321"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); + assert_eq!(quaint_value, QuaintValue::time(time)); + } + + #[test] + fn js_value_datetime_to_quaint() { + let column_type = ColumnType::DateTime; + + // null + test_null(QuaintValue::null_datetime(), column_type); + + let s = "2023-01-01 23:59:59.415"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(23, 59, 59, 415) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + + let s = "2023-01-01 23:59:59.123456"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_micro_opt(23, 59, 59, 123_456) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + + let s = "2023-01-01 23:59:59"; + let json_value = serde_json::Value::String(s.to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(23, 59, 59, 0) + .unwrap(); + let datetime = DateTime::from_utc(datetime, Utc); + assert_eq!(quaint_value, QuaintValue::datetime(datetime)); + } + + #[test] + fn js_value_json_to_quaint() { + let column_type = ColumnType::Json; + + // null + test_null(QuaintValue::null_json(), column_type); + + let json = json!({ + "key": "value", + "nested": [ + true, + false, + 1, + null + ] + }); + let json_value = json.clone(); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::json(json.clone())); + } + + #[test] + fn js_value_enum_to_quaint() { + let column_type = ColumnType::Enum; + + // null + test_null(QuaintValue::null_enum(), column_type); + + let s = "some enum variant"; + let json_value = serde_json::Value::String(s.to_string()); + + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::enum_variant(s)); + } + + #[test] + fn js_int32_array_to_quaint() { + let column_type = ColumnType::Int32Array; + test_null(QuaintValue::null_array(), column_type); + + let json_value = json!([1, 2, 3]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + assert_eq!( + quaint_value, + QuaintValue::array(vec![ + QuaintValue::int32(1), + QuaintValue::int32(2), + QuaintValue::int32(3) + ]) + ); + + let json_value = json!([1, 2, {}]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); + + assert_eq!( + quaint_value.err().unwrap().to_string(), + "Conversion failed: expected an i32 number in column 'column_name[2]', found {}" + ); + } + + #[test] + fn js_text_array_to_quaint() { + let column_type = ColumnType::TextArray; + test_null(QuaintValue::null_array(), column_type); + + let json_value = json!(["hi", "there"]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + + assert_eq!( + quaint_value, + QuaintValue::array(vec![QuaintValue::text("hi"), QuaintValue::text("there"),]) + ); + + let json_value = json!([10]); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); + + assert_eq!( + quaint_value.err().unwrap().to_string(), + "Conversion failed: expected a string in column 'column_name[0]', found 10" + ); + } +} diff --git a/query-engine/driver-adapters/src/conversion/mod.rs b/query-engine/driver-adapters/src/conversion/mod.rs index 5173b2349bab..3ef41fed903e 100644 --- a/query-engine/driver-adapters/src/conversion/mod.rs +++ b/query-engine/driver-adapters/src/conversion/mod.rs @@ -1,7 +1,9 @@ pub(crate) mod js_arg; +pub(crate) mod js_to_quaint; pub(crate) mod mysql; pub(crate) mod postgres; pub(crate) mod sqlite; pub use js_arg::JSArg; +pub use js_to_quaint::*; diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index 0677d279a1f4..65a86109e338 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -1,6 +1,3 @@ -use std::borrow::Cow; -use std::str::FromStr; - pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use super::async_js_function::AsyncJsFunction; @@ -8,16 +5,6 @@ use super::transaction::JsTransaction; use metrics::increment_gauge; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; use napi::{JsObject, JsString}; -use quaint::connector::ResultSet as QuaintResultSet; -use quaint::{ - error::{Error as QuaintError, ErrorKind}, - Value as QuaintValue, -}; - -// TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. -use bigdecimal::{BigDecimal, FromPrimitive}; -use chrono::{DateTime, Utc}; -use chrono::{NaiveDate, NaiveTime}; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for /// querying and executing SQL (i.e. a client connector). The Proxy uses NAPI ThreadSafeFunction to @@ -56,308 +43,6 @@ pub(crate) struct TransactionProxy { dispose: ThreadsafeFunction<(), ErrorStrategy::Fatal>, } -fn conversion_error(args: &std::fmt::Arguments) -> QuaintError { - let msg = match args.as_str() { - Some(s) => Cow::Borrowed(s), - None => Cow::Owned(args.to_string()), - }; - QuaintError::builder(ErrorKind::ConversionError(msg)).build() -} - -macro_rules! conversion_error { - ($($arg:tt)*) => { - conversion_error(&format_args!($($arg)*)) - }; -} - -/// Handle data-type conversion from a JSON value to a Quaint value. -/// This is used for most data types, except those that require connector-specific handling, e.g., `ColumnType::Boolean`. -fn js_value_to_quaint( - json_value: serde_json::Value, - column_type: ColumnType, - column_name: &str, -) -> quaint::Result> { - let parse_number_as_i64 = |n: &serde_json::Number| { - n.as_i64().ok_or(conversion_error!( - "number must be an integer in column '{column_name}', got '{n}'" - )) - }; - - // Note for the future: it may be worth revisiting how much bloat so many panics with different static - // strings add to the compiled artefact, and in case we should come up with a restricted set of panic - // messages, or even find a way of removing them altogether. - match column_type { - ColumnType::Int32 => match json_value { - serde_json::Value::Number(n) => { - // n.as_i32() is not implemented, so we need to downcast from i64 instead - parse_number_as_i64(&n) - .and_then(|n| -> quaint::Result { - n.try_into() - .map_err(|e| conversion_error!("cannot convert {n} to i32 in column '{column_name}': {e}")) - }) - .map(QuaintValue::int32) - } - serde_json::Value::String(s) => s.parse::().map(QuaintValue::int32).map_err(|e| { - conversion_error!("string-encoded number must be an i32 in column '{column_name}', got {s}: {e}") - }), - serde_json::Value::Null => Ok(QuaintValue::null_int32()), - mismatch => Err(conversion_error!( - "expected an i32 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Int64 => match json_value { - serde_json::Value::Number(n) => parse_number_as_i64(&n).map(QuaintValue::int64), - serde_json::Value::String(s) => s.parse::().map(QuaintValue::int64).map_err(|e| { - conversion_error!("string-encoded number must be an i64 in column '{column_name}', got {s}: {e}") - }), - serde_json::Value::Null => Ok(QuaintValue::null_int64()), - mismatch => Err(conversion_error!( - "expected a string or number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Float => match json_value { - // n.as_f32() is not implemented, so we need to downcast from f64 instead. - // We assume that the JSON value is a valid f32 number, but we check for overflows anyway. - serde_json::Value::Number(n) => n - .as_f64() - .ok_or(conversion_error!( - "number must be a float in column '{column_name}', got {n}" - )) - .and_then(f64_to_f32) - .map(QuaintValue::float), - serde_json::Value::Null => Ok(QuaintValue::null_float()), - mismatch => Err(conversion_error!( - "expected an f32 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Double => match json_value { - serde_json::Value::Number(n) => n.as_f64().map(QuaintValue::double).ok_or(conversion_error!( - "number must be a f64 in column '{column_name}', got {n}" - )), - serde_json::Value::Null => Ok(QuaintValue::null_double()), - mismatch => Err(conversion_error!( - "expected an f64 number in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Numeric => match json_value { - serde_json::Value::String(s) => BigDecimal::from_str(&s).map(QuaintValue::numeric).map_err(|e| { - conversion_error!("invalid numeric value when parsing {s} in column '{column_name}': {e}") - }), - serde_json::Value::Number(n) => n - .as_f64() - .and_then(BigDecimal::from_f64) - .ok_or(conversion_error!( - "number must be an f64 in column '{column_name}', got {n}" - )) - .map(QuaintValue::numeric), - serde_json::Value::Null => Ok(QuaintValue::null_numeric()), - mismatch => Err(conversion_error!( - "expected a string-encoded number in column '{column_name}', found {mismatch}", - )), - }, - ColumnType::Boolean => match json_value { - serde_json::Value::Bool(b) => Ok(QuaintValue::boolean(b)), - serde_json::Value::Null => Ok(QuaintValue::null_boolean()), - serde_json::Value::Number(n) => match n.as_i64() { - Some(0) => Ok(QuaintValue::boolean(false)), - Some(1) => Ok(QuaintValue::boolean(true)), - _ => Err(conversion_error!( - "expected number-encoded boolean to be 0 or 1 in column '{column_name}', got {n}" - )), - }, - serde_json::Value::String(s) => match s.as_str() { - "false" | "FALSE" | "0" => Ok(QuaintValue::boolean(false)), - "true" | "TRUE" | "1" => Ok(QuaintValue::boolean(true)), - _ => Err(conversion_error!( - "expected string-encoded boolean in column '{column_name}', got {s}" - )), - }, - mismatch => Err(conversion_error!( - "expected a boolean in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Character => match json_value { - serde_json::Value::String(s) => match s.chars().next() { - Some(c) => Ok(QuaintValue::character(c)), - None => Ok(QuaintValue::null_character()), - }, - serde_json::Value::Null => Ok(QuaintValue::null_character()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Text => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::text(s)), - serde_json::Value::Null => Ok(QuaintValue::null_text()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Date => match json_value { - serde_json::Value::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d") - .map(QuaintValue::date) - .map_err(|_| conversion_error!("expected a date string in column '{column_name}', got {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_date()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Time => match json_value { - serde_json::Value::String(s) => NaiveTime::parse_from_str(&s, "%H:%M:%S%.f") - .map(QuaintValue::time) - .map_err(|_| conversion_error!("expected a time string in column '{column_name}', got {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_time()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::DateTime => match json_value { - // TODO: change parsing order to prefer RFC3339 - serde_json::Value::String(s) => chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") - .map(|dt| DateTime::from_utc(dt, Utc)) - .or_else(|_| DateTime::parse_from_rfc3339(&s).map(DateTime::::from)) - .map(QuaintValue::datetime) - .map_err(|_| conversion_error!("expected a datetime string in column '{column_name}', found {s}")), - serde_json::Value::Null => Ok(QuaintValue::null_datetime()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Json => { - match json_value { - // DbNull - serde_json::Value::Null => Ok(QuaintValue::null_json()), - // JsonNull - serde_json::Value::String(s) if s == "$__prisma_null" => Ok(QuaintValue::json(serde_json::Value::Null)), - json => Ok(QuaintValue::json(json)), - } - } - ColumnType::Enum => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::enum_variant(s)), - serde_json::Value::Null => Ok(QuaintValue::null_enum()), - mismatch => Err(conversion_error!( - "expected a string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::Bytes => match json_value { - serde_json::Value::String(s) => Ok(QuaintValue::bytes(s.into_bytes())), - serde_json::Value::Array(array) => array - .iter() - .map(|value| value.as_i64().and_then(|maybe_byte| maybe_byte.try_into().ok())) - .collect::>>() - .map(QuaintValue::bytes) - .ok_or(conversion_error!( - "elements of the array in column '{column_name}' must be u8" - )), - serde_json::Value::Null => Ok(QuaintValue::null_bytes()), - mismatch => Err(conversion_error!( - "expected a string or an array in column '{column_name}', found {mismatch}", - )), - }, - ColumnType::Uuid => match json_value { - serde_json::Value::String(s) => uuid::Uuid::parse_str(&s) - .map(QuaintValue::uuid) - .map_err(|_| conversion_error!("Expected a UUID string in column '{column_name}'")), - serde_json::Value::Null => Ok(QuaintValue::null_bytes()), - mismatch => Err(conversion_error!( - "Expected a UUID string in column '{column_name}', found {mismatch}" - )), - }, - ColumnType::UnknownNumber => match json_value { - serde_json::Value::Number(n) => n - .as_i64() - .map(QuaintValue::int64) - .or(n.as_f64().map(QuaintValue::double)) - .ok_or(conversion_error!( - "number must be an i64 or f64 in column '{column_name}', got {n}" - )), - mismatch => Err(conversion_error!( - "expected a either an i64 or a f64 in column '{column_name}', found {mismatch}", - )), - }, - - ColumnType::Int32Array => js_array_to_quaint(ColumnType::Int32, json_value, column_name), - ColumnType::Int64Array => js_array_to_quaint(ColumnType::Int64, json_value, column_name), - ColumnType::FloatArray => js_array_to_quaint(ColumnType::Float, json_value, column_name), - ColumnType::DoubleArray => js_array_to_quaint(ColumnType::Double, json_value, column_name), - ColumnType::NumericArray => js_array_to_quaint(ColumnType::Numeric, json_value, column_name), - ColumnType::BooleanArray => js_array_to_quaint(ColumnType::Boolean, json_value, column_name), - ColumnType::CharacterArray => js_array_to_quaint(ColumnType::Character, json_value, column_name), - ColumnType::TextArray => js_array_to_quaint(ColumnType::Text, json_value, column_name), - ColumnType::DateArray => js_array_to_quaint(ColumnType::Date, json_value, column_name), - ColumnType::TimeArray => js_array_to_quaint(ColumnType::Time, json_value, column_name), - ColumnType::DateTimeArray => js_array_to_quaint(ColumnType::DateTime, json_value, column_name), - ColumnType::JsonArray => js_array_to_quaint(ColumnType::Json, json_value, column_name), - ColumnType::EnumArray => js_array_to_quaint(ColumnType::Enum, json_value, column_name), - ColumnType::BytesArray => js_array_to_quaint(ColumnType::Bytes, json_value, column_name), - ColumnType::UuidArray => js_array_to_quaint(ColumnType::Uuid, json_value, column_name), - - unimplemented => { - todo!("support column type {:?} in column {}", unimplemented, column_name) - } - } -} - -fn js_array_to_quaint( - base_type: ColumnType, - json_value: serde_json::Value, - column_name: &str, -) -> quaint::Result> { - match json_value { - serde_json::Value::Array(array) => Ok(QuaintValue::array( - array - .into_iter() - .enumerate() - .map(|(index, elem)| js_value_to_quaint(elem, base_type, &format!("{column_name}[{index}]"))) - .collect::>>()?, - )), - serde_json::Value::Null => Ok(QuaintValue::null_array()), - mismatch => Err(conversion_error!( - "expected an array in column '{column_name}', found {mismatch}", - )), - } -} - -impl TryFrom for QuaintResultSet { - type Error = quaint::error::Error; - - fn try_from(js_result_set: JSResultSet) -> Result { - let JSResultSet { - rows, - column_names, - column_types, - last_insert_id, - } = js_result_set; - - let mut quaint_rows = Vec::with_capacity(rows.len()); - - for row in rows { - let mut quaint_row = Vec::with_capacity(column_types.len()); - - for (i, row) in row.into_iter().enumerate() { - let column_type = column_types[i]; - let column_name = column_names[i].as_str(); - - quaint_row.push(js_value_to_quaint(row, column_type, column_name)?); - } - - quaint_rows.push(quaint_row); - } - - let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); - let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); - - // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), - // but that's Quaint's ResultSet API and that's how the MySQL connector does it. - // Sqlite, on the other hand, uses a `last_insert_id.unwrap_or(0)` approach. - if let Some(last_insert_id) = last_insert_id { - quaint_result_set.set_last_insert_id(last_insert_id); - } - - Ok(quaint_result_set) - } -} - impl CommonProxy { pub fn new(object: &JsObject) -> napi::Result { let flavour: JsString = object.get_named_property("flavour")?; @@ -432,372 +117,3 @@ impl Drop for TransactionProxy { .call((), napi::threadsafe_function::ThreadsafeFunctionCallMode::NonBlocking); } } - -/// Coerce a `f64` to a `f32`, asserting that the conversion is lossless. -/// Note that, when overflow occurs during conversion, the result is `infinity`. -fn f64_to_f32(x: f64) -> quaint::Result { - let y = x as f32; - - if x.is_finite() == y.is_finite() { - Ok(y) - } else { - Err(conversion_error!("f32 overflow during conversion")) - } -} -#[cfg(test)] -mod proxy_test { - use num_bigint::BigInt; - use serde_json::json; - - use super::*; - - #[track_caller] - fn test_null<'a, T: Into>>(quaint_none: T, column_type: ColumnType) { - let json_value = serde_json::Value::Null; - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, quaint_none.into()); - } - - #[test] - fn js_value_int32_to_quaint() { - let column_type = ColumnType::Int32; - - // null - test_null(QuaintValue::null_int32(), column_type); - - // 0 - let n: i32 = 0; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // max - let n: i32 = i32::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // min - let n: i32 = i32::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - - // string-encoded - let n = i32::MAX; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int32(n)); - } - - #[test] - fn js_value_int64_to_quaint() { - let column_type = ColumnType::Int64; - - // null - test_null(QuaintValue::null_int64(), column_type); - - // 0 - let n: i64 = 0; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // max - let n: i64 = i64::MAX; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // min - let n: i64 = i64::MIN; - let json_value = serde_json::Value::String(n.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - - // number-encoded - let n: i64 = (1 << 53) - 1; // max JS safe integer - let json_value = serde_json::Value::Number(serde_json::Number::from(n)); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::int64(n)); - } - - #[test] - fn js_value_float_to_quaint() { - let column_type = ColumnType::Float; - - // null - test_null(QuaintValue::null_float(), column_type); - - // 0 - let n: f32 = 0.0; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - - // max - let n: f32 = f32::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - - // min - let n: f32 = f32::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n.into()).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::float(n)); - } - - #[test] - fn js_value_double_to_quaint() { - let column_type = ColumnType::Double; - - // null - test_null(QuaintValue::null_double(), column_type); - - // 0 - let n: f64 = 0.0; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - - // max - let n: f64 = f64::MAX; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - - // min - let n: f64 = f64::MIN; - let json_value = serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::double(n)); - } - - #[test] - fn js_value_numeric_to_quaint() { - let column_type = ColumnType::Numeric; - - // null - test_null(QuaintValue::null_numeric(), column_type); - - let n_as_string = "1234.99"; - let decimal = BigDecimal::new(BigInt::parse_bytes(b"123499", 10).unwrap(), 2); - - let json_value = serde_json::Value::String(n_as_string.into()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::numeric(decimal)); - - let n_as_string = "1234.999999"; - let decimal = BigDecimal::new(BigInt::parse_bytes(b"1234999999", 10).unwrap(), 6); - - let json_value = serde_json::Value::String(n_as_string.into()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::numeric(decimal)); - } - - #[test] - fn js_value_boolean_to_quaint() { - let column_type = ColumnType::Boolean; - - // null - test_null(QuaintValue::null_boolean(), column_type); - - // true - for truthy_value in [json!(true), json!(1), json!("true"), json!("TRUE"), json!("1")] { - let quaint_value = js_value_to_quaint(truthy_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::boolean(true)); - } - - // false - for falsy_value in [json!(false), json!(0), json!("false"), json!("FALSE"), json!("0")] { - let quaint_value = js_value_to_quaint(falsy_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::boolean(false)); - } - } - - #[test] - fn js_value_char_to_quaint() { - let column_type = ColumnType::Character; - - // null - test_null(QuaintValue::null_character(), column_type); - - let c = 'c'; - let json_value = serde_json::Value::String(c.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::character(c)); - } - - #[test] - fn js_value_text_to_quaint() { - let column_type = ColumnType::Text; - - // null - test_null(QuaintValue::null_text(), column_type); - - let s = "some text"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::text(s)); - } - - #[test] - fn js_value_date_to_quaint() { - let column_type = ColumnType::Date; - - // null - test_null(QuaintValue::null_date(), column_type); - - let s = "2023-01-01"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); - assert_eq!(quaint_value, QuaintValue::date(date)); - } - - #[test] - fn js_value_time_to_quaint() { - let column_type = ColumnType::Time; - - // null - test_null(QuaintValue::null_time(), column_type); - - let s = "23:59:59"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_opt(23, 59, 59).unwrap(); - assert_eq!(quaint_value, QuaintValue::time(time)); - - let s = "13:02:20.321"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - let time: NaiveTime = NaiveTime::from_hms_milli_opt(13, 02, 20, 321).unwrap(); - assert_eq!(quaint_value, QuaintValue::time(time)); - } - - #[test] - fn js_value_datetime_to_quaint() { - let column_type = ColumnType::DateTime; - - // null - test_null(QuaintValue::null_datetime(), column_type); - - let s = "2023-01-01 23:59:59.415"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 415) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - - let s = "2023-01-01 23:59:59.123456"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_micro_opt(23, 59, 59, 123_456) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - - let s = "2023-01-01 23:59:59"; - let json_value = serde_json::Value::String(s.to_string()); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - let datetime = NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 0) - .unwrap(); - let datetime = DateTime::from_utc(datetime, Utc); - assert_eq!(quaint_value, QuaintValue::datetime(datetime)); - } - - #[test] - fn js_value_json_to_quaint() { - let column_type = ColumnType::Json; - - // null - test_null(QuaintValue::null_json(), column_type); - - let json = json!({ - "key": "value", - "nested": [ - true, - false, - 1, - null - ] - }); - let json_value = json.clone(); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::json(json.clone())); - } - - #[test] - fn js_value_enum_to_quaint() { - let column_type = ColumnType::Enum; - - // null - test_null(QuaintValue::null_enum(), column_type); - - let s = "some enum variant"; - let json_value = serde_json::Value::String(s.to_string()); - - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - assert_eq!(quaint_value, QuaintValue::enum_variant(s)); - } - - #[test] - fn js_int32_array_to_quaint() { - let column_type = ColumnType::Int32Array; - test_null(QuaintValue::null_array(), column_type); - - let json_value = json!([1, 2, 3]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - assert_eq!( - quaint_value, - QuaintValue::array(vec![ - QuaintValue::int32(1), - QuaintValue::int32(2), - QuaintValue::int32(3) - ]) - ); - - let json_value = json!([1, 2, {}]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); - - assert_eq!( - quaint_value.err().unwrap().to_string(), - "Conversion failed: expected an i32 number in column 'column_name[2]', found {}" - ); - } - - #[test] - fn js_text_array_to_quaint() { - let column_type = ColumnType::TextArray; - test_null(QuaintValue::null_array(), column_type); - - let json_value = json!(["hi", "there"]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); - - assert_eq!( - quaint_value, - QuaintValue::array(vec![QuaintValue::text("hi"), QuaintValue::text("there"),]) - ); - - let json_value = json!([10]); - let quaint_value = js_value_to_quaint(json_value, column_type, "column_name"); - - assert_eq!( - quaint_value.err().unwrap().to_string(), - "Conversion failed: expected a string in column 'column_name[0]', found 10" - ); - } -} diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 9fa2c63b4ffc..ab12bf45f05a 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -42,7 +42,7 @@ impl JSResultSet { } #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] -#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] #[derive(Debug)] pub enum ColumnType { From cc117e29a55958f223ad4e31fe4576973af97316 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 16 Nov 2023 23:27:00 +0100 Subject: [PATCH 041/134] feat(driver-adapters): allow feature-complete Wasm compilation of "driver-adapters" --- Cargo.lock | 14 + psl/psl-core/Cargo.toml | 3 + psl/psl-core/src/datamodel_connector.rs | 4 +- query-engine/driver-adapters/Cargo.toml | 4 +- query-engine/driver-adapters/src/types.rs | 6 +- .../src/wasm/async_js_function.rs | 63 +++- .../driver-adapters/src/wasm/conversion.rs | 1 + .../driver-adapters/src/wasm/error.rs | 11 + query-engine/driver-adapters/src/wasm/mod.rs | 7 + .../driver-adapters/src/wasm/proxy.rs | 125 +++++++ .../driver-adapters/src/wasm/queryable.rs | 324 ++++++++++++++++++ .../driver-adapters/src/wasm/send_future.rs | 24 ++ .../driver-adapters/src/wasm/transaction.rs | 132 +++++++ 13 files changed, 701 insertions(+), 17 deletions(-) create mode 100644 query-engine/driver-adapters/src/wasm/conversion.rs create mode 100644 query-engine/driver-adapters/src/wasm/error.rs create mode 100644 query-engine/driver-adapters/src/wasm/proxy.rs create mode 100644 query-engine/driver-adapters/src/wasm/queryable.rs create mode 100644 query-engine/driver-adapters/src/wasm/send_future.rs create mode 100644 query-engine/driver-adapters/src/wasm/transaction.rs diff --git a/Cargo.lock b/Cargo.lock index 627ecb5fb08c..b540e56e6cc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1097,6 +1097,7 @@ dependencies = [ "async-trait", "bigdecimal", "chrono", + "ducktor", "expect-test", "futures", "js-sys", @@ -1105,6 +1106,7 @@ dependencies = [ "napi-derive", "num-bigint", "once_cell", + "pin-project", "psl", "quaint", "serde", @@ -1119,6 +1121,17 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "ducktor" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c421abf6328bda65f53e6a76ee9837fd197b23bdfdbcebc4d7917dfaa1cf88ae" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "either" version = "1.9.0" @@ -3523,6 +3536,7 @@ dependencies = [ "serde", "serde_json", "url", + "wasm-bindgen", ] [[package]] diff --git a/psl/psl-core/Cargo.toml b/psl/psl-core/Cargo.toml index 0d4bea39b84e..5cc959da9f33 100644 --- a/psl/psl-core/Cargo.toml +++ b/psl/psl-core/Cargo.toml @@ -22,3 +22,6 @@ indoc.workspace = true # For the connector API. lsp-types = "0.91.1" url = "2.2.1" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen.workspace = true \ No newline at end of file diff --git a/psl/psl-core/src/datamodel_connector.rs b/psl/psl-core/src/datamodel_connector.rs index 72671e06688f..242f0df20b7c 100644 --- a/psl/psl-core/src/datamodel_connector.rs +++ b/psl/psl-core/src/datamodel_connector.rs @@ -361,8 +361,10 @@ pub trait Connector: Send + Sync { } } -#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen::prelude::wasm_bindgen)] +#[derive(Copy, Clone, Debug, PartialEq, Default, serde::Deserialize)] pub enum Flavour { + #[default] Cockroach, Mongo, Sqlserver, diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index cb86a96ee7f9..29697cdc95c4 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -31,9 +31,11 @@ napi-derive.workspace = true quaint.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] -js-sys.workspace = true quaint = { path = "../../quaint" } +js-sys.workspace = true serde-wasm-bindgen.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tsify.workspace = true +ducktor = "0.1.0" +pin-project = "1" diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index ab12bf45f05a..4f494c1bc092 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] #[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] -#[derive(Debug)] +#[derive(Debug, Default)] pub struct JSResultSet { pub column_types: Vec, pub column_names: Vec, @@ -180,7 +180,7 @@ pub enum ColumnType { #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] #[cfg_attr(target_arch = "wasm32", derive(Serialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi))] -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Query { pub sql: String, pub args: Vec, @@ -189,7 +189,7 @@ pub struct Query { #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] #[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] -#[derive(Debug)] +#[derive(Debug, Default)] pub struct TransactionOptions { /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) /// before opening a transaction, committing, or rollbacking. diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 8e8d6958cce9..66cf2a39ae05 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -1,31 +1,70 @@ use js_sys::{Function as JsFunction, Object as JsObject, Promise as JsPromise}; use serde::{de::DeserializeOwned, Serialize}; use std::marker::PhantomData; +use wasm_bindgen::convert::{FromWasmAbi, WasmAbi}; +use wasm_bindgen::describe::WasmDescribe; use wasm_bindgen::{prelude::wasm_bindgen, JsError, JsValue}; use wasm_bindgen_futures::JsFuture; +use super::error::into_quaint_error; + type JsResult = core::result::Result; +#[derive(Clone, Default)] pub(crate) struct AsyncJsFunction where - ArgType: Serialize + 'static, - ReturnType: DeserializeOwned + 'static, + ArgType: Serialize, + ReturnType: DeserializeOwned, { - threadsafe_fn: JsFunction, + pub threadsafe_fn: JsFunction, + _phantom_arg: PhantomData, _phantom_return: PhantomData, } -impl AsyncJsFunction +impl AsyncJsFunction +where + T: Serialize, + R: DeserializeOwned, +{ + pub async fn call(&self, arg1: T) -> quaint::Result { + let call_internal = async { + let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsError::from(&err))?; + let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; + let future = JsFuture::from(JsPromise::from(promise)); + let value = future.await?; + serde_wasm_bindgen::from_value(value).map_err(|err| JsValue::from(err)) + }; + + match call_internal.await { + Ok(result) => Ok(result), + Err(err) => Err(into_quaint_error(err)), + } + } +} + +impl WasmDescribe for AsyncJsFunction where - ArgType: Serialize + 'static, - ReturnType: DeserializeOwned + 'static, + ArgType: Serialize, + ReturnType: DeserializeOwned, { - async fn call(&self, arg1: ArgType) -> JsResult { - let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsError::from(&err))?; - let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; - let future = JsFuture::from(JsPromise::from(promise)); - let value = future.await?; - serde_wasm_bindgen::from_value(value).map_err(|err| JsValue::from(err)) + fn describe() { + JsFunction::describe(); + } +} + +impl FromWasmAbi for AsyncJsFunction +where + ArgType: Serialize, + ReturnType: DeserializeOwned, +{ + type Abi = ::Abi; + + unsafe fn from_abi(js: Self::Abi) -> Self { + Self { + threadsafe_fn: JsFunction::from_abi(js), + _phantom_arg: PhantomData:: {}, + _phantom_return: PhantomData:: {}, + } } } diff --git a/query-engine/driver-adapters/src/wasm/conversion.rs b/query-engine/driver-adapters/src/wasm/conversion.rs new file mode 100644 index 000000000000..9cb5202cda45 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/conversion.rs @@ -0,0 +1 @@ +pub(crate) use crate::conversion::{mysql, postgres, sqlite, JSArg}; diff --git a/query-engine/driver-adapters/src/wasm/error.rs b/query-engine/driver-adapters/src/wasm/error.rs new file mode 100644 index 000000000000..e0b588794302 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/error.rs @@ -0,0 +1,11 @@ +use quaint::error::Error as QuaintError; +use wasm_bindgen::JsValue; + +type WasmError = JsValue; + +/// transforms a Wasm error into a Quaint error +pub(crate) fn into_quaint_error(wasm_err: WasmError) -> QuaintError { + let status = "WASM_ERROR".to_string(); + let reason = wasm_err.as_string().unwrap_or_else(|| "unknown error".to_string()); + QuaintError::raw_connector_error(status, reason) +} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 92509cb18c3f..8636204577c9 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -1,3 +1,10 @@ //! Query Engine Driver Adapters: `wasm`-specific implementation. mod async_js_function; +mod conversion; +mod error; +mod proxy; +mod queryable; +mod send_future; +mod transaction; +pub use queryable::{from_wasm, JsQueryable}; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs new file mode 100644 index 000000000000..5d2b58d3e43d --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -0,0 +1,125 @@ +use ducktor::FromJsValue as DuckType; +use futures::Future; +use js_sys::{Function as JsFunction, Object as JsObject, Promise as JsPromise}; + +use super::{async_js_function::AsyncJsFunction, send_future::SendFuture, transaction::JsTransaction}; +pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; +use metrics::increment_gauge; +use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; + +type JsResult = core::result::Result; + +/// Proxy is a struct wrapping a javascript object that exhibits basic primitives for +/// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to +/// invoke the code within the node runtime that implements the client connector. +#[wasm_bindgen(getter_with_clone)] +#[derive(DuckType, Default)] +pub(crate) struct CommonProxy { + /// Execute a query given as SQL, interpolating the given parameters. + query_raw: AsyncJsFunction, + + /// Execute a query given as SQL, interpolating the given parameters and + /// returning the number of affected rows. + execute_raw: AsyncJsFunction, + + /// Return the flavour for this driver. + pub(crate) flavour: String, +} + +/// This is a JS proxy for accessing the methods specific to top level +/// JS driver objects +#[wasm_bindgen(getter_with_clone)] +#[derive(DuckType)] +pub(crate) struct DriverProxy { + start_transaction: AsyncJsFunction<(), JsTransaction>, +} + +/// This a JS proxy for accessing the methods, specific +/// to JS transaction objects +#[wasm_bindgen(getter_with_clone)] +#[derive(DuckType, Default)] +pub(crate) struct TransactionProxy { + /// transaction options + options: TransactionOptions, + + /// commit transaction + commit: AsyncJsFunction<(), ()>, + + /// rollback transaction + rollback: AsyncJsFunction<(), ()>, + + /// dispose transaction, cleanup logic executed at the end of the transaction lifecycle + /// on drop. + dispose: JsFunction, +} + +impl CommonProxy { + pub fn new(object: &JsObject) -> Self { + CommonProxy::from(&object.into()) + } + + pub async fn query_raw(&self, params: Query) -> quaint::Result { + self.query_raw.call(params).await + } + + pub async fn execute_raw(&self, params: Query) -> quaint::Result { + self.execute_raw.call(params).await + } +} + +impl DriverProxy { + pub fn new(object: &JsObject) -> Self { + Self::from(&object.into()) + } + + async fn start_transaction_inner(&self) -> quaint::Result> { + let tx = self.start_transaction.call(()).await?; + + // Decrement for this gauge is done in JsTransaction::commit/JsTransaction::rollback + // Previously, it was done in JsTransaction::new, similar to the native Transaction. + // However, correct Dispatcher is lost there and increment does not register, so we moved + // it here instead. + increment_gauge!("prisma_client_queries_active", 1.0); + Ok(Box::new(tx)) + } + + pub fn start_transaction<'a>( + &'a self, + ) -> SendFuture>> + 'a> { + SendFuture(self.start_transaction_inner()) + } +} + +impl TransactionProxy { + pub fn new(object: &JsObject) -> Self { + Self::from(&object.into()) + } + + pub fn options(&self) -> &TransactionOptions { + &self.options + } + + pub fn commit<'a>(&'a self) -> SendFuture> + 'a> { + SendFuture(self.commit.call(())) + } + + pub fn rollback<'a>(&'a self) -> SendFuture> + 'a> { + SendFuture(self.rollback.call(())) + } +} + +impl Drop for TransactionProxy { + fn drop(&mut self) { + _ = self.dispose.call0(&JsValue::null()); + } +} + +// Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. +unsafe impl Send for TransactionProxy {} +unsafe impl Sync for TransactionProxy {} + +unsafe impl Send for DriverProxy {} +unsafe impl Sync for DriverProxy {} + +unsafe impl Send for CommonProxy {} +unsafe impl Sync for CommonProxy {} diff --git a/query-engine/driver-adapters/src/wasm/queryable.rs b/query-engine/driver-adapters/src/wasm/queryable.rs new file mode 100644 index 000000000000..abb828a443c0 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/queryable.rs @@ -0,0 +1,324 @@ +use super::{ + conversion, + proxy::{CommonProxy, DriverProxy, Query}, + send_future::SendFuture, +}; +use async_trait::async_trait; +use ducktor::FromJsValue as DuckType; +use futures::Future; +use js_sys::Object as JsObject; +use psl::datamodel_connector::Flavour; +use quaint::{ + connector::{metrics, IsolationLevel, Transaction}, + error::{Error, ErrorKind}, + prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, + visitor::{self, Visitor}, +}; +use tracing::{info_span, Instrument}; +use wasm_bindgen::prelude::wasm_bindgen; + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +#[wasm_bindgen(getter_with_clone)] +#[derive(DuckType, Default)] +pub(crate) struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, + pub flavour: Flavour, +} + +impl JsBaseQueryable { + pub(crate) fn new(proxy: CommonProxy) -> Self { + let flavour: Flavour = proxy.flavour.parse().unwrap(); + Self { proxy, flavour } + } + + /// visit a quaint query AST according to the flavour of the JS connector + fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { + match self.flavour { + Flavour::Mysql => visitor::Mysql::build(q), + Flavour::Postgres => visitor::Postgres::build(q), + Flavour::Sqlite => visitor::Sqlite::build(q), + _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), + } + } + + async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { + let sql: String = sql.to_string(); + + let converter = match self.flavour { + Flavour::Postgres => conversion::postgres::value_to_js_arg, + Flavour::Sqlite => conversion::sqlite::value_to_js_arg, + Flavour::Mysql => conversion::mysql::value_to_js_arg, + _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), + }; + + let args = values + .iter() + .map(converter) + .collect::>>()?; + + Ok(Query { sql, args }) + } +} + +#[async_trait] +impl QuaintQueryable for JsBaseQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.query_raw", sql, params, move || async move { + self.do_query_raw(sql, params).await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.execute_raw", sql, params, move || async move { + self.do_execute_raw(sql, params).await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + let params = &[]; + metrics::query("js.raw_cmd", cmd, params, move || async move { + self.do_execute_raw(cmd, params).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> quaint::Result> { + // Note: JS Connectors don't use this method. + Ok(None) + } + + fn is_healthy(&self) -> bool { + // Note: JS Connectors don't use this method. + true + } + + /// Sets the transaction isolation level to given value. + /// Implementers have to make sure that the passed isolation level is valid for the underlying database. + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + if self.flavour == Flavour::Sqlite { + return match isolation_level { + IsolationLevel::Serializable => Ok(()), + _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), + }; + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await + } + + fn requires_isolation_first(&self) -> bool { + match self.flavour { + Flavour::Mysql => true, + Flavour::Postgres | Flavour::Sqlite => false, + _ => unreachable!(), + } + } +} + +impl JsBaseQueryable { + pub fn phantom_query_message(stmt: &str) -> String { + format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) + } + + async fn do_query_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; + + let len = result_set.len(); + let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); + + result_set.try_into() + } + + fn do_query_raw<'a>( + &'a self, + sql: &'a str, + params: &'a [quaint::Value<'a>], + ) -> SendFuture> + 'a> { + SendFuture(self.do_query_raw_inner(sql, params)) + } + + async fn do_execute_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; + + Ok(affected_rows as u64) + } + + fn do_execute_raw<'a>( + &'a self, + sql: &'a str, + params: &'a [quaint::Value<'a>], + ) -> SendFuture> + 'a> { + SendFuture(self.do_execute_raw_inner(sql, params)) + } +} + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub struct JsQueryable { + inner: JsBaseQueryable, + driver_proxy: DriverProxy, +} + +impl std::fmt::Display for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +impl std::fmt::Debug for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +#[async_trait] +impl QuaintQueryable for JsQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + let tx = self.driver_proxy.start_transaction().await?; + + let isolation_first = tx.requires_isolation_first(); + + if isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + let begin_stmt = tx.begin_statement(); + + let tx_opts = tx.options(); + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(begin_stmt).await?; + } + + if !isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + self.server_reset_query(tx.as_ref()).await?; + + Ok(tx) + } +} + +pub fn from_wasm(driver: JsObject) -> JsQueryable { + let common = CommonProxy::new(&driver); + let driver_proxy = DriverProxy::new(&driver); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } +} diff --git a/query-engine/driver-adapters/src/wasm/send_future.rs b/query-engine/driver-adapters/src/wasm/send_future.rs new file mode 100644 index 000000000000..61c64a960450 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/send_future.rs @@ -0,0 +1,24 @@ +use futures::Future; + +// Allow asynchronous futures to be sent safely across threads, solving the following error: +// +// ```text +// future cannot be sent between threads safely +// the trait `Send` is not implemented for `dyn Future>`. +// ``` +// +// See: https://github.com/rustwasm/wasm-bindgen/issues/2409#issuecomment-820750943 +#[pin_project::pin_project] +pub struct SendFuture(#[pin] pub F); + +unsafe impl Send for SendFuture {} + +impl Future for SendFuture { + type Output = F::Output; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + let future: std::pin::Pin<&mut F> = self.project().0; + future.poll(cx) + } +} diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs new file mode 100644 index 000000000000..21ed005c7d1e --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -0,0 +1,132 @@ +use async_trait::async_trait; +use ducktor::FromJsValue as DuckType; +use metrics::decrement_gauge; +use quaint::{ + connector::{IsolationLevel, Transaction as QuaintTransaction}, + prelude::{Query as QuaintQuery, Queryable, ResultSet}, + Value, +}; +use serde::Deserialize; +use wasm_bindgen::prelude::wasm_bindgen; + +use super::{ + proxy::{CommonProxy, TransactionOptions, TransactionProxy}, + queryable::JsBaseQueryable, + send_future::SendFuture, +}; + +// Wrapper around JS transaction objects that implements Queryable +// and quaint::Transaction. Can be used in place of quaint transaction, +// but delegates most operations to JS +#[derive(Deserialize, Default)] +pub(crate) struct JsTransaction { + #[serde(skip)] + tx_proxy: TransactionProxy, + #[serde(skip)] + inner: JsBaseQueryable, +} + +impl JsTransaction { + pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { + Self { inner, tx_proxy } + } + + pub fn options(&self) -> &TransactionOptions { + self.tx_proxy.options() + } + + pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { + let params = &[]; + quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await + } +} + +#[async_trait] +impl QuaintTransaction for JsTransaction { + async fn commit(&self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let commit_stmt = "COMMIT"; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(commit_stmt).await?; + } + + SendFuture(self.tx_proxy.commit()).await + } + + async fn rollback(&self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let rollback_stmt = "ROLLBACK"; + + if self.options().use_phantom_query { + let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); + self.raw_phantom_cmd(rollback_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(rollback_stmt).await?; + } + + SendFuture(self.tx_proxy.rollback()).await + } + + fn as_queryable(&self) -> &dyn Queryable { + self + } +} + +#[async_trait] +impl Queryable for JsTransaction { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +// Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. +unsafe impl Send for JsTransaction {} +unsafe impl Sync for JsTransaction {} From d7a799d85f2018475b50585d61e5f53f63e4766a Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 09:47:28 +0100 Subject: [PATCH 042/134] feat(driver-adapters): plug "driver-adapters" to "query-engine-wasm" --- query-engine/query-engine-wasm/Cargo.toml | 2 +- query-engine/query-engine-wasm/src/engine.rs | 12 ++- query-engine/query-engine-wasm/src/lib.rs | 1 - query-engine/query-engine-wasm/src/proxy.rs | 107 ------------------- 4 files changed, 10 insertions(+), 112 deletions(-) delete mode 100644 query-engine/query-engine-wasm/src/proxy.rs diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index f4a9703e741e..73191b1b53a7 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -16,7 +16,7 @@ psl.workspace = true prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } query-connector = { path = "../connectors/query-connector" } -sql-query-connector = { path = "../connectors/sql-query-connector" } +sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector" } query-core = { path = "../core" } request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } driver-adapters = { path = "../driver-adapters" } diff --git a/query-engine/query-engine-wasm/src/engine.rs b/query-engine/query-engine-wasm/src/engine.rs index f9a06fabcf4b..20c67ee24bce 100644 --- a/query-engine/query-engine-wasm/src/engine.rs +++ b/query-engine/query-engine-wasm/src/engine.rs @@ -1,12 +1,12 @@ #![allow(dead_code)] #![allow(unused_variables)] -use crate::proxy; use crate::{ error::ApiError, logger::{LogCallback, Logger}, }; use js_sys::{Function as JsFunction, Object as JsObject}; +use request_handlers::ConnectorMode; use serde::{Deserialize, Serialize}; use std::{ collections::{BTreeMap, HashMap}, @@ -21,6 +21,7 @@ use wasm_bindgen::prelude::wasm_bindgen; /// The main query engine used by JS #[wasm_bindgen] pub struct QueryEngine { + connector_mode: ConnectorMode, inner: RwLock, logger: Logger, } @@ -125,10 +126,12 @@ impl QueryEngine { let mut schema = psl::validate(datamodel.into()); let config = &mut schema.configuration; + let preview_features = config.preview_features(); if let Some(adapter) = maybe_adapter { - let js_queryable = - proxy::from_wasm(adapter).map_err(|e| ApiError::configuration(e.as_string().unwrap_or_default()))?; + let js_queryable = driver_adapters::from_wasm(adapter); + + sql_connector::activate_driver_adapter(Arc::new(js_queryable)); let provider_name = schema.connector.provider_name(); log::info!("Received driver adapter for {provider_name}."); @@ -160,9 +163,12 @@ impl QueryEngine { let log_level = log_level.parse::().unwrap(); let logger = Logger::new(log_queries, log_level, log_callback); + let connector_mode = ConnectorMode::Js; + Ok(Self { inner: RwLock::new(Inner::Builder(builder)), logger, + connector_mode, }) } diff --git a/query-engine/query-engine-wasm/src/lib.rs b/query-engine/query-engine-wasm/src/lib.rs index 89b519515517..74f9e93a0de1 100644 --- a/query-engine/query-engine-wasm/src/lib.rs +++ b/query-engine/query-engine-wasm/src/lib.rs @@ -2,7 +2,6 @@ pub mod engine; pub mod error; pub mod functions; pub mod logger; -mod proxy; pub(crate) type Result = std::result::Result; diff --git a/query-engine/query-engine-wasm/src/proxy.rs b/query-engine/query-engine-wasm/src/proxy.rs deleted file mode 100644 index ad028e218236..000000000000 --- a/query-engine/query-engine-wasm/src/proxy.rs +++ /dev/null @@ -1,107 +0,0 @@ -#![allow(dead_code)] -#![allow(unused_variables)] - -// This code will likely live in a separate crate, but for now it's here. - -use async_trait::async_trait; -use js_sys::{Function as JsFunction, JsString, Object as JsObject, Promise as JsPromise, Reflect as JsReflect}; -use serde::{de::DeserializeOwned, Serialize}; -use wasm_bindgen::{JsCast, JsValue}; - -type Result = std::result::Result; - -pub struct CommonProxy { - /// Execute a query given as SQL, interpolating the given parameters. - query_raw: JsFunction, - - /// Execute a query given as SQL, interpolating the given parameters and - /// returning the number of affected rows. - execute_raw: JsFunction, - - /// Return the flavour for this driver. - pub(crate) flavour: String, -} - -impl CommonProxy { - pub(crate) fn new(driver: &JsObject) -> Result { - let query_raw = JsReflect::get(driver, &"queryRaw".into())?.dyn_into::()?; - let execute_raw = JsReflect::get(driver, &"executeRaw".into())?.dyn_into::()?; - let flavour: String = JsReflect::get(driver, &"flavour".into())? - .dyn_into::()? - .into(); - - let common_proxy = Self { - query_raw, - execute_raw, - flavour, - }; - Ok(common_proxy) - } -} - -pub struct DriverProxy { - start_transaction: JsFunction, -} - -impl DriverProxy { - pub(crate) fn new(driver: &JsObject) -> Result { - let start_transaction = JsReflect::get(driver, &"startTransaction".into())?.dyn_into::()?; - - let driver_proxy = Self { start_transaction }; - Ok(driver_proxy) - } -} - -pub struct JsQueryable { - inner: CommonProxy, - driver_proxy: DriverProxy, -} - -impl JsQueryable { - pub fn new(inner: CommonProxy, driver_proxy: DriverProxy) -> Self { - Self { inner, driver_proxy } - } -} - -pub fn from_wasm(driver: JsObject) -> Result { - let common_proxy = CommonProxy::new(&driver)?; - let driver_proxy = DriverProxy::new(&driver)?; - - let js_queryable = JsQueryable::new(common_proxy, driver_proxy); - Ok(js_queryable) -} - -#[async_trait(?Send)] -trait JsAsyncFunc { - async fn call1_async(&self, arg1: T) -> Result - where - T: Serialize, - R: DeserializeOwned; - - fn call0_sync(&self) -> Result - where - R: DeserializeOwned; -} - -#[async_trait(?Send)] -impl JsAsyncFunc for JsFunction { - async fn call1_async(&self, arg1: T) -> Result - where - T: Serialize, - R: DeserializeOwned, - { - let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| js_sys::Error::new(&err.to_string()))?; - let promise = self.call1(&JsValue::null(), &arg1)?; - let future = wasm_bindgen_futures::JsFuture::from(JsPromise::from(promise)); - let value = future.await?; - serde_wasm_bindgen::from_value(value).map_err(|err| js_sys::Error::new(&err.to_string())) - } - - fn call0_sync(&self) -> Result - where - R: DeserializeOwned, - { - let value = self.call0(&JsValue::null())?; - serde_wasm_bindgen::from_value(value).map_err(|err| js_sys::Error::new(&err.to_string())) - } -} From c72c37756e8b4e773a993b09dc4e02150ede6b86 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 15:33:17 +0100 Subject: [PATCH 043/134] chore: remove .cargo, add it to .gitignore --- .cargo/config.toml | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml deleted file mode 100644 index 229dd6ee6b3f..000000000000 --- a/.cargo/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[build] -# target = "wasm32-unknown-unknown" From 2339b313786c7401ef46b55c174b53ecbff13332 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:19:23 +0100 Subject: [PATCH 044/134] chore: move "task" module into its own file --- query-engine/core/src/executor/mod.rs | 64 +------------------------- query-engine/core/src/executor/task.rs | 59 ++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 63 deletions(-) create mode 100644 query-engine/core/src/executor/task.rs diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 43df839e9635..ba2784d3c71a 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -10,6 +10,7 @@ mod execute_operation; mod interpreting_executor; mod pipeline; mod request_context; +pub(crate) mod task; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; @@ -131,66 +132,3 @@ pub trait TransactionManager { pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } - -// The `task` module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. -pub(crate) mod task { - pub use arch::{spawn, JoinHandle}; - use futures::Future; - - // On native targets, `tokio::spawn` spawns a new asynchronous task. - #[cfg(not(target_arch = "wasm32"))] - mod arch { - use super::*; - - pub type JoinHandle = tokio::task::JoinHandle; - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - tokio::spawn(future) - } - } - - // On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. - #[cfg(target_arch = "wasm32")] - mod arch { - use super::*; - use tokio::sync::oneshot::{self}; - - // Wasm-compatible alternative to `tokio::task::JoinHandle`. - // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. - #[pin_project::pin_project] - pub struct JoinHandle(#[pin] oneshot::Receiver); - - impl Future for JoinHandle { - type Output = Result; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - // the `self.project()` method is provided by the `pin_project` macro - let receiver: std::pin::Pin<&mut oneshot::Receiver> = self.project().0; - receiver.poll(cx) - } - } - - impl JoinHandle { - pub fn abort(&mut self) { - // abort is noop on Wasm targets - } - } - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (sender, receiver) = oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - sender.send(result).ok(); - }); - JoinHandle(receiver) - } - } -} diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs new file mode 100644 index 000000000000..8d1c39bbcd06 --- /dev/null +++ b/query-engine/core/src/executor/task.rs @@ -0,0 +1,59 @@ +//! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. + +pub use arch::{spawn, JoinHandle}; +use futures::Future; + +// On native targets, `tokio::spawn` spawns a new asynchronous task. +#[cfg(not(target_arch = "wasm32"))] +mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +// On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. +#[cfg(target_arch = "wasm32")] +mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + pub struct JoinHandle(oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + core::pin::Pin::new(&mut self.0).poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } +} From 96cd8ca800176eaeb9cf803c2b7e340973eaddc0 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:30:14 +0100 Subject: [PATCH 045/134] fix(driver-adapters): ci for "request-handlers" --- query-engine/connectors/sql-query-connector/Cargo.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index fa9c32ef88e1..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -26,9 +26,14 @@ tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector" From 3541054b2723c17a158e5e1990ca581461756f1b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:30:14 +0100 Subject: [PATCH 046/134] fix(driver-adapters): ci for "request-handlers" --- query-engine/connectors/sql-query-connector/Cargo.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index fa9c32ef88e1..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -26,9 +26,14 @@ tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector" From 0795a6a5593d9204fb8367047d940e14cc937a9b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 17:54:21 +0100 Subject: [PATCH 047/134] fix(driver-adapters): clippy compile error on "query-engine-wasm" --- query-engine/query-engine-wasm/src/lib.rs | 42 +++++++++++++------ query-engine/query-engine-wasm/src/wasm.rs | 4 ++ .../src/{ => wasm}/engine.rs | 0 .../query-engine-wasm/src/{ => wasm}/error.rs | 0 .../src/{ => wasm}/functions.rs | 0 .../src/{ => wasm}/logger.rs | 0 6 files changed, 33 insertions(+), 13 deletions(-) create mode 100644 query-engine/query-engine-wasm/src/wasm.rs rename query-engine/query-engine-wasm/src/{ => wasm}/engine.rs (100%) rename query-engine/query-engine-wasm/src/{ => wasm}/error.rs (100%) rename query-engine/query-engine-wasm/src/{ => wasm}/functions.rs (100%) rename query-engine/query-engine-wasm/src/{ => wasm}/logger.rs (100%) diff --git a/query-engine/query-engine-wasm/src/lib.rs b/query-engine/query-engine-wasm/src/lib.rs index 74f9e93a0de1..2d2167f7c22e 100644 --- a/query-engine/query-engine-wasm/src/lib.rs +++ b/query-engine/query-engine-wasm/src/lib.rs @@ -1,18 +1,34 @@ -pub mod engine; -pub mod error; -pub mod functions; -pub mod logger; +#[cfg(not(target_arch = "wasm32"))] +mod arch { + // This crate only works in a Wasm environment. + // This conditional compilation block is here to make commands like + // `cargo clippy --all-features` happy, as `clippy` doesn't support the + // `--exclude` option (see: https://github.com/rust-lang/rust-clippy/issues/9555). + // + // This crate can still be inspected by `clippy` via: + // `cargo clippy --all-features -p query-engine-wasm --target wasm32-unknown-unknown` +} + +#[cfg(target_arch = "wasm32")] +mod wasm; -pub(crate) type Result = std::result::Result; +#[cfg(target_arch = "wasm32")] +mod arch { + pub use super::wasm::*; -use wasm_bindgen::prelude::wasm_bindgen; + pub(crate) type Result = std::result::Result; -/// Function that should be called before any other public function in this module. -#[wasm_bindgen] -pub fn init() { - // Set up temporary logging for the wasm module. - wasm_logger::init(wasm_logger::Config::default()); + use wasm_bindgen::prelude::wasm_bindgen; - // Set up temporary panic hook for the wasm module. - std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + /// Function that should be called before any other public function in this module. + #[wasm_bindgen] + pub fn init() { + // Set up temporary logging for the wasm module. + wasm_logger::init(wasm_logger::Config::default()); + + // Set up temporary panic hook for the wasm module. + std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + } } + +pub use arch::*; diff --git a/query-engine/query-engine-wasm/src/wasm.rs b/query-engine/query-engine-wasm/src/wasm.rs new file mode 100644 index 000000000000..4360e8a70471 --- /dev/null +++ b/query-engine/query-engine-wasm/src/wasm.rs @@ -0,0 +1,4 @@ +pub mod engine; +pub mod error; +pub mod functions; +pub mod logger; diff --git a/query-engine/query-engine-wasm/src/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs similarity index 100% rename from query-engine/query-engine-wasm/src/engine.rs rename to query-engine/query-engine-wasm/src/wasm/engine.rs diff --git a/query-engine/query-engine-wasm/src/error.rs b/query-engine/query-engine-wasm/src/wasm/error.rs similarity index 100% rename from query-engine/query-engine-wasm/src/error.rs rename to query-engine/query-engine-wasm/src/wasm/error.rs diff --git a/query-engine/query-engine-wasm/src/functions.rs b/query-engine/query-engine-wasm/src/wasm/functions.rs similarity index 100% rename from query-engine/query-engine-wasm/src/functions.rs rename to query-engine/query-engine-wasm/src/wasm/functions.rs diff --git a/query-engine/query-engine-wasm/src/logger.rs b/query-engine/query-engine-wasm/src/wasm/logger.rs similarity index 100% rename from query-engine/query-engine-wasm/src/logger.rs rename to query-engine/query-engine-wasm/src/wasm/logger.rs From 22113caa2238fffab225f7c6ef0d753cd9f94562 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 17:55:54 +0100 Subject: [PATCH 048/134] chore(driver-adapters): fix conflicting library name warning on "cargo build" --- query-engine/query-engine-wasm/Cargo.toml | 2 +- query-engine/query-engine-wasm/example.js | 2 +- query-engine/query-engine-wasm/package.json | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 73191b1b53a7..9cd61041682b 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [lib] doc = false crate-type = ["cdylib"] -name = "query_engine" +name = "query_engine_wasm" [dependencies] anyhow = "1" diff --git a/query-engine/query-engine-wasm/example.js b/query-engine/query-engine-wasm/example.js index bca6d5ba95d7..6d3a78374bc8 100644 --- a/query-engine/query-engine-wasm/example.js +++ b/query-engine/query-engine-wasm/example.js @@ -6,7 +6,7 @@ import { Pool } from '@neondatabase/serverless' import { PrismaNeon } from '@prisma/adapter-neon' import { bindAdapter } from '@prisma/driver-adapter-utils' -import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine.js' +import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine_wasm.js' async function main() { // Always initialize the Wasm library before using it. diff --git a/query-engine/query-engine-wasm/package.json b/query-engine/query-engine-wasm/package.json index 102db2ce14b5..8192656bd56f 100644 --- a/query-engine/query-engine-wasm/package.json +++ b/query-engine/query-engine-wasm/package.json @@ -1,9 +1,12 @@ { "type": "module", "main": "./example.js", + "scripts": { + "dev": "node --experimental-wasm-modules ./example.js" + }, "dependencies": { "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.5.2", - "@prisma/driver-adapter-utils": "5.5.2" + "@prisma/adapter-neon": "5.6.0", + "@prisma/driver-adapter-utils": "5.6.0" } } From 9c8bd20682e0346e181db3721e9023fcdc377b16 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 20 Nov 2023 03:53:16 +0100 Subject: [PATCH 049/134] chore: fix conflicts --- Cargo.lock | 3 --- 1 file changed, 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 39de1090f823..7f79a3e8a5e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3846,11 +3846,8 @@ dependencies = [ "psl", "quaint", "query-connector", -<<<<<<< HEAD "query-core", "request-handlers", -======= ->>>>>>> main "serde", "serde-wasm-bindgen", "serde_json", From fb92bbb1a4a9845fe3b85963ff3e82cc219f352c Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 20 Nov 2023 03:56:49 +0100 Subject: [PATCH 050/134] chore: fixed some clippy warnings --- query-engine/driver-adapters/src/wasm/async_js_function.rs | 6 +++--- query-engine/driver-adapters/src/wasm/proxy.rs | 2 +- query-engine/driver-adapters/src/wasm/transaction.rs | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 66cf2a39ae05..5ae0708054c6 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -1,9 +1,9 @@ -use js_sys::{Function as JsFunction, Object as JsObject, Promise as JsPromise}; +use js_sys::{Function as JsFunction, Promise as JsPromise}; use serde::{de::DeserializeOwned, Serialize}; use std::marker::PhantomData; -use wasm_bindgen::convert::{FromWasmAbi, WasmAbi}; +use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::describe::WasmDescribe; -use wasm_bindgen::{prelude::wasm_bindgen, JsError, JsValue}; +use wasm_bindgen::{JsError, JsValue}; use wasm_bindgen_futures::JsFuture; use super::error::into_quaint_error; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index 5d2b58d3e43d..e4efc059ce58 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,6 +1,6 @@ use ducktor::FromJsValue as DuckType; use futures::Future; -use js_sys::{Function as JsFunction, Object as JsObject, Promise as JsPromise}; +use js_sys::{Function as JsFunction, Object as JsObject}; use super::{async_js_function::AsyncJsFunction, send_future::SendFuture, transaction::JsTransaction}; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs index 21ed005c7d1e..e7aba9f418c4 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use ducktor::FromJsValue as DuckType; use metrics::decrement_gauge; use quaint::{ connector::{IsolationLevel, Transaction as QuaintTransaction}, From 3c592e111027bc92b6e69d5ffd8d220b72ca4fa7 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 20 Nov 2023 03:58:16 +0100 Subject: [PATCH 051/134] chore: add .cargo to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 75c06e9ce68b..a4b51023344f 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,6 @@ prisma-schema-wasm/nodejs # Ignore pnpm-lock.yaml query-engine/driver-adapters/pnpm-lock.yaml package-lock.json + +# Useful for local wasm32-* development +.cargo/ From 8529b8cb34f33c1f7f68701b145eedf37e09f70f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 20 Nov 2023 04:14:02 +0100 Subject: [PATCH 052/134] feat(query-engine-wasm): ported some logic from query-engine-node-api in a wasm32-compatible fashion --- query-engine/query-engine-wasm/Cargo.toml | 4 +- query-engine/query-engine-wasm/build.sh | 12 +- query-engine/query-engine-wasm/src/wasm.rs | 3 + .../query-engine-wasm/src/wasm/engine.rs | 161 ++++++++++++++++-- .../query-engine-wasm/src/wasm/error.rs | 51 +++--- .../query-engine-wasm/src/wasm/functions.rs | 10 +- 6 files changed, 196 insertions(+), 45 deletions(-) diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 9cd61041682b..a9759399783b 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -13,10 +13,10 @@ anyhow = "1" async-trait = "0.1" user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true -prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } query-connector = { path = "../connectors/query-connector" } sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector" } + query-core = { path = "../core" } request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } driver-adapters = { path = "../driver-adapters" } @@ -29,6 +29,8 @@ tsify.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true +prisma-models = { path = "../prisma-models" } + thiserror = "1" url = "2" serde.workspace = true diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 12d8328305ff..dbbf6a720534 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -5,10 +5,16 @@ OUT_VERSION="$1" OUT_FOLDER="pkg" OUT_JSON="${OUT_FOLDER}/package.json" -OUT_TARGET="bundler" # Note(jkomyno): I wasn't able to make it work with `web` target +OUT_TARGET="bundler" OUT_NPM_NAME="@prisma/query-engine-wasm" +# The local ./Cargo.toml file uses "name = "query_engine_wasm" as library name +# to avoid conflicts with libquery's `name = "query_engine"` library name declaration. +# This little `sed -i` trick below is a hack to publish "@prisma/query-engine-wasm" +# with the same binding filenames currently expected by the Prisma Client. +sed -i '' 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml wasm-pack build --release --target $OUT_TARGET +sed -i '' 's/name = "query_engine"/name = "query_engine_wasm"/g' Cargo.toml sleep 1 @@ -21,6 +27,10 @@ printf '%s\n' "$(jq --arg version "$OUT_VERSION" '. + {"version": $version}' $OU # Add the package name printf '%s\n' "$(jq --arg name "$OUT_NPM_NAME" '. + {"name": $name}' $OUT_JSON)" > $OUT_JSON +# Some info: enabling Cloudflare Workers in the bindings generated by wasm-package +# is useful for local experiments, but it's not needed here. +# `@prisma/client` has its own `esbuild` plugin for CF-compatible bindings +# and import of `.wasm` files. enable_cf_in_bindings() { # Enable Cloudflare Workers in the generated JS bindings. # The generated bindings are compatible with: diff --git a/query-engine/query-engine-wasm/src/wasm.rs b/query-engine/query-engine-wasm/src/wasm.rs index 4360e8a70471..14edeadf63b6 100644 --- a/query-engine/query-engine-wasm/src/wasm.rs +++ b/query-engine/query-engine-wasm/src/wasm.rs @@ -2,3 +2,6 @@ pub mod engine; pub mod error; pub mod functions; pub mod logger; + +pub(crate) type Result = std::result::Result; +pub(crate) type Executor = Box; diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 20c67ee24bce..615339e64d29 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -5,17 +5,29 @@ use crate::{ error::ApiError, logger::{LogCallback, Logger}, }; +use futures::FutureExt; use js_sys::{Function as JsFunction, Object as JsObject}; +use query_core::{ + protocol::EngineProtocol, + schema::{self, QuerySchema}, + QueryExecutor, TransactionOptions, TxId, +}; use request_handlers::ConnectorMode; +use request_handlers::{dmmf, load_executor, render_graphql_schema, RequestBody, RequestHandler}; use serde::{Deserialize, Serialize}; +use serde_json::json; use std::{ collections::{BTreeMap, HashMap}, + future::Future, + panic::AssertUnwindSafe, path::PathBuf, sync::Arc, }; use tokio::sync::RwLock; +use tracing::{field, Instrument, Span}; use tracing_subscriber::filter::LevelFilter; use tsify::Tsify; +use user_facing_errors::Error; use wasm_bindgen::prelude::wasm_bindgen; /// The main query engine used by JS @@ -40,13 +52,17 @@ struct EngineBuilder { schema: Arc, config_dir: PathBuf, env: HashMap, + engine_protocol: EngineProtocol, } /// Internal structure for querying and reconnecting with the engine. struct ConnectedEngine { schema: Arc, + query_schema: Arc, + executor: crate::Executor, config_dir: PathBuf, env: HashMap, + engine_protocol: EngineProtocol, } /// Returned from the `serverInfo` method in javascript. @@ -58,6 +74,22 @@ struct ServerInfo { primary_connector: Option, } +impl ConnectedEngine { + /// The schema AST for Query Engine core. + pub fn query_schema(&self) -> &Arc { + &self.query_schema + } + + /// The query executor. + pub fn executor(&self) -> &(dyn QueryExecutor + Send + Sync) { + self.executor.as_ref() + } + + pub fn engine_protocol(&self) -> EngineProtocol { + self.engine_protocol + } +} + /// Parameters defining the construction of an engine. #[derive(Debug, Deserialize, Tsify)] #[tsify(from_wasm_abi)] @@ -75,7 +107,7 @@ pub struct ConstructorOptions { #[serde(default)] ignore_env_var_errors: bool, #[serde(default)] - engine_protocol: Option, + engine_protocol: Option, } impl Inner { @@ -154,9 +186,12 @@ impl QueryEngine { .validate_that_one_datasource_is_provided() .map_err(|errors| ApiError::conversion(errors, schema.db.source()))?; + let engine_protocol = engine_protocol.unwrap_or(EngineProtocol::Json); + let builder = EngineBuilder { schema: Arc::new(schema), config_dir, + engine_protocol, env, }; @@ -194,42 +229,122 @@ impl QueryEngine { trace: String, tx_id: Option, ) -> Result { - log::info!("Called `QueryEngine::query()`"); - Err(ApiError::configuration("Can't use `query` until `request_handlers` is Wasm-compatible.").into()) + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; + + async move { + let span = if tx_id.is_none() { + tracing::info_span!("prisma:engine", user_facing = true) + } else { + Span::none() + }; + + let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); + let response = handler + .handle(query, tx_id.map(TxId::from), None) + .instrument(span) + .await; + + Ok(serde_json::to_string(&response)?) + } + .await + }) + .await } /// If connected, attempts to start a transaction in the core and returns its ID. #[wasm_bindgen(js_name = startTransaction)] pub async fn start_transaction(&self, input: String, trace: String) -> Result { - log::info!("Called `QueryEngine::start_transaction()`"); - Err(ApiError::configuration("Can't use `start_transaction` until `query_core` is Wasm-compatible.").into()) + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + async move { + let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); + + let tx_opts: TransactionOptions = serde_json::from_str(&input)?; + match engine + .executor() + .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) + .instrument(span) + .await + { + Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), + Err(err) => Ok(map_known_error(err)?), + } + } + .await + }) + .await } /// If connected, attempts to commit a transaction with id `tx_id` in the core. #[wasm_bindgen(js_name = commitTransaction)] pub async fn commit_transaction(&self, tx_id: String, trace: String) -> Result { - log::info!("Called `QueryEngine::commit_transaction()`"); - Err(ApiError::configuration("Can't use `commit_transaction` until `query_core` is Wasm-compatible.").into()) + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + async move { + match engine.executor().commit_tx(TxId::from(tx_id)).await { + Ok(_) => Ok("{}".to_string()), + Err(err) => Ok(map_known_error(err)?), + } + } + .await + }) + .await } #[wasm_bindgen] pub async fn dmmf(&self, trace: String) -> Result { - log::info!("Called `QueryEngine::dmmf()`"); - Err(ApiError::configuration("Can't use `dmmf` until `request_handlers` is Wasm-compatible.").into()) + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + let dmmf = dmmf::render_dmmf(&engine.query_schema); + + let json = { + let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); + serde_json::to_string(&dmmf)? + }; + + Ok(json) + }) + .await } /// If connected, attempts to roll back a transaction with id `tx_id` in the core. #[wasm_bindgen(js_name = rollbackTransaction)] pub async fn rollback_transaction(&self, tx_id: String, trace: String) -> Result { - log::info!("Called `QueryEngine::rollback_transaction()`"); - Ok("{}".to_owned()) + async_panic_to_js_error(async { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + async move { + match engine.executor().rollback_tx(TxId::from(tx_id)).await { + Ok(_) => Ok("{}".to_string()), + Err(err) => Ok(map_known_error(err)?), + } + } + .await + }) + .await } /// Loads the query schema. Only available when connected. #[wasm_bindgen(js_name = sdlSchema)] pub async fn sdl_schema(&self) -> Result { - log::info!("Called `QueryEngine::sdl_schema()`"); - Ok("{}".to_owned()) + async_panic_to_js_error(async move { + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + Ok(render_graphql_schema(engine.query_schema())) + }) + .await } #[wasm_bindgen] @@ -239,6 +354,13 @@ impl QueryEngine { } } +fn map_known_error(err: query_core::CoreError) -> crate::Result { + let user_error: user_facing_errors::Error = err.into(); + let value = serde_json::to_string(&user_error)?; + + Ok(value) +} + fn stringify_env_values(origin: serde_json::Value) -> crate::Result> { use serde_json::Value; @@ -269,3 +391,16 @@ fn stringify_env_values(origin: serde_json::Value) -> crate::Result(fut: F) -> Result +where + F: Future>, +{ + match AssertUnwindSafe(fut).catch_unwind().await { + Ok(result) => result, + Err(err) => match Error::extract_panic_message(err) { + Some(message) => Err(wasm_bindgen::JsError::new(&format!("PANIC: {message}"))), + None => Err(wasm_bindgen::JsError::new("PANIC: unknown panic")), + }, + } +} diff --git a/query-engine/query-engine-wasm/src/wasm/error.rs b/query-engine/query-engine-wasm/src/wasm/error.rs index 619e96564f6a..cfabc92ea0b0 100644 --- a/query-engine/query-engine-wasm/src/wasm/error.rs +++ b/query-engine/query-engine-wasm/src/wasm/error.rs @@ -1,6 +1,6 @@ use psl::diagnostics::Diagnostics; -// use query_connector::error::ConnectorError; -// use query_core::CoreError; +use query_connector::error::ConnectorError; +use query_core::CoreError; use thiserror::Error; #[derive(Debug, Error)] @@ -11,11 +11,12 @@ pub enum ApiError { #[error("{}", _0)] Configuration(String), - // #[error("{}", _0)] - // Core(CoreError), + #[error("{}", _0)] + Core(CoreError), + + #[error("{}", _0)] + Connector(ConnectorError), - // #[error("{}", _0)] - // Connector(ConnectorError), #[error("Can't modify an already connected engine.")] AlreadyConnected, @@ -31,10 +32,10 @@ impl From for user_facing_errors::Error { use std::fmt::Write as _; match err { - // ApiError::Connector(ConnectorError { - // user_facing_error: Some(err), - // .. - // }) => err.into(), + ApiError::Connector(ConnectorError { + user_facing_error: Some(err), + .. + }) => err.into(), ApiError::Conversion(errors, dml_string) => { let mut full_error = errors.to_pretty_string("schema.prisma", &dml_string); write!(full_error, "\nValidation Error Count: {}", errors.errors().len()).unwrap(); @@ -43,7 +44,7 @@ impl From for user_facing_errors::Error { user_facing_errors::common::SchemaParserError { full_error }, )) } - // ApiError::Core(error) => user_facing_errors::Error::from(error), + ApiError::Core(error) => user_facing_errors::Error::from(error), other => user_facing_errors::Error::new_non_panic_with_current_backtrace(other.to_string()), } } @@ -59,20 +60,20 @@ impl ApiError { } } -// impl From for ApiError { -// fn from(e: CoreError) -> Self { -// match e { -// CoreError::ConfigurationError(message) => Self::Configuration(message), -// core_error => Self::Core(core_error), -// } -// } -// } - -// impl From for ApiError { -// fn from(e: ConnectorError) -> Self { -// Self::Connector(e) -// } -// } +impl From for ApiError { + fn from(e: CoreError) -> Self { + match e { + CoreError::ConfigurationError(message) => Self::Configuration(message), + core_error => Self::Core(core_error), + } + } +} + +impl From for ApiError { + fn from(e: ConnectorError) -> Self { + Self::Connector(e) + } +} impl From for ApiError { fn from(e: url::ParseError) -> Self { diff --git a/query-engine/query-engine-wasm/src/wasm/functions.rs b/query-engine/query-engine-wasm/src/wasm/functions.rs index e0f0a93aa5cd..9767b22fb811 100644 --- a/query-engine/query-engine-wasm/src/wasm/functions.rs +++ b/query-engine/query-engine-wasm/src/wasm/functions.rs @@ -1,5 +1,7 @@ use crate::error::ApiError; +use request_handlers::dmmf; use serde::Serialize; +use std::sync::Arc; use tsify::Tsify; use wasm_bindgen::prelude::wasm_bindgen; @@ -28,12 +30,10 @@ pub fn dmmf(datamodel_string: String) -> Result { .to_result() .map_err(|errors| ApiError::conversion(errors, schema.db.source()))?; - Ok("{}".to_string()) + let query_schema = query_core::schema::build(Arc::new(schema), true); + let dmmf = dmmf::render_dmmf(&query_schema); - // let query_schema = query_core::schema::build(Arc::new(schema), true); - // let dmmf = dmmf::render_dmmf(&query_schema); - - // Ok(serde_json::to_string(&dmmf)?) + Ok(serde_json::to_string(&dmmf)?) } #[wasm_bindgen] From d276f3df63a4f54d920634f7d95e27af51b5897c Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Mon, 20 Nov 2023 17:33:52 +0100 Subject: [PATCH 053/134] Add connect/disconnect --- .../query-engine-wasm/src/wasm/engine.rs | 100 +++++++++++++++++- 1 file changed, 97 insertions(+), 3 deletions(-) diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 615339e64d29..d94d0cc371d0 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -210,15 +210,109 @@ impl QueryEngine { /// Connect to the database, allow queries to be run. #[wasm_bindgen] pub async fn connect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - log::info!("Called `QueryEngine::connect()`"); + async_panic_to_js_error(async { + let span = tracing::info_span!("prisma:engine:connect"); + + let mut inner = self.inner.write().await; + let builder = inner.as_builder()?; + let arced_schema = Arc::clone(&builder.schema); + let arced_schema_2 = Arc::clone(&builder.schema); + + let url = { + let data_source = builder + .schema + .configuration + .datasources + .first() + .ok_or_else(|| ApiError::configuration("No valid data source found"))?; + data_source + .load_url_with_config_dir(&builder.config_dir, |key| builder.env.get(key).map(ToString::to_string)) + .map_err(|err| crate::error::ApiError::Conversion(err, builder.schema.db.source().to_owned()))? + }; + + let engine = async move { + // We only support one data source & generator at the moment, so take the first one (default not exposed yet). + let data_source = arced_schema + .configuration + .datasources + .first() + .ok_or_else(|| ApiError::configuration("No valid data source found"))?; + + let preview_features = arced_schema.configuration.preview_features(); + + let executor_fut = async { + let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; + let connector = executor.primary_connector(); + + let conn_span = tracing::info_span!( + "prisma:engine:connection", + user_facing = true, + "db.type" = connector.name(), + ); + + connector.get_connection().instrument(conn_span).await?; + + crate::Result::<_>::Ok(executor) + }; + + let query_schema_span = tracing::info_span!("prisma:engine:schema"); + let query_schema_fut = tokio::runtime::Handle::current() + .spawn_blocking(move || { + let enable_raw_queries = true; + schema::build(arced_schema_2, enable_raw_queries) + }) + .instrument(query_schema_span); + + let (query_schema, executor) = tokio::join!(query_schema_fut, executor_fut); + + Ok(ConnectedEngine { + schema: builder.schema.clone(), + query_schema: Arc::new(query_schema.unwrap()), + executor: executor?, + config_dir: builder.config_dir.clone(), + env: builder.env.clone(), + engine_protocol: builder.engine_protocol, + }) as crate::Result + } + .instrument(span) + .await?; + + *inner = Inner::Connected(engine); + + Ok(()) + }) + .await?; + Ok(()) } /// Disconnect and drop the core. Can be reconnected later with `#connect`. #[wasm_bindgen] pub async fn disconnect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - log::info!("Called `QueryEngine::disconnect()`"); - Ok(()) + async_panic_to_js_error(async { + let span = tracing::info_span!("prisma:engine:disconnect"); + + // TODO: when using Node Drivers, we need to call Driver::close() here. + + async { + let mut inner = self.inner.write().await; + let engine = inner.as_engine()?; + + let builder = EngineBuilder { + schema: engine.schema.clone(), + config_dir: engine.config_dir.clone(), + env: engine.env.clone(), + engine_protocol: engine.engine_protocol(), + }; + + *inner = Inner::Builder(builder); + + Ok(()) + } + .instrument(span) + .await + }) + .await } /// If connected, sends a query to the core and returns the response. From d0541140ac8ceae5973124fb1e84e872ab7b0de2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 22 Nov 2023 01:12:00 +0100 Subject: [PATCH 054/134] fix: remove tokio-induced panic in "connect" --- .../query-engine-wasm/src/wasm/engine.rs | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index d94d0cc371d0..c274f48a5c03 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -240,7 +240,7 @@ impl QueryEngine { let preview_features = arced_schema.configuration.preview_features(); - let executor_fut = async { + let executor = async { let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; let connector = executor.primary_connector(); @@ -253,21 +253,17 @@ impl QueryEngine { connector.get_connection().instrument(conn_span).await?; crate::Result::<_>::Ok(executor) - }; - - let query_schema_span = tracing::info_span!("prisma:engine:schema"); - let query_schema_fut = tokio::runtime::Handle::current() - .spawn_blocking(move || { - let enable_raw_queries = true; - schema::build(arced_schema_2, enable_raw_queries) - }) - .instrument(query_schema_span); + } + .await; - let (query_schema, executor) = tokio::join!(query_schema_fut, executor_fut); + let query_schema = { + let enable_raw_queries = true; + schema::build(arced_schema_2, enable_raw_queries) + }; Ok(ConnectedEngine { schema: builder.schema.clone(), - query_schema: Arc::new(query_schema.unwrap()), + query_schema: Arc::new(query_schema), executor: executor?, config_dir: builder.config_dir.clone(), env: builder.env.clone(), From d6d09d6f47999af237eac47771d69db68d00e2ed Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 22 Nov 2023 02:08:07 +0100 Subject: [PATCH 055/134] feat: remove ducktor --- Cargo.lock | 12 ------ query-engine/driver-adapters/Cargo.toml | 1 - .../src/wasm/async_js_function.rs | 14 +++++++ .../src/wasm/js_object_extern.rs | 10 +++++ query-engine/driver-adapters/src/wasm/mod.rs | 3 ++ .../driver-adapters/src/wasm/proxy.rs | 39 +++++++++++++------ .../driver-adapters/src/wasm/queryable.rs | 11 +++--- .../query-engine-wasm/src/wasm/engine.rs | 5 ++- 8 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 query-engine/driver-adapters/src/wasm/js_object_extern.rs diff --git a/Cargo.lock b/Cargo.lock index 7f79a3e8a5e7..ac4a15e64a61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1097,7 +1097,6 @@ dependencies = [ "async-trait", "bigdecimal", "chrono", - "ducktor", "expect-test", "futures", "js-sys", @@ -1121,17 +1120,6 @@ dependencies = [ "wasm-bindgen-futures", ] -[[package]] -name = "ducktor" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c421abf6328bda65f53e6a76ee9837fd197b23bdfdbcebc4d7917dfaa1cf88ae" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - [[package]] name = "either" version = "1.9.0" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 29697cdc95c4..8fa27edb5aa0 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -37,5 +37,4 @@ serde-wasm-bindgen.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tsify.workspace = true -ducktor = "0.1.0" pin-project = "1" diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 5ae0708054c6..29e168e9665e 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -22,6 +22,20 @@ where _phantom_return: PhantomData, } +impl From for AsyncJsFunction +where + T: Serialize, + R: DeserializeOwned, +{ + fn from(js_fn: JsFunction) -> Self { + Self { + threadsafe_fn: js_fn, + _phantom_arg: PhantomData:: {}, + _phantom_return: PhantomData:: {}, + } + } +} + impl AsyncJsFunction where T: Serialize, diff --git a/query-engine/driver-adapters/src/wasm/js_object_extern.rs b/query-engine/driver-adapters/src/wasm/js_object_extern.rs new file mode 100644 index 000000000000..8804e706f67f --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/js_object_extern.rs @@ -0,0 +1,10 @@ +use js_sys::JsString; +use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; + +#[wasm_bindgen] +extern "C" { + pub type JsObjectExtern; + + #[wasm_bindgen(method, catch, structural, indexing_getter)] + pub fn get(this: &JsObjectExtern, key: JsString) -> Result; +} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 8636204577c9..5f817569c31c 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -3,8 +3,11 @@ mod async_js_function; mod conversion; mod error; +mod js_object_extern; mod proxy; mod queryable; mod send_future; mod transaction; + +pub use js_object_extern::JsObjectExtern; pub use queryable::{from_wasm, JsQueryable}; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index e4efc059ce58..75bc8f6347e2 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,11 +1,12 @@ -use ducktor::FromJsValue as DuckType; use futures::Future; -use js_sys::{Function as JsFunction, Object as JsObject}; +use js_sys::{Function as JsFunction, JsString, Object as JsObject}; +use tsify::Tsify; use super::{async_js_function::AsyncJsFunction, send_future::SendFuture, transaction::JsTransaction}; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; +use crate::JsObjectExtern; use metrics::increment_gauge; -use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; +use wasm_bindgen::{prelude::wasm_bindgen, JsCast, JsValue}; type JsResult = core::result::Result; @@ -13,7 +14,7 @@ type JsResult = core::result::Result; /// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to /// invoke the code within the node runtime that implements the client connector. #[wasm_bindgen(getter_with_clone)] -#[derive(DuckType, Default)] +#[derive(Default)] pub(crate) struct CommonProxy { /// Execute a query given as SQL, interpolating the given parameters. query_raw: AsyncJsFunction, @@ -29,7 +30,6 @@ pub(crate) struct CommonProxy { /// This is a JS proxy for accessing the methods specific to top level /// JS driver objects #[wasm_bindgen(getter_with_clone)] -#[derive(DuckType)] pub(crate) struct DriverProxy { start_transaction: AsyncJsFunction<(), JsTransaction>, } @@ -37,7 +37,7 @@ pub(crate) struct DriverProxy { /// This a JS proxy for accessing the methods, specific /// to JS transaction objects #[wasm_bindgen(getter_with_clone)] -#[derive(DuckType, Default)] +#[derive(Default)] pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, @@ -54,8 +54,14 @@ pub(crate) struct TransactionProxy { } impl CommonProxy { - pub fn new(object: &JsObject) -> Self { - CommonProxy::from(&object.into()) + pub fn new(object: &JsObjectExtern) -> JsResult { + let flavour: String = JsString::from(object.get("value".into())?).into(); + + Ok(Self { + query_raw: JsFunction::from(object.get("queryRaw".into())?).into(), + execute_raw: JsFunction::from(object.get("executeRaw".into())?).into(), + flavour, + }) } pub async fn query_raw(&self, params: Query) -> quaint::Result { @@ -68,8 +74,10 @@ impl CommonProxy { } impl DriverProxy { - pub fn new(object: &JsObject) -> Self { - Self::from(&object.into()) + pub fn new(object: &JsObjectExtern) -> JsResult { + Ok(Self { + start_transaction: JsFunction::from(object.get("startTransaction".into())?).into(), + }) } async fn start_transaction_inner(&self) -> quaint::Result> { @@ -91,8 +99,15 @@ impl DriverProxy { } impl TransactionProxy { - pub fn new(object: &JsObject) -> Self { - Self::from(&object.into()) + pub fn new(object: &JsObjectExtern) -> JsResult { + let options = object.get("options".into())?; + + Ok(Self { + options: TransactionOptions::from_js(options).unwrap(), + commit: JsFunction::from(object.get("commit".into())?).into(), + rollback: JsFunction::from(object.get("dispose".into())?).into(), + dispose: object.get("dispose".into())?.into(), + }) } pub fn options(&self) -> &TransactionOptions { diff --git a/query-engine/driver-adapters/src/wasm/queryable.rs b/query-engine/driver-adapters/src/wasm/queryable.rs index abb828a443c0..edb0de4ea493 100644 --- a/query-engine/driver-adapters/src/wasm/queryable.rs +++ b/query-engine/driver-adapters/src/wasm/queryable.rs @@ -1,10 +1,11 @@ +use crate::JsObjectExtern; + use super::{ conversion, proxy::{CommonProxy, DriverProxy, Query}, send_future::SendFuture, }; use async_trait::async_trait; -use ducktor::FromJsValue as DuckType; use futures::Future; use js_sys::Object as JsObject; use psl::datamodel_connector::Flavour; @@ -30,7 +31,7 @@ use wasm_bindgen::prelude::wasm_bindgen; /// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector /// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. #[wasm_bindgen(getter_with_clone)] -#[derive(DuckType, Default)] +#[derive(Default)] pub(crate) struct JsBaseQueryable { pub(crate) proxy: CommonProxy, pub flavour: Flavour, @@ -313,9 +314,9 @@ impl TransactionCapable for JsQueryable { } } -pub fn from_wasm(driver: JsObject) -> JsQueryable { - let common = CommonProxy::new(&driver); - let driver_proxy = DriverProxy::new(&driver); +pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable { + let common = CommonProxy::new(&driver).unwrap(); + let driver_proxy = DriverProxy::new(&driver).unwrap(); JsQueryable { inner: JsBaseQueryable::new(common), diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 20c67ee24bce..7bf67c53c2e7 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -5,7 +5,8 @@ use crate::{ error::ApiError, logger::{LogCallback, Logger}, }; -use js_sys::{Function as JsFunction, Object as JsObject}; +use driver_adapters::JsObjectExtern; +use js_sys::Function as JsFunction; use request_handlers::ConnectorMode; use serde::{Deserialize, Serialize}; use std::{ @@ -103,7 +104,7 @@ impl QueryEngine { pub fn new( options: ConstructorOptions, callback: JsFunction, - maybe_adapter: Option, + maybe_adapter: Option, ) -> Result { log::info!("Called `QueryEngine::new()`"); From 2347cb13044d04d30f40b74b1def849690a3f604 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 22 Nov 2023 02:46:55 +0100 Subject: [PATCH 056/134] feat(driver-adapters): remove "queryable" into its own module --- query-engine/driver-adapters/Cargo.toml | 2 +- query-engine/driver-adapters/src/lib.rs | 2 + .../driver-adapters/src/napi/conversion.rs | 2 +- query-engine/driver-adapters/src/napi/mod.rs | 6 +- .../driver-adapters/src/napi/transaction.rs | 6 +- .../driver-adapters/src/queryable/mod.rs | 310 ++++++++++++++++++ .../driver-adapters/src/queryable/napi.rs | 32 ++ .../driver-adapters/src/queryable/wasm.rs | 33 ++ .../src/{wasm => }/send_future.rs | 0 query-engine/driver-adapters/src/wasm/mod.rs | 6 +- .../driver-adapters/src/wasm/proxy.rs | 7 +- .../driver-adapters/src/wasm/transaction.rs | 8 +- 12 files changed, 392 insertions(+), 22 deletions(-) create mode 100644 query-engine/driver-adapters/src/queryable/mod.rs create mode 100644 query-engine/driver-adapters/src/queryable/napi.rs create mode 100644 query-engine/driver-adapters/src/queryable/wasm.rs rename query-engine/driver-adapters/src/{wasm => }/send_future.rs (100%) diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 8fa27edb5aa0..ec77df85e142 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -13,6 +13,7 @@ tracing = "0.1" tracing-core = "0.1" metrics = "0.18" uuid = { version = "1", features = ["v4"] } +pin-project = "1" # Note: these deps are temporarily specified here to avoid importing them from tiberius (the SQL server driver). # They will be imported from quaint-core instead in a future PR. @@ -37,4 +38,3 @@ serde-wasm-bindgen.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tsify.workspace = true -pin-project = "1" diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index ca8aa4541bd1..0e40b814c43f 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -9,6 +9,8 @@ pub(crate) mod conversion; pub(crate) mod error; +pub(crate) mod queryable; +pub(crate) mod send_future; pub(crate) mod types; #[cfg(not(target_arch = "wasm32"))] diff --git a/query-engine/driver-adapters/src/napi/conversion.rs b/query-engine/driver-adapters/src/napi/conversion.rs index 5ab630998d27..ac2dda60a279 100644 --- a/query-engine/driver-adapters/src/napi/conversion.rs +++ b/query-engine/driver-adapters/src/napi/conversion.rs @@ -1,4 +1,4 @@ -pub(crate) use crate::conversion::{mysql, postgres, sqlite, JSArg}; +pub(crate) use crate::conversion::JSArg; use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; use napi::NapiValue; diff --git a/query-engine/driver-adapters/src/napi/mod.rs b/query-engine/driver-adapters/src/napi/mod.rs index 05267dec453b..4612cb550553 100644 --- a/query-engine/driver-adapters/src/napi/mod.rs +++ b/query-engine/driver-adapters/src/napi/mod.rs @@ -3,8 +3,8 @@ mod async_js_function; mod conversion; mod error; -mod proxy; -mod queryable; +pub(crate) mod proxy; mod result; mod transaction; -pub use queryable::{from_napi, JsQueryable}; + +pub use crate::queryable::{from_napi, JsQueryable}; diff --git a/query-engine/driver-adapters/src/napi/transaction.rs b/query-engine/driver-adapters/src/napi/transaction.rs index 16ecbb435ce9..69219d06ef1e 100644 --- a/query-engine/driver-adapters/src/napi/transaction.rs +++ b/query-engine/driver-adapters/src/napi/transaction.rs @@ -7,10 +7,8 @@ use quaint::{ Value, }; -use super::{ - proxy::{CommonProxy, TransactionOptions, TransactionProxy}, - queryable::JsBaseQueryable, -}; +use super::proxy::{CommonProxy, TransactionOptions, TransactionProxy}; +use crate::queryable::JsBaseQueryable; // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, diff --git a/query-engine/driver-adapters/src/queryable/mod.rs b/query-engine/driver-adapters/src/queryable/mod.rs new file mode 100644 index 000000000000..ac252bbb011b --- /dev/null +++ b/query-engine/driver-adapters/src/queryable/mod.rs @@ -0,0 +1,310 @@ +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod napi; + +#[cfg(not(target_arch = "wasm32"))] +pub use napi::from_napi; + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) use napi::JsBaseQueryable; + +#[cfg(target_arch = "wasm32")] +pub(crate) mod wasm; + +#[cfg(target_arch = "wasm32")] +pub use wasm::from_wasm; + +#[cfg(target_arch = "wasm32")] +pub(crate) use wasm::JsBaseQueryable; + +use super::{ + conversion, + proxy::{CommonProxy, DriverProxy, Query}, +}; +use crate::send_future::SendFuture; +use async_trait::async_trait; +use futures::Future; +use psl::datamodel_connector::Flavour; +use quaint::{ + connector::{metrics, IsolationLevel, Transaction}, + error::{Error, ErrorKind}, + prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, + visitor::{self, Visitor}, +}; +use tracing::{info_span, Instrument}; + +impl JsBaseQueryable { + pub(crate) fn new(proxy: CommonProxy) -> Self { + let flavour: Flavour = proxy.flavour.parse().unwrap(); + Self { proxy, flavour } + } + + /// visit a quaint query AST according to the flavour of the JS connector + fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { + match self.flavour { + Flavour::Mysql => visitor::Mysql::build(q), + Flavour::Postgres => visitor::Postgres::build(q), + Flavour::Sqlite => visitor::Sqlite::build(q), + _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), + } + } + + async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { + let sql: String = sql.to_string(); + + let converter = match self.flavour { + Flavour::Postgres => conversion::postgres::value_to_js_arg, + Flavour::Sqlite => conversion::sqlite::value_to_js_arg, + Flavour::Mysql => conversion::mysql::value_to_js_arg, + _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), + }; + + let args = values + .iter() + .map(converter) + .collect::>>()?; + + Ok(Query { sql, args }) + } +} + +#[async_trait] +impl QuaintQueryable for JsBaseQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.query_raw", sql, params, move || async move { + self.do_query_raw(sql, params).await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + let (sql, params) = self.visit_quaint_query(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + metrics::query("js.execute_raw", sql, params, move || async move { + self.do_execute_raw(sql, params).await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + let params = &[]; + metrics::query("js.raw_cmd", cmd, params, move || async move { + self.do_execute_raw(cmd, params).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> quaint::Result> { + // Note: JS Connectors don't use this method. + Ok(None) + } + + fn is_healthy(&self) -> bool { + // Note: JS Connectors don't use this method. + true + } + + /// Sets the transaction isolation level to given value. + /// Implementers have to make sure that the passed isolation level is valid for the underlying database. + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + if self.flavour == Flavour::Sqlite { + return match isolation_level { + IsolationLevel::Serializable => Ok(()), + _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), + }; + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await + } + + fn requires_isolation_first(&self) -> bool { + match self.flavour { + Flavour::Mysql => true, + Flavour::Postgres | Flavour::Sqlite => false, + _ => unreachable!(), + } + } +} + +impl JsBaseQueryable { + pub fn phantom_query_message(stmt: &str) -> String { + format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) + } + + async fn do_query_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; + + let len = result_set.len(); + let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); + + result_set.try_into() + } + + fn do_query_raw<'a>( + &'a self, + sql: &'a str, + params: &'a [quaint::Value<'a>], + ) -> SendFuture> + 'a> { + SendFuture(self.do_query_raw_inner(sql, params)) + } + + async fn do_execute_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + let len = params.len(); + let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); + let query = self.build_query(sql, params).instrument(serialization_span).await?; + + let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; + + Ok(affected_rows as u64) + } + + fn do_execute_raw<'a>( + &'a self, + sql: &'a str, + params: &'a [quaint::Value<'a>], + ) -> SendFuture> + 'a> { + SendFuture(self.do_execute_raw_inner(sql, params)) + } +} + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub struct JsQueryable { + inner: JsBaseQueryable, + driver_proxy: DriverProxy, +} + +impl std::fmt::Display for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +impl std::fmt::Debug for JsQueryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSQueryable(driver)") + } +} + +#[async_trait] +impl QuaintQueryable for JsQueryable { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + let tx = self.driver_proxy.start_transaction().await?; + + let isolation_first = tx.requires_isolation_first(); + + if isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + let begin_stmt = tx.begin_statement(); + + let tx_opts = tx.options(); + if tx_opts.use_phantom_query { + let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + tx.raw_phantom_cmd(begin_stmt.as_str()).await?; + } else { + tx.raw_cmd(begin_stmt).await?; + } + + if !isolation_first { + if let Some(isolation) = isolation { + tx.set_tx_isolation_level(isolation).await?; + } + } + + self.server_reset_query(tx.as_ref()).await?; + + Ok(tx) + } +} diff --git a/query-engine/driver-adapters/src/queryable/napi.rs b/query-engine/driver-adapters/src/queryable/napi.rs new file mode 100644 index 000000000000..2245802908c2 --- /dev/null +++ b/query-engine/driver-adapters/src/queryable/napi.rs @@ -0,0 +1,32 @@ +use crate::napi::proxy::{CommonProxy, DriverProxy}; +use crate::JsQueryable; +use napi::JsObject; +use psl::datamodel_connector::Flavour; + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +/// +pub(crate) struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, + pub flavour: Flavour, +} + +pub fn from_napi(driver: JsObject) -> JsQueryable { + let common = CommonProxy::new(&driver).unwrap(); + let driver_proxy = DriverProxy::new(&driver).unwrap(); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } +} diff --git a/query-engine/driver-adapters/src/queryable/wasm.rs b/query-engine/driver-adapters/src/queryable/wasm.rs new file mode 100644 index 000000000000..867d1fb5081a --- /dev/null +++ b/query-engine/driver-adapters/src/queryable/wasm.rs @@ -0,0 +1,33 @@ +use crate::wasm::proxy::{CommonProxy, DriverProxy}; +use crate::{JsObjectExtern, JsQueryable}; +use psl::datamodel_connector::Flavour; +use wasm_bindgen::prelude::wasm_bindgen; + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +#[wasm_bindgen(getter_with_clone)] +#[derive(Default)] +pub(crate) struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, + pub flavour: Flavour, +} + +pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable { + let common = CommonProxy::new(&driver).unwrap(); + let driver_proxy = DriverProxy::new(&driver).unwrap(); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } +} diff --git a/query-engine/driver-adapters/src/wasm/send_future.rs b/query-engine/driver-adapters/src/send_future.rs similarity index 100% rename from query-engine/driver-adapters/src/wasm/send_future.rs rename to query-engine/driver-adapters/src/send_future.rs diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 5f817569c31c..416b9db4a787 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -4,10 +4,8 @@ mod async_js_function; mod conversion; mod error; mod js_object_extern; -mod proxy; -mod queryable; -mod send_future; +pub(crate) mod proxy; mod transaction; +pub use crate::queryable::{from_wasm, JsQueryable}; pub use js_object_extern::JsObjectExtern; -pub use queryable::{from_wasm, JsQueryable}; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index 75bc8f6347e2..7ab578830f42 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,12 +1,13 @@ use futures::Future; -use js_sys::{Function as JsFunction, JsString, Object as JsObject}; +use js_sys::{Function as JsFunction, JsString}; use tsify::Tsify; -use super::{async_js_function::AsyncJsFunction, send_future::SendFuture, transaction::JsTransaction}; +use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; +use crate::send_future::SendFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use crate::JsObjectExtern; use metrics::increment_gauge; -use wasm_bindgen::{prelude::wasm_bindgen, JsCast, JsValue}; +use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; type JsResult = core::result::Result; diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs index e7aba9f418c4..43925b488101 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -6,13 +6,9 @@ use quaint::{ Value, }; use serde::Deserialize; -use wasm_bindgen::prelude::wasm_bindgen; -use super::{ - proxy::{CommonProxy, TransactionOptions, TransactionProxy}, - queryable::JsBaseQueryable, - send_future::SendFuture, -}; +use super::proxy::{TransactionOptions, TransactionProxy}; +use crate::{queryable::JsBaseQueryable, send_future::SendFuture}; // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, From a004fe64e69f80e27322c25a35ffd4e296f4b772 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 09:16:53 +0100 Subject: [PATCH 057/134] Couple of fixes 1. Build schema and connect sequentually 2. Print full stacktrace for WASM error 3. Expand example to attempt a query --- .../driver-adapters/src/wasm/error.rs | 10 +-- query-engine/query-engine-wasm/example.js | 17 ++++- .../query-engine-wasm/package-lock.json | 22 +++--- .../query-engine-wasm/src/wasm/engine.rs | 69 +++++++++---------- 4 files changed, 65 insertions(+), 53 deletions(-) diff --git a/query-engine/driver-adapters/src/wasm/error.rs b/query-engine/driver-adapters/src/wasm/error.rs index e0b588794302..0aa4fe7981f2 100644 --- a/query-engine/driver-adapters/src/wasm/error.rs +++ b/query-engine/driver-adapters/src/wasm/error.rs @@ -1,11 +1,13 @@ +use js_sys::Reflect; use quaint::error::Error as QuaintError; use wasm_bindgen::JsValue; -type WasmError = JsValue; - /// transforms a Wasm error into a Quaint error -pub(crate) fn into_quaint_error(wasm_err: WasmError) -> QuaintError { +pub(crate) fn into_quaint_error(wasm_err: JsValue) -> QuaintError { let status = "WASM_ERROR".to_string(); - let reason = wasm_err.as_string().unwrap_or_else(|| "unknown error".to_string()); + let reason = Reflect::get(&wasm_err, &JsValue::from_str("stack")) + .ok() + .and_then(|value| value.as_string()) + .unwrap_or_else(|| "Unknown error".to_string()); QuaintError::raw_connector_error(status, reason) } diff --git a/query-engine/query-engine-wasm/example.js b/query-engine/query-engine-wasm/example.js index 6d3a78374bc8..1fa929e80b2f 100644 --- a/query-engine/query-engine-wasm/example.js +++ b/query-engine/query-engine-wasm/example.js @@ -6,7 +6,7 @@ import { Pool } from '@neondatabase/serverless' import { PrismaNeon } from '@prisma/adapter-neon' import { bindAdapter } from '@prisma/driver-adapter-utils' -import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine_wasm.js' +import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine.js' async function main() { // Always initialize the Wasm library before using it. @@ -48,7 +48,22 @@ async function main() { const queryEngine = new QueryEngine(options, callback, driverAdapter) await queryEngine.connect('trace') + const res = await queryEngine.query(JSON.stringify({ + modelName: 'User', + action: 'findMany', + query: { + arguments: {}, + selection: { + $scalars: true + } + } + }), 'trace') + const parsed = JSON.parse(res); + console.log('query result = ', parsed) await queryEngine.disconnect('trace') + console.log('after disconnect') + queryEngine.free() + await driverAdapter.close() } main() diff --git a/query-engine/query-engine-wasm/package-lock.json b/query-engine/query-engine-wasm/package-lock.json index c2d5a7a1162e..86b53cb0cde2 100644 --- a/query-engine/query-engine-wasm/package-lock.json +++ b/query-engine/query-engine-wasm/package-lock.json @@ -6,8 +6,8 @@ "": { "dependencies": { "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.5.2", - "@prisma/driver-adapter-utils": "5.5.2" + "@prisma/adapter-neon": "5.6.0", + "@prisma/driver-adapter-utils": "5.6.0" } }, "node_modules/@neondatabase/serverless": { @@ -19,12 +19,12 @@ } }, "node_modules/@prisma/adapter-neon": { - "version": "5.5.2", - "resolved": "https://registry.npmjs.org/@prisma/adapter-neon/-/adapter-neon-5.5.2.tgz", - "integrity": "sha512-XcpJ/fgh/sP7mlBFkqjIzEcU/kWnNyiZf19MBP366HF7vXg2UQTbGxmbbeFiohXSJ/rwyu1Qmos7IrKK+QJOgg==", + "version": "5.6.0", + "resolved": "https://registry.npmjs.org/@prisma/adapter-neon/-/adapter-neon-5.6.0.tgz", + "integrity": "sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==", "dependencies": { - "@prisma/driver-adapter-utils": "5.5.2", - "postgres-array": "^3.0.2" + "@prisma/driver-adapter-utils": "5.6.0", + "postgres-array": "3.0.2" }, "peerDependencies": { "@neondatabase/serverless": "^0.6.0" @@ -39,11 +39,11 @@ } }, "node_modules/@prisma/driver-adapter-utils": { - "version": "5.5.2", - "resolved": "https://registry.npmjs.org/@prisma/driver-adapter-utils/-/driver-adapter-utils-5.5.2.tgz", - "integrity": "sha512-lRkxjboGcIl2VkJNomZQ9b6vc2qGFnVwjaR/o3cTPGmmSxETx71cYRYcG/NHKrhvKxI6oKNZ/xzyuzPpg1+kJQ==", + "version": "5.6.0", + "resolved": "https://registry.npmjs.org/@prisma/driver-adapter-utils/-/driver-adapter-utils-5.6.0.tgz", + "integrity": "sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==", "dependencies": { - "debug": "^4.3.4" + "debug": "4.3.4" } }, "node_modules/@types/node": { diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index c274f48a5c03..caadc6efdc17 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -240,31 +240,24 @@ impl QueryEngine { let preview_features = arced_schema.configuration.preview_features(); - let executor = async { - let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; - let connector = executor.primary_connector(); + let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; + let connector = executor.primary_connector(); - let conn_span = tracing::info_span!( - "prisma:engine:connection", - user_facing = true, - "db.type" = connector.name(), - ); + let conn_span = tracing::info_span!( + "prisma:engine:connection", + user_facing = true, + "db.type" = connector.name(), + ); - connector.get_connection().instrument(conn_span).await?; + connector.get_connection().instrument(conn_span).await?; - crate::Result::<_>::Ok(executor) - } - .await; - - let query_schema = { - let enable_raw_queries = true; - schema::build(arced_schema_2, enable_raw_queries) - }; + let query_schema_span = tracing::info_span!("prisma:engine:schema"); + let query_schema = query_schema_span.in_scope(|| schema::build(arced_schema_2, true)); Ok(ConnectedEngine { schema: builder.schema.clone(), query_schema: Arc::new(query_schema), - executor: executor?, + executor, config_dir: builder.config_dir.clone(), env: builder.env.clone(), engine_protocol: builder.engine_protocol, @@ -285,30 +278,32 @@ impl QueryEngine { /// Disconnect and drop the core. Can be reconnected later with `#connect`. #[wasm_bindgen] pub async fn disconnect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:disconnect"); + // async_panic_to_js_error(async { + // let span = tracing::info_span!("prisma:engine:disconnect"); - // TODO: when using Node Drivers, we need to call Driver::close() here. + // TODO: when using Node Drivers, we need to call Driver::close() here. - async { - let mut inner = self.inner.write().await; - let engine = inner.as_engine()?; + // async { + let mut inner = self.inner.write().await; + let engine = inner.as_engine()?; - let builder = EngineBuilder { - schema: engine.schema.clone(), - config_dir: engine.config_dir.clone(), - env: engine.env.clone(), - engine_protocol: engine.engine_protocol(), - }; + let builder = EngineBuilder { + schema: engine.schema.clone(), + config_dir: engine.config_dir.clone(), + env: engine.env.clone(), + engine_protocol: engine.engine_protocol(), + }; - *inner = Inner::Builder(builder); + log::info!("Recreated builder"); + *inner = Inner::Builder(builder); + log::info!("Recreated inner builder"); - Ok(()) - } - .instrument(span) - .await - }) - .await + Ok(()) + // } + // .instrument(span) + // .await + // }) + // .await } /// If connected, sends a query to the core and returns the response. From 49271f687484c4b39b2685c3d2d6a7e18a4de9c2 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 11:01:53 +0100 Subject: [PATCH 058/134] Fix `Instant::now` usage Replace `Instant::now` with a custom library, that will use appriate time/date functions for the platforms. For native, it is `std::time`, for WASM in should probably be `performance::now()` but it is a no-op stub for now. --- Cargo.lock | 17 ++++++++--- libs/elapsed/Cargo.toml | 11 +++++++ libs/elapsed/src/lib.rs | 9 ++++++ libs/elapsed/src/native.rs | 17 +++++++++++ libs/elapsed/src/wasm.rs | 15 ++++++++++ quaint/Cargo.toml | 16 ++-------- quaint/src/connector/metrics.rs | 25 ++++++++-------- query-engine/core/Cargo.toml | 3 +- .../core/src/executor/execute_operation.rs | 30 ++++++++++++------- query-engine/query-engine-wasm/build.sh | 2 +- 10 files changed, 104 insertions(+), 41 deletions(-) create mode 100644 libs/elapsed/Cargo.toml create mode 100644 libs/elapsed/src/lib.rs create mode 100644 libs/elapsed/src/native.rs create mode 100644 libs/elapsed/src/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 7f79a3e8a5e7..5475f2db4153 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1138,6 +1138,13 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "elapsed" +version = "0.1.0" +dependencies = [ + "js-sys", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -2001,9 +2008,9 @@ checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "js-sys" -version = "0.3.61" +version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" +checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" dependencies = [ "wasm-bindgen", ] @@ -3589,6 +3596,7 @@ dependencies = [ "chrono", "connection-string", "either", + "elapsed", "futures", "getrandom 0.2.10", "hex", @@ -3691,6 +3699,7 @@ dependencies = [ "connection-string", "crossbeam-channel", "cuid", + "elapsed", "enumflags2", "futures", "indexmap 1.9.3", @@ -6115,9 +6124,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.61" +version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" +checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/libs/elapsed/Cargo.toml b/libs/elapsed/Cargo.toml new file mode 100644 index 000000000000..71f103e73ab4 --- /dev/null +++ b/libs/elapsed/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "elapsed" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[target.'cfg(target_arch = "wasm32")'.dependencies] +js-sys.workspace = true diff --git a/libs/elapsed/src/lib.rs b/libs/elapsed/src/lib.rs new file mode 100644 index 000000000000..d339f75f836d --- /dev/null +++ b/libs/elapsed/src/lib.rs @@ -0,0 +1,9 @@ +#[cfg(target_arch = "wasm32")] +mod wasm; +#[cfg(target_arch = "wasm32")] +pub use crate::wasm::ElapsedTimeCounter; + +#[cfg(not(target_arch = "wasm32"))] +mod native; +#[cfg(not(target_arch = "wasm32"))] +pub use crate::native::ElapsedTimeCounter; diff --git a/libs/elapsed/src/native.rs b/libs/elapsed/src/native.rs new file mode 100644 index 000000000000..93855abbe648 --- /dev/null +++ b/libs/elapsed/src/native.rs @@ -0,0 +1,17 @@ +use std::time::{Duration, Instant}; + +pub struct ElapsedTimeCounter { + instant: Instant, +} + +impl ElapsedTimeCounter { + pub fn start() -> Self { + let instant = Instant::now(); + + Self { instant } + } + + pub fn elapsed_time(&self) -> Duration { + self.instant.elapsed() + } +} diff --git a/libs/elapsed/src/wasm.rs b/libs/elapsed/src/wasm.rs new file mode 100644 index 000000000000..cdd83251e4b1 --- /dev/null +++ b/libs/elapsed/src/wasm.rs @@ -0,0 +1,15 @@ +use std::time::Duration; + +/// TODO: this is a stub that always returns 0 as elapsed time +/// In should use performance::now() instead +pub struct ElapsedTimeCounter {} + +impl ElapsedTimeCounter { + pub fn start() -> Self { + Self {} + } + + pub fn elapsed_time(&self) -> Duration { + Duration::from_millis(0u64) + } +} diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 52a7edf72aca..b884834277d5 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -29,12 +29,7 @@ docs = [] # way to access database-specific methods when you need extra control. expose-drivers = [] -native = [ - "postgresql-native", - "mysql-native", - "mssql-native", - "sqlite-native", -] +native = ["postgresql-native", "mysql-native", "mssql-native", "sqlite-native"] all = ["native", "pooled"] @@ -57,13 +52,7 @@ postgresql-native = [ ] postgresql = [] -mssql-native = [ - "mssql", - "tiberius", - "tokio-util", - "tokio/time", - "tokio/net", -] +mssql-native = ["mssql", "tiberius", "tokio-util", "tokio/time", "tokio/net"] mssql = [] mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] @@ -100,6 +89,7 @@ mobc = { version = "0.8", optional = true } serde = { version = "1.0", optional = true } sqlformat = { version = "0.2.0", optional = true } uuid = { version = "1", features = ["v4"] } +elapsed = { path = "../libs/elapsed" } [dev-dependencies] once_cell = "1.3" diff --git a/quaint/src/connector/metrics.rs b/quaint/src/connector/metrics.rs index 2705a40b32b2..628a2e81f7a3 100644 --- a/quaint/src/connector/metrics.rs +++ b/quaint/src/connector/metrics.rs @@ -1,7 +1,8 @@ use tracing::{info_span, Instrument}; use crate::ast::{Params, Value}; -use std::{future::Future, time::Instant}; +use elapsed::ElapsedTimeCounter; +use std::future::Future; pub async fn query<'a, F, T, U>(tag: &'static str, query: &'a str, params: &'a [Value<'_>], f: F) -> crate::Result where @@ -17,7 +18,7 @@ where F: FnOnce() -> U + 'a, U: Future>, { - let start = Instant::now(); + let start = ElapsedTimeCounter::start(); let res = f().await; let result = match res { @@ -34,19 +35,19 @@ where sqlformat::FormatOptions::default(), ); - trace_query(&query_fmt, params, result, start); + trace_query(&query_fmt, params, result, &start); } else { - trace_query(query, params, result, start); + trace_query(query, params, result, &start); }; } #[cfg(not(feature = "fmt-sql"))] { - trace_query(query, params, result, start); + trace_query(query, params, result, &start); } - histogram!(format!("{tag}.query.time"), start.elapsed()); - histogram!("prisma_datasource_queries_duration_histogram_ms", start.elapsed()); + histogram!(format!("{tag}.query.time"), start.elapsed_time()); + histogram!("prisma_datasource_queries_duration_histogram_ms", start.elapsed_time()); increment_counter!("prisma_datasource_queries_total"); res @@ -57,7 +58,7 @@ pub(crate) async fn check_out(f: F) -> std::result::Result>>, { - let start = Instant::now(); + let start = ElapsedTimeCounter::start(); let res = f.await; let result = match res { @@ -67,24 +68,24 @@ where tracing::trace!( message = "Fetched a connection from the pool", - duration_ms = start.elapsed().as_millis() as u64, + duration_ms = start.elapsed_time().as_millis() as u64, item_type = "query", is_query = true, result, ); - histogram!("pool.check_out", start.elapsed()); + histogram!("pool.check_out", start.elapsed_time()); res } -fn trace_query<'a>(query: &'a str, params: &'a [Value<'_>], result: &str, start: Instant) { +fn trace_query<'a>(query: &'a str, params: &'a [Value<'_>], result: &str, start: &ElapsedTimeCounter) { tracing::debug!( query = %query, params = %Params(params), result, item_type = "query", is_query = true, - duration_ms = start.elapsed().as_millis() as u64, + duration_ms = start.elapsed_time().as_millis() as u64, ); } diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index e5e90cb9937b..42e7ee301525 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -10,7 +10,7 @@ metrics = ["query-engine-metrics"] async-trait = "0.1" bigdecimal = "0.3" chrono = "0.4" -connection-string.workspace = true +connection-string.workspace = true connector = { path = "../connectors/query-connector", package = "query-connector" } crossbeam-channel = "0.5.6" psl.workspace = true @@ -34,6 +34,7 @@ user-facing-errors = { path = "../../libs/user-facing-errors" } uuid = "1" cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } schema = { path = "../schema" } +elapsed = { path = "../../libs/elapsed" } lru = "0.7.7" enumflags2 = "0.7" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 6ba21d37f9ff..c6860eecb29e 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -6,6 +6,7 @@ use crate::{ QueryGraphBuilder, QueryInterpreter, ResponseData, }; use connector::{Connection, ConnectionLike, Connector}; +use elapsed::ElapsedTimeCounter; use futures::future; #[cfg(feature = "metrics")] @@ -14,7 +15,7 @@ use query_engine_metrics::{ }; use schema::{QuerySchema, QuerySchemaRef}; -use std::time::{Duration, Instant}; +use std::time::Duration; use tracing::Instrument; use tracing_futures::WithSubscriber; @@ -24,13 +25,16 @@ pub async fn execute_single_operation( operation: &Operation, trace_id: Option, ) -> crate::Result { - let operation_timer = Instant::now(); + let operation_timer = ElapsedTimeCounter::start(); let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; #[cfg(feature = "metrics")] - histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); + histogram!( + PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, + operation_timer.elapsed_time() + ); result } @@ -49,11 +53,14 @@ pub async fn execute_many_operations( let mut results = Vec::with_capacity(queries.len()); for (i, (graph, serializer)) in queries.into_iter().enumerate() { - let operation_timer = Instant::now(); + let operation_timer = ElapsedTimeCounter::start(); let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; #[cfg(feature = "metrics")] - histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); + histogram!( + PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, + operation_timer.elapsed_time() + ); match result { Ok(result) => results.push(Ok(result)), @@ -150,14 +157,14 @@ async fn execute_self_contained( retry_on_transient_error: bool, trace_id: Option, ) -> crate::Result { - let operation_timer = Instant::now(); + let operation_timer = ElapsedTimeCounter::start(); let result = if retry_on_transient_error { execute_self_contained_with_retry( &mut conn, query_schema, operation, force_transactions, - Instant::now(), + ElapsedTimeCounter::start(), trace_id, ) .await @@ -168,7 +175,10 @@ async fn execute_self_contained( }; #[cfg(feature = "metrics")] - histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); + histogram!( + PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, + operation_timer.elapsed_time() + ); result } @@ -200,7 +210,7 @@ async fn execute_self_contained_with_retry( query_schema: QuerySchemaRef, operation: Operation, force_transactions: bool, - retry_timeout: Instant, + retry_timeout: ElapsedTimeCounter, trace_id: Option, ) -> crate::Result { let (graph, serializer) = build_graph(&query_schema, operation.clone())?; @@ -216,7 +226,7 @@ async fn execute_self_contained_with_retry( let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; - if is_transient_error(&res) && retry_timeout.elapsed() < MAX_TX_TIMEOUT_RETRY_LIMIT { + if is_transient_error(&res) && retry_timeout.elapsed_time() < MAX_TX_TIMEOUT_RETRY_LIMIT { tokio::time::sleep(TX_RETRY_BACKOFF).await; continue; } else { diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index dbbf6a720534..784d4e0e2064 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -13,7 +13,7 @@ OUT_NPM_NAME="@prisma/query-engine-wasm" # This little `sed -i` trick below is a hack to publish "@prisma/query-engine-wasm" # with the same binding filenames currently expected by the Prisma Client. sed -i '' 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml -wasm-pack build --release --target $OUT_TARGET +wasm-pack build --dev --target $OUT_TARGET sed -i '' 's/name = "query_engine"/name = "query_engine_wasm"/g' Cargo.toml sleep 1 From 82afce4e44c62e0efea4306ae8a966a65f4131ad Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 22 Nov 2023 13:35:33 +0100 Subject: [PATCH 059/134] fix(driver-adapters): understand "flavour" and adjust casing in "JSResultSet" --- query-engine/driver-adapters/src/types.rs | 1 + query-engine/driver-adapters/src/wasm/proxy.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 4f494c1bc092..2b2c0b45e50c 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -26,6 +26,7 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] #[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[cfg_attr(target_arch = "wasm32", serde(rename_all = "camelCase"))] #[derive(Debug, Default)] pub struct JSResultSet { pub column_types: Vec, diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index 7ab578830f42..bb2e9a855fe7 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -56,7 +56,7 @@ pub(crate) struct TransactionProxy { impl CommonProxy { pub fn new(object: &JsObjectExtern) -> JsResult { - let flavour: String = JsString::from(object.get("value".into())?).into(); + let flavour: String = JsString::from(object.get("flavour".into())?).into(); Ok(Self { query_raw: JsFunction::from(object.get("queryRaw".into())?).into(), From bbae4e3ac217b6a6e4d89cd0e3b8fee78ce8f022 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 22 Nov 2023 13:36:10 +0100 Subject: [PATCH 060/134] feat(driver-adapters): add some Into traits for "JsResult" --- query-engine/driver-adapters/src/wasm/mod.rs | 1 + .../driver-adapters/src/wasm/result.rs | 76 +++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 query-engine/driver-adapters/src/wasm/result.rs diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 416b9db4a787..8f0e0aeeb3de 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -5,6 +5,7 @@ mod conversion; mod error; mod js_object_extern; pub(crate) mod proxy; +mod result; mod transaction; pub use crate::queryable::{from_wasm, JsQueryable}; diff --git a/query-engine/driver-adapters/src/wasm/result.rs b/query-engine/driver-adapters/src/wasm/result.rs new file mode 100644 index 000000000000..c5fed81ebe24 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/result.rs @@ -0,0 +1,76 @@ +use js_sys::Boolean as JsBoolean; +use quaint::error::{Error as QuaintError, ErrorKind}; +use wasm_bindgen::{JsCast, JsValue}; + +use crate::{error::DriverAdapterError, JsObjectExtern}; + +impl From for QuaintError { + fn from(value: DriverAdapterError) -> Self { + match value { + DriverAdapterError::UnsupportedNativeDataType { native_type } => { + QuaintError::builder(ErrorKind::UnsupportedColumnType { + column_type: native_type, + }) + .build() + } + DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), + DriverAdapterError::Postgres(e) => e.into(), + DriverAdapterError::Mysql(e) => e.into(), + DriverAdapterError::Sqlite(e) => e.into(), + // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here + } + } +} + +/// Wrapper for JS-side result type +pub(crate) enum JsResult +where + T: From, +{ + Ok(T), + Err(DriverAdapterError), +} + +impl TryFrom for JsResult +where + T: From, +{ + type Error = JsValue; + + fn try_from(value: JsValue) -> Result { + Self::from_js_unknown(value) + } +} + +impl JsResult +where + T: From, +{ + fn from_js_unknown(unknown: JsValue) -> Result { + let object = unknown.unchecked_into::(); + + let ok: JsBoolean = object.get("ok".into())?.unchecked_into(); + let ok = ok.value_of(); + + if ok { + let value: JsValue = object.get("value".into())?; + return Ok(Self::Ok(T::from(value))); + } + + let error = object.get("error".into())?; + let error: DriverAdapterError = serde_wasm_bindgen::from_value(error)?; + Ok(Self::Err(error)) + } +} + +impl From> for quaint::Result +where + T: From, +{ + fn from(value: JsResult) -> Self { + match value { + JsResult::Ok(result) => Ok(result), + JsResult::Err(error) => Err(error.into()), + } + } +} From 77a79fa7f7cb5e6426576e704c68aa80afe67d85 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 15:09:52 +0100 Subject: [PATCH 061/134] Fix JSResult parsing --- .../src/wasm/async_js_function.rs | 25 +++++++++++-------- .../driver-adapters/src/wasm/result.rs | 14 ++++++----- query-engine/query-engine-wasm/example.js | 13 ++++++++-- .../query-engine-wasm/package-lock.json | 22 +++++++++++++++- query-engine/query-engine-wasm/package.json | 3 ++- 5 files changed, 56 insertions(+), 21 deletions(-) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 29e168e9665e..f4e3771694a2 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -7,8 +7,7 @@ use wasm_bindgen::{JsError, JsValue}; use wasm_bindgen_futures::JsFuture; use super::error::into_quaint_error; - -type JsResult = core::result::Result; +use super::result::JsResult; #[derive(Clone, Default)] pub(crate) struct AsyncJsFunction @@ -42,19 +41,23 @@ where R: DeserializeOwned, { pub async fn call(&self, arg1: T) -> quaint::Result { - let call_internal = async { - let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsError::from(&err))?; - let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; - let future = JsFuture::from(JsPromise::from(promise)); - let value = future.await?; - serde_wasm_bindgen::from_value(value).map_err(|err| JsValue::from(err)) - }; + let result = self.call_internal(arg1).await; - match call_internal.await { - Ok(result) => Ok(result), + match result { + Ok(js_result) => js_result.into(), Err(err) => Err(into_quaint_error(err)), } } + + async fn call_internal(&self, arg1: T) -> Result, JsValue> { + let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsValue::from(JsError::from(&err)))?; + let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; + let future = JsFuture::from(JsPromise::from(promise)); + let value = future.await?; + let js_result: JsResult = value.try_into()?; + + Ok(js_result) + } } impl WasmDescribe for AsyncJsFunction diff --git a/query-engine/driver-adapters/src/wasm/result.rs b/query-engine/driver-adapters/src/wasm/result.rs index c5fed81ebe24..df4652307469 100644 --- a/query-engine/driver-adapters/src/wasm/result.rs +++ b/query-engine/driver-adapters/src/wasm/result.rs @@ -1,5 +1,6 @@ use js_sys::Boolean as JsBoolean; use quaint::error::{Error as QuaintError, ErrorKind}; +use serde::de::DeserializeOwned; use wasm_bindgen::{JsCast, JsValue}; use crate::{error::DriverAdapterError, JsObjectExtern}; @@ -25,7 +26,7 @@ impl From for QuaintError { /// Wrapper for JS-side result type pub(crate) enum JsResult where - T: From, + T: DeserializeOwned, { Ok(T), Err(DriverAdapterError), @@ -33,7 +34,7 @@ where impl TryFrom for JsResult where - T: From, + T: DeserializeOwned, { type Error = JsValue; @@ -44,7 +45,7 @@ where impl JsResult where - T: From, + T: DeserializeOwned, { fn from_js_unknown(unknown: JsValue) -> Result { let object = unknown.unchecked_into::(); @@ -53,8 +54,9 @@ where let ok = ok.value_of(); if ok { - let value: JsValue = object.get("value".into())?; - return Ok(Self::Ok(T::from(value))); + let js_value: JsValue = object.get("value".into())?; + let deserialized = serde_wasm_bindgen::from_value::(js_value)?; + return Ok(Self::Ok(deserialized)); } let error = object.get("error".into())?; @@ -65,7 +67,7 @@ where impl From> for quaint::Result where - T: From, + T: DeserializeOwned, { fn from(value: JsResult) -> Self { match value { diff --git a/query-engine/query-engine-wasm/example.js b/query-engine/query-engine-wasm/example.js index 1fa929e80b2f..58453f41a85e 100644 --- a/query-engine/query-engine-wasm/example.js +++ b/query-engine/query-engine-wasm/example.js @@ -3,17 +3,20 @@ * on Node.js 18+. */ -import { Pool } from '@neondatabase/serverless' +import { Pool, neonConfig } from '@neondatabase/serverless' import { PrismaNeon } from '@prisma/adapter-neon' import { bindAdapter } from '@prisma/driver-adapter-utils' import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine.js' +import { WebSocket } from 'undici' + +neonConfig.webSocketConstructor = WebSocket async function main() { // Always initialize the Wasm library before using it. // This sets up the logging and panic hooks. init() - const connectionString = undefined + const connectionString = process.env.DATABASE_URL const pool = new Pool({ connectionString }) const adapter = new PrismaNeon(pool) @@ -60,6 +63,12 @@ async function main() { }), 'trace') const parsed = JSON.parse(res); console.log('query result = ', parsed) + + const error = parsed.errors?.[0]?.user_facing_error + if (error?.error_code === 'P2036') { + console.log('js error:', driverAdapter.errorRegistry.consumeError(error.meta.id)) + } + // if (res.error.user_facing_error.code =) await queryEngine.disconnect('trace') console.log('after disconnect') queryEngine.free() diff --git a/query-engine/query-engine-wasm/package-lock.json b/query-engine/query-engine-wasm/package-lock.json index 86b53cb0cde2..654ea1ef5b57 100644 --- a/query-engine/query-engine-wasm/package-lock.json +++ b/query-engine/query-engine-wasm/package-lock.json @@ -7,7 +7,16 @@ "dependencies": { "@neondatabase/serverless": "0.6.0", "@prisma/adapter-neon": "5.6.0", - "@prisma/driver-adapter-utils": "5.6.0" + "@prisma/driver-adapter-utils": "5.6.0", + "undici": "^5.27.2" + } + }, + "node_modules/@fastify/busboy": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-2.1.0.tgz", + "integrity": "sha512-+KpH+QxZU7O4675t3mnkQKcZZg56u+K/Ct2K+N2AZYNVK8kyeo/bI18tI8aPm3tvNNRyTWfj6s5tnGNlcbQRsA==", + "engines": { + "node": ">=14" } }, "node_modules/@neondatabase/serverless": { @@ -148,6 +157,17 @@ "node": ">=0.10.0" } }, + "node_modules/undici": { + "version": "5.27.2", + "resolved": "https://registry.npmjs.org/undici/-/undici-5.27.2.tgz", + "integrity": "sha512-iS857PdOEy/y3wlM3yRp+6SNQQ6xU0mmZcwRSriqk+et/cwWAtwmIGf6WkoDN2EK/AMdCO/dfXzIwi+rFMrjjQ==", + "dependencies": { + "@fastify/busboy": "^2.0.0" + }, + "engines": { + "node": ">=14.0" + } + }, "node_modules/undici-types": { "version": "5.26.5", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", diff --git a/query-engine/query-engine-wasm/package.json b/query-engine/query-engine-wasm/package.json index 8192656bd56f..7359677a2249 100644 --- a/query-engine/query-engine-wasm/package.json +++ b/query-engine/query-engine-wasm/package.json @@ -7,6 +7,7 @@ "dependencies": { "@neondatabase/serverless": "0.6.0", "@prisma/adapter-neon": "5.6.0", - "@prisma/driver-adapter-utils": "5.6.0" + "@prisma/driver-adapter-utils": "5.6.0", + "undici": "^5.27.2" } } From 6b60c8fab90d44fef421511c6ca6376067021f7a Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 15:33:49 +0100 Subject: [PATCH 062/134] Reorganize example Switch to sqlite so cloud services are not required Save schema to external file --- query-engine/query-engine-wasm/README.md | 2 +- .../query-engine-wasm/example/.gitignore | 1 + .../{ => example}/example.js | 32 +- .../{ => example}/package.json | 7 +- .../query-engine-wasm/example/pnpm-lock.yaml | 371 ++++++++++++++++++ .../example/prisma/schema.prisma | 12 + .../query-engine-wasm/package-lock.json | 185 --------- 7 files changed, 398 insertions(+), 212 deletions(-) create mode 100644 query-engine/query-engine-wasm/example/.gitignore rename query-engine/query-engine-wasm/{ => example}/example.js (66%) rename query-engine/query-engine-wasm/{ => example}/package.json (60%) create mode 100644 query-engine/query-engine-wasm/example/pnpm-lock.yaml create mode 100644 query-engine/query-engine-wasm/example/prisma/schema.prisma delete mode 100644 query-engine/query-engine-wasm/package-lock.json diff --git a/query-engine/query-engine-wasm/README.md b/query-engine/query-engine-wasm/README.md index f5adc7eb2894..7f294bc997c9 100644 --- a/query-engine/query-engine-wasm/README.md +++ b/query-engine/query-engine-wasm/README.md @@ -37,4 +37,4 @@ From the current folder: To try importing the , you can run: - `nvm use` -- `node --experimental-wasm-modules ./example.js` +- `node --experimental-wasm-modules example/example.js` diff --git a/query-engine/query-engine-wasm/example/.gitignore b/query-engine/query-engine-wasm/example/.gitignore new file mode 100644 index 000000000000..3997beadf829 --- /dev/null +++ b/query-engine/query-engine-wasm/example/.gitignore @@ -0,0 +1 @@ +*.db \ No newline at end of file diff --git a/query-engine/query-engine-wasm/example.js b/query-engine/query-engine-wasm/example/example.js similarity index 66% rename from query-engine/query-engine-wasm/example.js rename to query-engine/query-engine-wasm/example/example.js index 58453f41a85e..52833baab014 100644 --- a/query-engine/query-engine-wasm/example.js +++ b/query-engine/query-engine-wasm/example/example.js @@ -2,43 +2,29 @@ * Run with: `node --experimental-wasm-modules ./example.js` * on Node.js 18+. */ - -import { Pool, neonConfig } from '@neondatabase/serverless' -import { PrismaNeon } from '@prisma/adapter-neon' +import { readFile } from 'fs/promises' +import { PrismaLibSQL } from '@prisma/adapter-libsql' +import { createClient } from '@libsql/client' import { bindAdapter } from '@prisma/driver-adapter-utils' -import { init, QueryEngine, getBuildTimeInfo } from './pkg/query_engine.js' -import { WebSocket } from 'undici' +import { init, QueryEngine, getBuildTimeInfo } from '../pkg/query_engine.js' -neonConfig.webSocketConstructor = WebSocket async function main() { // Always initialize the Wasm library before using it. // This sets up the logging and panic hooks. init() - const connectionString = process.env.DATABASE_URL - const pool = new Pool({ connectionString }) - const adapter = new PrismaNeon(pool) + const client = createClient({ url: "file:./prisma/dev.db"}) + const adapter = new PrismaLibSQL(client) const driverAdapter = bindAdapter(adapter) console.log('buildTimeInfo', getBuildTimeInfo()) - const options = { - datamodel: /* prisma */` - datasource db { - provider = "postgres" - url = env("DATABASE_URL") - } + const datamodel = await readFile('prisma/schema.prisma', 'utf8') - generator client { - provider = "prisma-client-js" - } - - model User { - id Int @id @default(autoincrement()) - } - `, + const options = { + datamodel, logLevel: 'info', logQueries: true, datasourceOverrides: {}, diff --git a/query-engine/query-engine-wasm/package.json b/query-engine/query-engine-wasm/example/package.json similarity index 60% rename from query-engine/query-engine-wasm/package.json rename to query-engine/query-engine-wasm/example/package.json index 7359677a2249..bb6d7b868ede 100644 --- a/query-engine/query-engine-wasm/package.json +++ b/query-engine/query-engine-wasm/example/package.json @@ -5,9 +5,10 @@ "dev": "node --experimental-wasm-modules ./example.js" }, "dependencies": { - "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.6.0", + "@libsql/client": "0.4.0-pre.2", + "@prisma/adapter-libsql": "5.6.0", + "@prisma/client": "5.6.0", "@prisma/driver-adapter-utils": "5.6.0", - "undici": "^5.27.2" + "prisma": "5.6.0" } } diff --git a/query-engine/query-engine-wasm/example/pnpm-lock.yaml b/query-engine/query-engine-wasm/example/pnpm-lock.yaml new file mode 100644 index 000000000000..887edea0e8cc --- /dev/null +++ b/query-engine/query-engine-wasm/example/pnpm-lock.yaml @@ -0,0 +1,371 @@ +lockfileVersion: '6.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +dependencies: + '@libsql/client': + specifier: 0.4.0-pre.2 + version: 0.4.0-pre.2 + '@prisma/adapter-libsql': + specifier: 5.6.0 + version: 5.6.0(@libsql/client@0.4.0-pre.2) + '@prisma/client': + specifier: 5.6.0 + version: 5.6.0(prisma@5.6.0) + '@prisma/driver-adapter-utils': + specifier: 5.6.0 + version: 5.6.0 + prisma: + specifier: 5.6.0 + version: 5.6.0 + +packages: + + /@libsql/client@0.4.0-pre.2: + resolution: {integrity: sha512-sKWNPU+RQoki5hEoYhpC+fQ/kj+VuwoSXF2PMYGWB19MYBkMaMc7udn1T0ibNjNkFNmd98HvPIHd48NNC2oWvA==} + dependencies: + '@libsql/hrana-client': 0.5.5 + js-base64: 3.7.5 + libsql: 0.2.0-pre.2 + transitivePeerDependencies: + - bufferutil + - encoding + - utf-8-validate + dev: false + + /@libsql/darwin-arm64@0.2.0-pre.2: + resolution: {integrity: sha512-PKXAKBJF6XwfCT3yU1N/kHyUGcsatf/4rYNzdnc6UGeg+yWf3ZDk7sGnHHj9bDQ9oKLRVJQmc+cNIEsF2GOr9w==} + cpu: [arm64] + os: [darwin] + requiresBuild: true + dev: false + optional: true + + /@libsql/darwin-x64@0.2.0-pre.2: + resolution: {integrity: sha512-e3k4LsAFRf8qFfZqkg/VkoXK/UfDYgoDvLmAJpAGKEFp7d/bTmbF1r0YCjtGaPbheRxARAUXNfekvRhdpXE3mg==} + cpu: [x64] + os: [darwin] + requiresBuild: true + dev: false + optional: true + + /@libsql/hrana-client@0.5.5: + resolution: {integrity: sha512-i+hDBpiV719poqEiHupUUZYKJ9YSbCRFe5Q2PQ0v3mHIftePH6gayLjp2u6TXbqbO/Dv6y8yyvYlBXf/kFfRZA==} + dependencies: + '@libsql/isomorphic-fetch': 0.1.10 + '@libsql/isomorphic-ws': 0.1.5 + js-base64: 3.7.5 + node-fetch: 3.3.2 + transitivePeerDependencies: + - bufferutil + - encoding + - utf-8-validate + dev: false + + /@libsql/isomorphic-fetch@0.1.10: + resolution: {integrity: sha512-dH0lMk50gKSvEKD78xWMu60SY1sjp1sY//iFLO0XMmBwfVfG136P9KOk06R4maBdlb8KMXOzJ1D28FR5ZKnHTA==} + dependencies: + '@types/node-fetch': 2.6.9 + node-fetch: 2.7.0 + transitivePeerDependencies: + - encoding + dev: false + + /@libsql/isomorphic-ws@0.1.5: + resolution: {integrity: sha512-DtLWIH29onUYR00i0GlQ3UdcTRC6EP4u9w/h9LxpUZJWRMARk6dQwZ6Jkd+QdwVpuAOrdxt18v0K2uIYR3fwFg==} + dependencies: + '@types/ws': 8.5.10 + ws: 8.14.2 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + dev: false + + /@libsql/linux-arm64-gnu@0.2.0-pre.2: + resolution: {integrity: sha512-ZkN6e129joeUu6cinGMRbCvLTnrM5xV5n9XHs2dRrZfL7yu7utbvrY1l+P6VI1gugs93UhgupqyMsolFjvrPww==} + cpu: [arm64] + os: [linux] + requiresBuild: true + dev: false + optional: true + + /@libsql/linux-arm64-musl@0.2.0-pre.2: + resolution: {integrity: sha512-tEy4UAIzHYtjCBJnZoTcX1LCYy+XGR3hQCsdRYujWJhUtmtU/AqCRZV3q8MyfX7UhKyawJKWoQvwQ6Vs7w9jAA==} + cpu: [arm64] + os: [linux] + requiresBuild: true + dev: false + optional: true + + /@libsql/linux-x64-gnu@0.2.0-pre.2: + resolution: {integrity: sha512-jhHKwz5i9mdlpT4EeaKNUfyW5N9YY8wD5lZ0F5HrrPKhwgufnJY0oPEbvhM4KXDcSJetiIcGJ6K6NQyMSgoJ/Q==} + cpu: [x64] + os: [linux] + requiresBuild: true + dev: false + optional: true + + /@libsql/linux-x64-musl@0.2.0-pre.2: + resolution: {integrity: sha512-HvwZtSQ2eIT968yxAb+htO+wmibdwW1PIyR7iJ5TN7phj7W1gF962l3ZhV1hVYERaMu+liBH1e/cRP1S35q3vQ==} + cpu: [x64] + os: [linux] + requiresBuild: true + dev: false + optional: true + + /@libsql/win32-x64-msvc@0.2.0-pre.2: + resolution: {integrity: sha512-BWjInhsZRF9x+W0T5oJVjqoCCdvh82y74b/T3Ge/irXyLdVhHA9Zb1JWDy5uhu8eBR+d2n9B+IO0YwAvhFRTLw==} + cpu: [x64] + os: [win32] + requiresBuild: true + dev: false + optional: true + + /@neon-rs/load@0.0.4: + resolution: {integrity: sha512-kTPhdZyTQxB+2wpiRcFWrDcejc4JI6tkPuS7UZCG4l6Zvc5kU/gGQ/ozvHTh1XR5tS+UlfAfGuPajjzQjCiHCw==} + dev: false + + /@prisma/adapter-libsql@5.6.0(@libsql/client@0.4.0-pre.2): + resolution: {integrity: sha512-XFDLw9QqEDDVXAe8YdX8TL4mCiolDijjxh8HQRJ33VcuujGnAWWpBKE35MKfIsuONVyNXFthB/Gky/MlmMcE6Q==} + peerDependencies: + '@libsql/client': ^0.3.5 + dependencies: + '@libsql/client': 0.4.0-pre.2 + '@prisma/driver-adapter-utils': 5.6.0 + async-mutex: 0.4.0 + transitivePeerDependencies: + - supports-color + dev: false + + /@prisma/client@5.6.0(prisma@5.6.0): + resolution: {integrity: sha512-mUDefQFa1wWqk4+JhKPYq8BdVoFk9NFMBXUI8jAkBfQTtgx8WPx02U2HB/XbAz3GSUJpeJOKJQtNvaAIDs6sug==} + engines: {node: '>=16.13'} + requiresBuild: true + peerDependencies: + prisma: '*' + peerDependenciesMeta: + prisma: + optional: true + dependencies: + '@prisma/engines-version': 5.6.0-32.e95e739751f42d8ca026f6b910f5a2dc5adeaeee + prisma: 5.6.0 + dev: false + + /@prisma/driver-adapter-utils@5.6.0: + resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} + dependencies: + debug: 4.3.4 + transitivePeerDependencies: + - supports-color + dev: false + + /@prisma/engines-version@5.6.0-32.e95e739751f42d8ca026f6b910f5a2dc5adeaeee: + resolution: {integrity: sha512-UoFgbV1awGL/3wXuUK3GDaX2SolqczeeJ5b4FVec9tzeGbSWJboPSbT0psSrmgYAKiKnkOPFSLlH6+b+IyOwAw==} + dev: false + + /@prisma/engines@5.6.0: + resolution: {integrity: sha512-Mt2q+GNJpU2vFn6kif24oRSBQv1KOkYaterQsi0k2/lA+dLvhRX6Lm26gon6PYHwUM8/h8KRgXIUMU0PCLB6bw==} + requiresBuild: true + dev: false + + /@types/node-fetch@2.6.9: + resolution: {integrity: sha512-bQVlnMLFJ2d35DkPNjEPmd9ueO/rh5EiaZt2bhqiSarPjZIuIV6bPQVqcrEyvNo+AfTrRGVazle1tl597w3gfA==} + dependencies: + '@types/node': 20.9.4 + form-data: 4.0.0 + dev: false + + /@types/node@20.9.4: + resolution: {integrity: sha512-wmyg8HUhcn6ACjsn8oKYjkN/zUzQeNtMy44weTJSM6p4MMzEOuKbA3OjJ267uPCOW7Xex9dyrNTful8XTQYoDA==} + dependencies: + undici-types: 5.26.5 + dev: false + + /@types/ws@8.5.10: + resolution: {integrity: sha512-vmQSUcfalpIq0R9q7uTo2lXs6eGIpt9wtnLdMv9LVpIjCA/+ufZRozlVoVelIYixx1ugCBKDhn89vnsEGOCx9A==} + dependencies: + '@types/node': 20.9.4 + dev: false + + /async-mutex@0.4.0: + resolution: {integrity: sha512-eJFZ1YhRR8UN8eBLoNzcDPcy/jqjsg6I1AP+KvWQX80BqOSW1oJPJXDylPUEeMr2ZQvHgnQ//Lp6f3RQ1zI7HA==} + dependencies: + tslib: 2.6.2 + dev: false + + /asynckit@0.4.0: + resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} + dev: false + + /combined-stream@1.0.8: + resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} + engines: {node: '>= 0.8'} + dependencies: + delayed-stream: 1.0.0 + dev: false + + /data-uri-to-buffer@4.0.1: + resolution: {integrity: sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A==} + engines: {node: '>= 12'} + dev: false + + /debug@4.3.4: + resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + dependencies: + ms: 2.1.2 + dev: false + + /delayed-stream@1.0.0: + resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} + engines: {node: '>=0.4.0'} + dev: false + + /detect-libc@2.0.2: + resolution: {integrity: sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==} + engines: {node: '>=8'} + dev: false + + /fetch-blob@3.2.0: + resolution: {integrity: sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ==} + engines: {node: ^12.20 || >= 14.13} + dependencies: + node-domexception: 1.0.0 + web-streams-polyfill: 3.2.1 + dev: false + + /form-data@4.0.0: + resolution: {integrity: sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==} + engines: {node: '>= 6'} + dependencies: + asynckit: 0.4.0 + combined-stream: 1.0.8 + mime-types: 2.1.35 + dev: false + + /formdata-polyfill@4.0.10: + resolution: {integrity: sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g==} + engines: {node: '>=12.20.0'} + dependencies: + fetch-blob: 3.2.0 + dev: false + + /js-base64@3.7.5: + resolution: {integrity: sha512-3MEt5DTINKqfScXKfJFrRbxkrnk2AxPWGBL/ycjz4dK8iqiSJ06UxD8jh8xuh6p10TX4t2+7FsBYVxxQbMg+qA==} + dev: false + + /libsql@0.2.0-pre.2: + resolution: {integrity: sha512-ErF11J/Q0Uo1TMceX1f7RKfFvQ/j4FS8TagzJnAZBwhHsPcr7uItkSTchkuRHm5+cE4dJO7lqf+MpmlDjp/qAQ==} + cpu: [x64, arm64] + os: [darwin, linux, win32] + dependencies: + '@neon-rs/load': 0.0.4 + detect-libc: 2.0.2 + optionalDependencies: + '@libsql/darwin-arm64': 0.2.0-pre.2 + '@libsql/darwin-x64': 0.2.0-pre.2 + '@libsql/linux-arm64-gnu': 0.2.0-pre.2 + '@libsql/linux-arm64-musl': 0.2.0-pre.2 + '@libsql/linux-x64-gnu': 0.2.0-pre.2 + '@libsql/linux-x64-musl': 0.2.0-pre.2 + '@libsql/win32-x64-msvc': 0.2.0-pre.2 + dev: false + + /mime-db@1.52.0: + resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==} + engines: {node: '>= 0.6'} + dev: false + + /mime-types@2.1.35: + resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} + engines: {node: '>= 0.6'} + dependencies: + mime-db: 1.52.0 + dev: false + + /ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + dev: false + + /node-domexception@1.0.0: + resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} + engines: {node: '>=10.5.0'} + dev: false + + /node-fetch@2.7.0: + resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==} + engines: {node: 4.x || >=6.0.0} + peerDependencies: + encoding: ^0.1.0 + peerDependenciesMeta: + encoding: + optional: true + dependencies: + whatwg-url: 5.0.0 + dev: false + + /node-fetch@3.3.2: + resolution: {integrity: sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + dependencies: + data-uri-to-buffer: 4.0.1 + fetch-blob: 3.2.0 + formdata-polyfill: 4.0.10 + dev: false + + /prisma@5.6.0: + resolution: {integrity: sha512-EEaccku4ZGshdr2cthYHhf7iyvCcXqwJDvnoQRAJg5ge2Tzpv0e2BaMCp+CbbDUwoVTzwgOap9Zp+d4jFa2O9A==} + engines: {node: '>=16.13'} + hasBin: true + requiresBuild: true + dependencies: + '@prisma/engines': 5.6.0 + dev: false + + /tr46@0.0.3: + resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==} + dev: false + + /tslib@2.6.2: + resolution: {integrity: sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==} + dev: false + + /undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + dev: false + + /web-streams-polyfill@3.2.1: + resolution: {integrity: sha512-e0MO3wdXWKrLbL0DgGnUV7WHVuw9OUvL4hjgnPkIeEvESk74gAITi5G606JtZPp39cd8HA9VQzCIvA49LpPN5Q==} + engines: {node: '>= 8'} + dev: false + + /webidl-conversions@3.0.1: + resolution: {integrity: sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==} + dev: false + + /whatwg-url@5.0.0: + resolution: {integrity: sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==} + dependencies: + tr46: 0.0.3 + webidl-conversions: 3.0.1 + dev: false + + /ws@8.14.2: + resolution: {integrity: sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + dev: false diff --git a/query-engine/query-engine-wasm/example/prisma/schema.prisma b/query-engine/query-engine-wasm/example/prisma/schema.prisma new file mode 100644 index 000000000000..93a7c64a6122 --- /dev/null +++ b/query-engine/query-engine-wasm/example/prisma/schema.prisma @@ -0,0 +1,12 @@ +datasource db { + provider = "sqlite" + url = "file:./dev.db" +} + +generator client { + provider = "prisma-client-js" +} + +model User { + id Int @id @default(autoincrement()) +} diff --git a/query-engine/query-engine-wasm/package-lock.json b/query-engine/query-engine-wasm/package-lock.json deleted file mode 100644 index 654ea1ef5b57..000000000000 --- a/query-engine/query-engine-wasm/package-lock.json +++ /dev/null @@ -1,185 +0,0 @@ -{ - "name": "query-engine-wasm", - "lockfileVersion": 3, - "requires": true, - "packages": { - "": { - "dependencies": { - "@neondatabase/serverless": "0.6.0", - "@prisma/adapter-neon": "5.6.0", - "@prisma/driver-adapter-utils": "5.6.0", - "undici": "^5.27.2" - } - }, - "node_modules/@fastify/busboy": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-2.1.0.tgz", - "integrity": "sha512-+KpH+QxZU7O4675t3mnkQKcZZg56u+K/Ct2K+N2AZYNVK8kyeo/bI18tI8aPm3tvNNRyTWfj6s5tnGNlcbQRsA==", - "engines": { - "node": ">=14" - } - }, - "node_modules/@neondatabase/serverless": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/@neondatabase/serverless/-/serverless-0.6.0.tgz", - "integrity": "sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==", - "dependencies": { - "@types/pg": "8.6.6" - } - }, - "node_modules/@prisma/adapter-neon": { - "version": "5.6.0", - "resolved": "https://registry.npmjs.org/@prisma/adapter-neon/-/adapter-neon-5.6.0.tgz", - "integrity": "sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==", - "dependencies": { - "@prisma/driver-adapter-utils": "5.6.0", - "postgres-array": "3.0.2" - }, - "peerDependencies": { - "@neondatabase/serverless": "^0.6.0" - } - }, - "node_modules/@prisma/adapter-neon/node_modules/postgres-array": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-3.0.2.tgz", - "integrity": "sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==", - "engines": { - "node": ">=12" - } - }, - "node_modules/@prisma/driver-adapter-utils": { - "version": "5.6.0", - "resolved": "https://registry.npmjs.org/@prisma/driver-adapter-utils/-/driver-adapter-utils-5.6.0.tgz", - "integrity": "sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==", - "dependencies": { - "debug": "4.3.4" - } - }, - "node_modules/@types/node": { - "version": "20.8.10", - "resolved": "https://registry.npmjs.org/@types/node/-/node-20.8.10.tgz", - "integrity": "sha512-TlgT8JntpcbmKUFzjhsyhGfP2fsiz1Mv56im6enJ905xG1DAYesxJaeSbGqQmAw8OWPdhyJGhGSQGKRNJ45u9w==", - "dependencies": { - "undici-types": "~5.26.4" - } - }, - "node_modules/@types/pg": { - "version": "8.6.6", - "resolved": "https://registry.npmjs.org/@types/pg/-/pg-8.6.6.tgz", - "integrity": "sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==", - "dependencies": { - "@types/node": "*", - "pg-protocol": "*", - "pg-types": "^2.2.0" - } - }, - "node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", - "dependencies": { - "ms": "2.1.2" - }, - "engines": { - "node": ">=6.0" - }, - "peerDependenciesMeta": { - "supports-color": { - "optional": true - } - } - }, - "node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" - }, - "node_modules/pg-int8": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/pg-int8/-/pg-int8-1.0.1.tgz", - "integrity": "sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==", - "engines": { - "node": ">=4.0.0" - } - }, - "node_modules/pg-protocol": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.6.0.tgz", - "integrity": "sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==" - }, - "node_modules/pg-types": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/pg-types/-/pg-types-2.2.0.tgz", - "integrity": "sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==", - "dependencies": { - "pg-int8": "1.0.1", - "postgres-array": "~2.0.0", - "postgres-bytea": "~1.0.0", - "postgres-date": "~1.0.4", - "postgres-interval": "^1.1.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/postgres-array": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz", - "integrity": "sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==", - "engines": { - "node": ">=4" - } - }, - "node_modules/postgres-bytea": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/postgres-bytea/-/postgres-bytea-1.0.0.tgz", - "integrity": "sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/postgres-date": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/postgres-date/-/postgres-date-1.0.7.tgz", - "integrity": "sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/postgres-interval": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/postgres-interval/-/postgres-interval-1.2.0.tgz", - "integrity": "sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==", - "dependencies": { - "xtend": "^4.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/undici": { - "version": "5.27.2", - "resolved": "https://registry.npmjs.org/undici/-/undici-5.27.2.tgz", - "integrity": "sha512-iS857PdOEy/y3wlM3yRp+6SNQQ6xU0mmZcwRSriqk+et/cwWAtwmIGf6WkoDN2EK/AMdCO/dfXzIwi+rFMrjjQ==", - "dependencies": { - "@fastify/busboy": "^2.0.0" - }, - "engines": { - "node": ">=14.0" - } - }, - "node_modules/undici-types": { - "version": "5.26.5", - "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", - "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==" - }, - "node_modules/xtend": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", - "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", - "engines": { - "node": ">=0.4" - } - } - } -} From 521b3abfa5bc3c92ce1e200ff22222326aaec460 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 16:16:49 +0100 Subject: [PATCH 063/134] Fix some of the warnings --- query-engine/driver-adapters/src/wasm/conversion.rs | 1 - query-engine/driver-adapters/src/wasm/mod.rs | 1 - query-engine/query-engine-wasm/src/wasm.rs | 1 - 3 files changed, 3 deletions(-) delete mode 100644 query-engine/driver-adapters/src/wasm/conversion.rs diff --git a/query-engine/driver-adapters/src/wasm/conversion.rs b/query-engine/driver-adapters/src/wasm/conversion.rs deleted file mode 100644 index 9cb5202cda45..000000000000 --- a/query-engine/driver-adapters/src/wasm/conversion.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) use crate::conversion::{mysql, postgres, sqlite, JSArg}; diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 8f0e0aeeb3de..9cdc66b177e7 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -1,7 +1,6 @@ //! Query Engine Driver Adapters: `wasm`-specific implementation. mod async_js_function; -mod conversion; mod error; mod js_object_extern; pub(crate) mod proxy; diff --git a/query-engine/query-engine-wasm/src/wasm.rs b/query-engine/query-engine-wasm/src/wasm.rs index 14edeadf63b6..5e83cf3aa2b6 100644 --- a/query-engine/query-engine-wasm/src/wasm.rs +++ b/query-engine/query-engine-wasm/src/wasm.rs @@ -3,5 +3,4 @@ pub mod error; pub mod functions; pub mod logger; -pub(crate) type Result = std::result::Result; pub(crate) type Executor = Box; From 272228e69ec5215ad8eb54768381d28f6c75b19f Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 22 Nov 2023 17:01:58 +0100 Subject: [PATCH 064/134] Remove unused file --- .../driver-adapters/src/napi/queryable.rs | 303 ------------------ 1 file changed, 303 deletions(-) delete mode 100644 query-engine/driver-adapters/src/napi/queryable.rs diff --git a/query-engine/driver-adapters/src/napi/queryable.rs b/query-engine/driver-adapters/src/napi/queryable.rs deleted file mode 100644 index 900ff076b806..000000000000 --- a/query-engine/driver-adapters/src/napi/queryable.rs +++ /dev/null @@ -1,303 +0,0 @@ -use super::{ - conversion, - proxy::{CommonProxy, DriverProxy, Query}, -}; -use async_trait::async_trait; -use napi::JsObject; -use psl::datamodel_connector::Flavour; -use quaint::{ - connector::{metrics, IsolationLevel, Transaction}, - error::{Error, ErrorKind}, - prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, - visitor::{self, Visitor}, -}; -use tracing::{info_span, Instrument}; - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub(crate) struct JsBaseQueryable { - pub(crate) proxy: CommonProxy, - pub flavour: Flavour, -} - -impl JsBaseQueryable { - pub(crate) fn new(proxy: CommonProxy) -> Self { - let flavour: Flavour = proxy.flavour.parse().unwrap(); - Self { proxy, flavour } - } - - /// visit a quaint query AST according to the flavour of the JS connector - fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { - match self.flavour { - Flavour::Mysql => visitor::Mysql::build(q), - Flavour::Postgres => visitor::Postgres::build(q), - Flavour::Sqlite => visitor::Sqlite::build(q), - _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), - } - } - - async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { - let sql: String = sql.to_string(); - - let converter = match self.flavour { - Flavour::Postgres => conversion::postgres::value_to_js_arg, - Flavour::Sqlite => conversion::sqlite::value_to_js_arg, - Flavour::Mysql => conversion::mysql::value_to_js_arg, - _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), - }; - - let args = values - .iter() - .map(converter) - .collect::>>()?; - - Ok(Query { sql, args }) - } -} - -#[async_trait] -impl QuaintQueryable for JsBaseQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.query_raw", sql, params, move || async move { - self.do_query_raw(sql, params).await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.execute_raw", sql, params, move || async move { - self.do_execute_raw(sql, params).await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - let params = &[]; - metrics::query("js.raw_cmd", cmd, params, move || async move { - self.do_execute_raw(cmd, params).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> quaint::Result> { - // Note: JS Connectors don't use this method. - Ok(None) - } - - fn is_healthy(&self) -> bool { - // Note: JS Connectors don't use this method. - true - } - - /// Sets the transaction isolation level to given value. - /// Implementers have to make sure that the passed isolation level is valid for the underlying database. - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - if self.flavour == Flavour::Sqlite { - return match isolation_level { - IsolationLevel::Serializable => Ok(()), - _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), - }; - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await - } - - fn requires_isolation_first(&self) -> bool { - match self.flavour { - Flavour::Mysql => true, - Flavour::Postgres | Flavour::Sqlite => false, - _ => unreachable!(), - } - } -} - -impl JsBaseQueryable { - pub fn phantom_query_message(stmt: &str) -> String { - format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) - } - - async fn do_query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; - - let len = result_set.len(); - let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); - - result_set.try_into() - } - - async fn do_execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; - - Ok(affected_rows as u64) - } -} - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub struct JsQueryable { - inner: JsBaseQueryable, - driver_proxy: DriverProxy, -} - -impl std::fmt::Display for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -impl std::fmt::Debug for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -#[async_trait] -impl QuaintQueryable for JsQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.query(q).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw(sql, params).await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw_typed(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.execute(q).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw(sql, params).await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw_typed(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - self.inner.raw_cmd(cmd).await - } - - async fn version(&self) -> quaint::Result> { - self.inner.version().await - } - - fn is_healthy(&self) -> bool { - self.inner.is_healthy() - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - self.inner.set_tx_isolation_level(isolation_level).await - } - - fn requires_isolation_first(&self) -> bool { - self.inner.requires_isolation_first() - } -} - -#[async_trait] -impl TransactionCapable for JsQueryable { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> quaint::Result> { - let tx = self.driver_proxy.start_transaction().await?; - - let isolation_first = tx.requires_isolation_first(); - - if isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - let begin_stmt = tx.begin_statement(); - - let tx_opts = tx.options(); - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(begin_stmt).await?; - } - - if !isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - self.server_reset_query(tx.as_ref()).await?; - - Ok(tx) - } -} - -pub fn from_napi(driver: JsObject) -> JsQueryable { - let common = CommonProxy::new(&driver).unwrap(); - let driver_proxy = DriverProxy::new(&driver).unwrap(); - - JsQueryable { - inner: JsBaseQueryable::new(common), - driver_proxy, - } -} From 39b2489c4b773617e33392c4bd5adaa9893be5c9 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Thu, 23 Nov 2023 14:29:39 +0100 Subject: [PATCH 065/134] Fix WASM transaction binding Correctly gets to a point of starting the tranasction and executing the query, fails on parsing the results like normal queries do. --- Cargo.lock | 2 +- psl/psl-core/src/datamodel_connector.rs | 3 -- query-engine/driver-adapters/Cargo.toml | 2 +- .../driver-adapters/src/queryable/mod.rs | 25 +++++++-------- .../driver-adapters/src/queryable/wasm.rs | 5 ++- query-engine/driver-adapters/src/types.rs | 27 +++++++++++++++- .../src/wasm/async_js_function.rs | 30 +++++++++++------- .../driver-adapters/src/wasm/error.rs | 1 + .../driver-adapters/src/wasm/from_js.rs | 15 +++++++++ query-engine/driver-adapters/src/wasm/mod.rs | 1 + .../driver-adapters/src/wasm/proxy.rs | 2 -- .../driver-adapters/src/wasm/result.rs | 31 +++++++------------ .../driver-adapters/src/wasm/transaction.rs | 29 +++++++++++++---- 13 files changed, 112 insertions(+), 61 deletions(-) create mode 100644 query-engine/driver-adapters/src/wasm/from_js.rs diff --git a/Cargo.lock b/Cargo.lock index c9dc91badd04..9a80d864b6ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1106,7 +1106,6 @@ dependencies = [ "num-bigint", "once_cell", "pin-project", - "psl", "quaint", "serde", "serde-wasm-bindgen", @@ -1118,6 +1117,7 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", + "web-sys", ] [[package]] diff --git a/psl/psl-core/src/datamodel_connector.rs b/psl/psl-core/src/datamodel_connector.rs index 242f0df20b7c..dc3a7e80bd10 100644 --- a/psl/psl-core/src/datamodel_connector.rs +++ b/psl/psl-core/src/datamodel_connector.rs @@ -361,10 +361,7 @@ pub trait Connector: Send + Sync { } } -#[cfg_attr(target_arch = "wasm32", wasm_bindgen::prelude::wasm_bindgen)] -#[derive(Copy, Clone, Debug, PartialEq, Default, serde::Deserialize)] pub enum Flavour { - #[default] Cockroach, Mongo, Sqlserver, diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index ec77df85e142..23244ab7dcb4 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -8,7 +8,6 @@ async-trait = "0.1" once_cell = "1.15" serde.workspace = true serde_json.workspace = true -psl.workspace = true tracing = "0.1" tracing-core = "0.1" metrics = "0.18" @@ -21,6 +20,7 @@ num-bigint = "0.4.3" bigdecimal = "0.3.0" chrono = "0.4.20" futures = "0.3" +web-sys = "0.3.65" [dev-dependencies] expect-test = "1" diff --git a/query-engine/driver-adapters/src/queryable/mod.rs b/query-engine/driver-adapters/src/queryable/mod.rs index ac252bbb011b..9cd2eb1c9b33 100644 --- a/query-engine/driver-adapters/src/queryable/mod.rs +++ b/query-engine/driver-adapters/src/queryable/mod.rs @@ -19,11 +19,11 @@ pub(crate) use wasm::JsBaseQueryable; use super::{ conversion, proxy::{CommonProxy, DriverProxy, Query}, + types::AdapterFlavour, }; use crate::send_future::SendFuture; use async_trait::async_trait; use futures::Future; -use psl::datamodel_connector::Flavour; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, error::{Error, ErrorKind}, @@ -34,17 +34,16 @@ use tracing::{info_span, Instrument}; impl JsBaseQueryable { pub(crate) fn new(proxy: CommonProxy) -> Self { - let flavour: Flavour = proxy.flavour.parse().unwrap(); + let flavour: AdapterFlavour = proxy.flavour.parse().unwrap(); Self { proxy, flavour } } /// visit a quaint query AST according to the flavour of the JS connector fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { match self.flavour { - Flavour::Mysql => visitor::Mysql::build(q), - Flavour::Postgres => visitor::Postgres::build(q), - Flavour::Sqlite => visitor::Sqlite::build(q), - _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), + AdapterFlavour::Mysql => visitor::Mysql::build(q), + AdapterFlavour::Postgres => visitor::Postgres::build(q), + AdapterFlavour::Sqlite => visitor::Sqlite::build(q), } } @@ -52,10 +51,9 @@ impl JsBaseQueryable { let sql: String = sql.to_string(); let converter = match self.flavour { - Flavour::Postgres => conversion::postgres::value_to_js_arg, - Flavour::Sqlite => conversion::sqlite::value_to_js_arg, - Flavour::Mysql => conversion::mysql::value_to_js_arg, - _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), + AdapterFlavour::Postgres => conversion::postgres::value_to_js_arg, + AdapterFlavour::Sqlite => conversion::sqlite::value_to_js_arg, + AdapterFlavour::Mysql => conversion::mysql::value_to_js_arg, }; let args = values @@ -127,7 +125,7 @@ impl QuaintQueryable for JsBaseQueryable { return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); } - if self.flavour == Flavour::Sqlite { + if self.flavour == AdapterFlavour::Sqlite { return match isolation_level { IsolationLevel::Serializable => Ok(()), _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), @@ -140,9 +138,8 @@ impl QuaintQueryable for JsBaseQueryable { fn requires_isolation_first(&self) -> bool { match self.flavour { - Flavour::Mysql => true, - Flavour::Postgres | Flavour::Sqlite => false, - _ => unreachable!(), + AdapterFlavour::Mysql => true, + AdapterFlavour::Postgres | AdapterFlavour::Sqlite => false, } } } diff --git a/query-engine/driver-adapters/src/queryable/wasm.rs b/query-engine/driver-adapters/src/queryable/wasm.rs index 867d1fb5081a..ee1c65a81347 100644 --- a/query-engine/driver-adapters/src/queryable/wasm.rs +++ b/query-engine/driver-adapters/src/queryable/wasm.rs @@ -1,6 +1,6 @@ +use crate::types::AdapterFlavour; use crate::wasm::proxy::{CommonProxy, DriverProxy}; use crate::{JsObjectExtern, JsQueryable}; -use psl::datamodel_connector::Flavour; use wasm_bindgen::prelude::wasm_bindgen; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -16,10 +16,9 @@ use wasm_bindgen::prelude::wasm_bindgen; /// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector /// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. #[wasm_bindgen(getter_with_clone)] -#[derive(Default)] pub(crate) struct JsBaseQueryable { pub(crate) proxy: CommonProxy, - pub flavour: Flavour, + pub flavour: AdapterFlavour, } pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable { diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 2b2c0b45e50c..8975e7cd9044 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -1,5 +1,7 @@ #![allow(unused_imports)] +use std::str::FromStr; + #[cfg(not(target_arch = "wasm32"))] use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; @@ -9,6 +11,28 @@ use tsify::Tsify; use crate::conversion::JSArg; use serde::{Deserialize, Serialize}; +#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum AdapterFlavour { + Mysql, + Postgres, + Sqlite, +} + +impl FromStr for AdapterFlavour { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "postgres" => Ok(Self::Postgres), + "mysql" => Ok(Self::Mysql), + "sqlite" => Ok(Self::Sqlite), + _ => Err(format!("Unsupported adapter flavour: {:?}", s)), + } + } +} + /// This result set is more convenient to be manipulated from both Rust and NodeJS. /// Quaint's version of ResultSet is: /// @@ -27,7 +51,7 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] #[cfg_attr(target_arch = "wasm32", serde(rename_all = "camelCase"))] -#[derive(Debug, Default)] +#[derive(Debug)] pub struct JSResultSet { pub column_types: Vec, pub column_names: Vec, @@ -190,6 +214,7 @@ pub struct Query { #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] #[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[cfg_attr(target_arch = "wasm32", serde(rename_all = "camelCase"))] #[derive(Debug, Default)] pub struct TransactionOptions { /// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index f4e3771694a2..e13f288f4a56 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -1,19 +1,21 @@ -use js_sys::{Function as JsFunction, Promise as JsPromise}; -use serde::{de::DeserializeOwned, Serialize}; +use js_sys::{Function as JsFunction, JsString, Promise as JsPromise}; +use serde::Serialize; use std::marker::PhantomData; +use std::str::FromStr; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::describe::WasmDescribe; use wasm_bindgen::{JsError, JsValue}; use wasm_bindgen_futures::JsFuture; use super::error::into_quaint_error; +use super::from_js::FromJsValue; use super::result::JsResult; -#[derive(Clone, Default)] +#[derive(Clone)] pub(crate) struct AsyncJsFunction where ArgType: Serialize, - ReturnType: DeserializeOwned, + ReturnType: FromJsValue, { pub threadsafe_fn: JsFunction, @@ -24,7 +26,7 @@ where impl From for AsyncJsFunction where T: Serialize, - R: DeserializeOwned, + R: FromJsValue, { fn from(js_fn: JsFunction) -> Self { Self { @@ -38,14 +40,20 @@ where impl AsyncJsFunction where T: Serialize, - R: DeserializeOwned, + R: FromJsValue, { pub async fn call(&self, arg1: T) -> quaint::Result { let result = self.call_internal(arg1).await; match result { - Ok(js_result) => js_result.into(), - Err(err) => Err(into_quaint_error(err)), + Ok(js_result) => { + web_sys::console::log_1(&JsString::from_str("OK JS").unwrap().into()); + js_result.into() + } + Err(err) => { + web_sys::console::log_1(&JsString::from_str("CALL ERR").unwrap().into()); + Err(into_quaint_error(err)) + } } } @@ -54,7 +62,7 @@ where let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; let future = JsFuture::from(JsPromise::from(promise)); let value = future.await?; - let js_result: JsResult = value.try_into()?; + let js_result = JsResult::::from_js_value(value)?; Ok(js_result) } @@ -63,7 +71,7 @@ where impl WasmDescribe for AsyncJsFunction where ArgType: Serialize, - ReturnType: DeserializeOwned, + ReturnType: FromJsValue, { fn describe() { JsFunction::describe(); @@ -73,7 +81,7 @@ where impl FromWasmAbi for AsyncJsFunction where ArgType: Serialize, - ReturnType: DeserializeOwned, + ReturnType: FromJsValue, { type Abi = ::Abi; diff --git a/query-engine/driver-adapters/src/wasm/error.rs b/query-engine/driver-adapters/src/wasm/error.rs index 0aa4fe7981f2..49d7ad6cf440 100644 --- a/query-engine/driver-adapters/src/wasm/error.rs +++ b/query-engine/driver-adapters/src/wasm/error.rs @@ -5,6 +5,7 @@ use wasm_bindgen::JsValue; /// transforms a Wasm error into a Quaint error pub(crate) fn into_quaint_error(wasm_err: JsValue) -> QuaintError { let status = "WASM_ERROR".to_string(); + web_sys::console::log_1(&wasm_err); let reason = Reflect::get(&wasm_err, &JsValue::from_str("stack")) .ok() .and_then(|value| value.as_string()) diff --git a/query-engine/driver-adapters/src/wasm/from_js.rs b/query-engine/driver-adapters/src/wasm/from_js.rs new file mode 100644 index 000000000000..aaa0d91223f6 --- /dev/null +++ b/query-engine/driver-adapters/src/wasm/from_js.rs @@ -0,0 +1,15 @@ +use serde::de::DeserializeOwned; +use wasm_bindgen::JsValue; + +pub trait FromJsValue: Sized { + fn from_js_value(value: JsValue) -> Result; +} + +impl FromJsValue for T +where + T: DeserializeOwned, +{ + fn from_js_value(value: JsValue) -> Result { + serde_wasm_bindgen::from_value(value).map_err(|e| JsValue::from(e)) + } +} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 9cdc66b177e7..2afe1987e1a7 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -2,6 +2,7 @@ mod async_js_function; mod error; +mod from_js; mod js_object_extern; pub(crate) mod proxy; mod result; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index bb2e9a855fe7..16a88b8d1fd3 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -15,7 +15,6 @@ type JsResult = core::result::Result; /// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to /// invoke the code within the node runtime that implements the client connector. #[wasm_bindgen(getter_with_clone)] -#[derive(Default)] pub(crate) struct CommonProxy { /// Execute a query given as SQL, interpolating the given parameters. query_raw: AsyncJsFunction, @@ -38,7 +37,6 @@ pub(crate) struct DriverProxy { /// This a JS proxy for accessing the methods, specific /// to JS transaction objects #[wasm_bindgen(getter_with_clone)] -#[derive(Default)] pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, diff --git a/query-engine/driver-adapters/src/wasm/result.rs b/query-engine/driver-adapters/src/wasm/result.rs index df4652307469..fc5115e4a500 100644 --- a/query-engine/driver-adapters/src/wasm/result.rs +++ b/query-engine/driver-adapters/src/wasm/result.rs @@ -1,8 +1,10 @@ -use js_sys::Boolean as JsBoolean; +use std::str::FromStr; + +use js_sys::{Boolean as JsBoolean, JsString}; use quaint::error::{Error as QuaintError, ErrorKind}; -use serde::de::DeserializeOwned; use wasm_bindgen::{JsCast, JsValue}; +use super::from_js::FromJsValue; use crate::{error::DriverAdapterError, JsObjectExtern}; impl From for QuaintError { @@ -26,28 +28,17 @@ impl From for QuaintError { /// Wrapper for JS-side result type pub(crate) enum JsResult where - T: DeserializeOwned, + T: FromJsValue, { Ok(T), Err(DriverAdapterError), } -impl TryFrom for JsResult -where - T: DeserializeOwned, -{ - type Error = JsValue; - - fn try_from(value: JsValue) -> Result { - Self::from_js_unknown(value) - } -} - -impl JsResult +impl FromJsValue for JsResult where - T: DeserializeOwned, + T: FromJsValue, { - fn from_js_unknown(unknown: JsValue) -> Result { + fn from_js_value(unknown: JsValue) -> Result { let object = unknown.unchecked_into::(); let ok: JsBoolean = object.get("ok".into())?.unchecked_into(); @@ -55,7 +46,9 @@ where if ok { let js_value: JsValue = object.get("value".into())?; - let deserialized = serde_wasm_bindgen::from_value::(js_value)?; + web_sys::console::log_1(&JsString::from_str("BEFORE DESERIALIZE").unwrap().into()); + let deserialized = T::from_js_value(js_value)?; + web_sys::console::log_1(&JsString::from_str(" DESERIALIZE").unwrap().into()); return Ok(Self::Ok(deserialized)); } @@ -67,7 +60,7 @@ where impl From> for quaint::Result where - T: DeserializeOwned, + T: FromJsValue, { fn from(value: JsResult) -> Self { match value { diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs index 43925b488101..d1b93bd4bfd0 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -1,23 +1,25 @@ use async_trait::async_trait; +use js_sys::{JsString, Object as JsObject}; use metrics::decrement_gauge; use quaint::{ connector::{IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use serde::Deserialize; +use std::str::FromStr; +use wasm_bindgen::JsCast; -use super::proxy::{TransactionOptions, TransactionProxy}; -use crate::{queryable::JsBaseQueryable, send_future::SendFuture}; +use super::{ + from_js::FromJsValue, + proxy::{TransactionOptions, TransactionProxy}, +}; +use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::SendFuture, JsObjectExtern}; // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, // but delegates most operations to JS -#[derive(Deserialize, Default)] pub(crate) struct JsTransaction { - #[serde(skip)] tx_proxy: TransactionProxy, - #[serde(skip)] inner: JsBaseQueryable, } @@ -36,6 +38,21 @@ impl JsTransaction { } } +impl FromJsValue for JsTransaction { + fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> Result { + let object: JsObjectExtern = value.dyn_into::()?.unchecked_into(); + web_sys::console::log_1(&JsString::from_str("OBJECT").unwrap().into()); + let common_proxy = CommonProxy::new(&object)?; + web_sys::console::log_1(&JsString::from_str("PROXY").unwrap().into()); + let base = JsBaseQueryable::new(common_proxy); + web_sys::console::log_1(&JsString::from_str("BASE").unwrap().into()); + let tx_proxy = TransactionProxy::new(&object)?; + web_sys::console::log_1(&JsString::from_str("TX_PROXY").unwrap().into()); + + Ok(Self::new(base, tx_proxy)) + } +} + #[async_trait] impl QuaintTransaction for JsTransaction { async fn commit(&self) -> quaint::Result<()> { From b589b1f00d9993d7779b7479e6aa4f91009d40a0 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Thu, 23 Nov 2023 14:33:33 +0100 Subject: [PATCH 066/134] Cleanup --- Cargo.lock | 1 - query-engine/driver-adapters/Cargo.toml | 1 - .../driver-adapters/src/wasm/async_js_function.rs | 13 +++---------- query-engine/driver-adapters/src/wasm/error.rs | 1 - query-engine/driver-adapters/src/wasm/result.rs | 6 +----- .../driver-adapters/src/wasm/transaction.rs | 7 +------ 6 files changed, 5 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a80d864b6ca..4d3e4de66ccb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1117,7 +1117,6 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", - "web-sys", ] [[package]] diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 23244ab7dcb4..86d063603ed7 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -20,7 +20,6 @@ num-bigint = "0.4.3" bigdecimal = "0.3.0" chrono = "0.4.20" futures = "0.3" -web-sys = "0.3.65" [dev-dependencies] expect-test = "1" diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index e13f288f4a56..fe64847b5b1e 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -1,7 +1,6 @@ -use js_sys::{Function as JsFunction, JsString, Promise as JsPromise}; +use js_sys::{Function as JsFunction, Promise as JsPromise}; use serde::Serialize; use std::marker::PhantomData; -use std::str::FromStr; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::describe::WasmDescribe; use wasm_bindgen::{JsError, JsValue}; @@ -46,14 +45,8 @@ where let result = self.call_internal(arg1).await; match result { - Ok(js_result) => { - web_sys::console::log_1(&JsString::from_str("OK JS").unwrap().into()); - js_result.into() - } - Err(err) => { - web_sys::console::log_1(&JsString::from_str("CALL ERR").unwrap().into()); - Err(into_quaint_error(err)) - } + Ok(js_result) => js_result.into(), + Err(err) => Err(into_quaint_error(err)), } } diff --git a/query-engine/driver-adapters/src/wasm/error.rs b/query-engine/driver-adapters/src/wasm/error.rs index 49d7ad6cf440..0aa4fe7981f2 100644 --- a/query-engine/driver-adapters/src/wasm/error.rs +++ b/query-engine/driver-adapters/src/wasm/error.rs @@ -5,7 +5,6 @@ use wasm_bindgen::JsValue; /// transforms a Wasm error into a Quaint error pub(crate) fn into_quaint_error(wasm_err: JsValue) -> QuaintError { let status = "WASM_ERROR".to_string(); - web_sys::console::log_1(&wasm_err); let reason = Reflect::get(&wasm_err, &JsValue::from_str("stack")) .ok() .and_then(|value| value.as_string()) diff --git a/query-engine/driver-adapters/src/wasm/result.rs b/query-engine/driver-adapters/src/wasm/result.rs index fc5115e4a500..2e656a205c41 100644 --- a/query-engine/driver-adapters/src/wasm/result.rs +++ b/query-engine/driver-adapters/src/wasm/result.rs @@ -1,6 +1,4 @@ -use std::str::FromStr; - -use js_sys::{Boolean as JsBoolean, JsString}; +use js_sys::Boolean as JsBoolean; use quaint::error::{Error as QuaintError, ErrorKind}; use wasm_bindgen::{JsCast, JsValue}; @@ -46,9 +44,7 @@ where if ok { let js_value: JsValue = object.get("value".into())?; - web_sys::console::log_1(&JsString::from_str("BEFORE DESERIALIZE").unwrap().into()); let deserialized = T::from_js_value(js_value)?; - web_sys::console::log_1(&JsString::from_str(" DESERIALIZE").unwrap().into()); return Ok(Self::Ok(deserialized)); } diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs index d1b93bd4bfd0..b9eac2965e48 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -1,12 +1,11 @@ use async_trait::async_trait; -use js_sys::{JsString, Object as JsObject}; +use js_sys::Object as JsObject; use metrics::decrement_gauge; use quaint::{ connector::{IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use std::str::FromStr; use wasm_bindgen::JsCast; use super::{ @@ -41,13 +40,9 @@ impl JsTransaction { impl FromJsValue for JsTransaction { fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> Result { let object: JsObjectExtern = value.dyn_into::()?.unchecked_into(); - web_sys::console::log_1(&JsString::from_str("OBJECT").unwrap().into()); let common_proxy = CommonProxy::new(&object)?; - web_sys::console::log_1(&JsString::from_str("PROXY").unwrap().into()); let base = JsBaseQueryable::new(common_proxy); - web_sys::console::log_1(&JsString::from_str("BASE").unwrap().into()); let tx_proxy = TransactionProxy::new(&object)?; - web_sys::console::log_1(&JsString::from_str("TX_PROXY").unwrap().into()); Ok(Self::new(base, tx_proxy)) } From e56b723c3d50bbd7046815fb5aa04d0fc6532607 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 23 Nov 2023 15:01:53 +0100 Subject: [PATCH 067/134] feat(driver-adapters): fix enum parsing, add "wasm-rs-dbg" crate for dev development --- Cargo.lock | 21 +++++++++++++++++---- Cargo.toml | 2 ++ query-engine/driver-adapters/Cargo.toml | 5 ++++- query-engine/driver-adapters/src/types.rs | 4 +++- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9dc91badd04..23fabd3c52cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1111,6 +1111,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", + "serde_repr", "tokio", "tracing", "tracing-core", @@ -1118,6 +1119,7 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-rs-dbg", ] [[package]] @@ -1930,9 +1932,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.21.2" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261bf85ed492cd1c47c9ba675e48649682a9d2d2e77f515c5386d7726fb0ba76" +checksum = "5d64600be34b2fcfc267740a243fa7744441bb4947a619ac4e5bb6507f35fbfc" dependencies = [ "console", "lazy_static", @@ -3532,6 +3534,7 @@ dependencies = [ "serde_json", "url", "wasm-bindgen", + "wasm-rs-dbg", ] [[package]] @@ -3848,6 +3851,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", + "serde_repr", "sql-query-connector", "thiserror", "tokio", @@ -4667,9 +4671,9 @@ dependencies = [ [[package]] name = "serde_repr" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" +checksum = "3081f5ffbb02284dda55132aa26daecedd7372a42417bbbab6f14ab7d6bb9145" dependencies = [ "proc-macro2", "quote", @@ -6110,6 +6114,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasm-rs-dbg" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e5fe4ac478ca5cf1db842029f41a5881da39e70320deb0006912f226ea63f4" +dependencies = [ + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.65" diff --git a/Cargo.toml b/Cargo.toml index 1e6f732192dd..52b5ddeac055 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,10 +58,12 @@ napi = { version = "2.12.4", default-features = false, features = [ ] } napi-derive = "2.12.4" js-sys = { version = "0.3" } +serde_repr = { version = "0.1.17" } serde-wasm-bindgen = { version = "0.5" } tsify = { version = "0.4.5" } wasm-bindgen = { version = "0.2.88" } wasm-bindgen-futures = { version = "0.4" } +wasm-rs-dbg = { version = "0.1.2" } [workspace.dependencies.quaint] path = "quaint" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index ec77df85e142..c32b1b8e1d10 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -14,6 +14,7 @@ tracing-core = "0.1" metrics = "0.18" uuid = { version = "1", features = ["v4"] } pin-project = "1" +wasm-rs-dbg = "0.1.2" # Note: these deps are temporarily specified here to avoid importing them from tiberius (the SQL server driver). # They will be imported from quaint-core instead in a future PR. @@ -24,7 +25,8 @@ futures = "0.3" [dev-dependencies] expect-test = "1" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time", "sync"] } +wasm-rs-dbg.dev-workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] napi.workspace = true @@ -35,6 +37,7 @@ quaint.workspace = true quaint = { path = "../../quaint" } js-sys.workspace = true serde-wasm-bindgen.workspace = true +serde_repr.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tsify.workspace = true diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 2b2c0b45e50c..5c140353a744 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -8,6 +8,7 @@ use tsify::Tsify; use crate::conversion::JSArg; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; /// This result set is more convenient to be manipulated from both Rust and NodeJS. /// Quaint's version of ResultSet is: @@ -43,8 +44,9 @@ impl JSResultSet { } #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] -#[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Serialize, Deserialize, Tsify))] +#[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Serialize_repr, Deserialize_repr, Tsify))] #[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))] +#[repr(u8)] #[derive(Debug)] pub enum ColumnType { // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] From 61add9e7bed5e6880a525af97d8083255b9693e6 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 23 Nov 2023 15:02:15 +0100 Subject: [PATCH 068/134] chore(driver-adapters): remove unused "src/wasm/queryable.rs" --- .../driver-adapters/src/wasm/queryable.rs | 325 ------------------ 1 file changed, 325 deletions(-) delete mode 100644 query-engine/driver-adapters/src/wasm/queryable.rs diff --git a/query-engine/driver-adapters/src/wasm/queryable.rs b/query-engine/driver-adapters/src/wasm/queryable.rs deleted file mode 100644 index edb0de4ea493..000000000000 --- a/query-engine/driver-adapters/src/wasm/queryable.rs +++ /dev/null @@ -1,325 +0,0 @@ -use crate::JsObjectExtern; - -use super::{ - conversion, - proxy::{CommonProxy, DriverProxy, Query}, - send_future::SendFuture, -}; -use async_trait::async_trait; -use futures::Future; -use js_sys::Object as JsObject; -use psl::datamodel_connector::Flavour; -use quaint::{ - connector::{metrics, IsolationLevel, Transaction}, - error::{Error, ErrorKind}, - prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, - visitor::{self, Visitor}, -}; -use tracing::{info_span, Instrument}; -use wasm_bindgen::prelude::wasm_bindgen; - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -#[wasm_bindgen(getter_with_clone)] -#[derive(Default)] -pub(crate) struct JsBaseQueryable { - pub(crate) proxy: CommonProxy, - pub flavour: Flavour, -} - -impl JsBaseQueryable { - pub(crate) fn new(proxy: CommonProxy) -> Self { - let flavour: Flavour = proxy.flavour.parse().unwrap(); - Self { proxy, flavour } - } - - /// visit a quaint query AST according to the flavour of the JS connector - fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec>)> { - match self.flavour { - Flavour::Mysql => visitor::Mysql::build(q), - Flavour::Postgres => visitor::Postgres::build(q), - Flavour::Sqlite => visitor::Sqlite::build(q), - _ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour), - } - } - - async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { - let sql: String = sql.to_string(); - - let converter = match self.flavour { - Flavour::Postgres => conversion::postgres::value_to_js_arg, - Flavour::Sqlite => conversion::sqlite::value_to_js_arg, - Flavour::Mysql => conversion::mysql::value_to_js_arg, - _ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour), - }; - - let args = values - .iter() - .map(converter) - .collect::>>()?; - - Ok(Query { sql, args }) - } -} - -#[async_trait] -impl QuaintQueryable for JsBaseQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.query_raw", sql, params, move || async move { - self.do_query_raw(sql, params).await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - let (sql, params) = self.visit_quaint_query(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.execute_raw", sql, params, move || async move { - self.do_execute_raw(sql, params).await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - let params = &[]; - metrics::query("js.raw_cmd", cmd, params, move || async move { - self.do_execute_raw(cmd, params).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> quaint::Result> { - // Note: JS Connectors don't use this method. - Ok(None) - } - - fn is_healthy(&self) -> bool { - // Note: JS Connectors don't use this method. - true - } - - /// Sets the transaction isolation level to given value. - /// Implementers have to make sure that the passed isolation level is valid for the underlying database. - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - if self.flavour == Flavour::Sqlite { - return match isolation_level { - IsolationLevel::Serializable => Ok(()), - _ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()), - }; - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await - } - - fn requires_isolation_first(&self) -> bool { - match self.flavour { - Flavour::Mysql => true, - Flavour::Postgres | Flavour::Sqlite => false, - _ => unreachable!(), - } - } -} - -impl JsBaseQueryable { - pub fn phantom_query_message(stmt: &str) -> String { - format!(r#"-- Implicit "{}" query via underlying driver"#, stmt) - } - - async fn do_query_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; - - let len = result_set.len(); - let _deserialization_span = info_span!("js:query:result", user_facing = true, "length" = %len).entered(); - - result_set.try_into() - } - - fn do_query_raw<'a>( - &'a self, - sql: &'a str, - params: &'a [quaint::Value<'a>], - ) -> SendFuture> + 'a> { - SendFuture(self.do_query_raw_inner(sql, params)) - } - - async fn do_execute_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - let len = params.len(); - let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); - let query = self.build_query(sql, params).instrument(serialization_span).await?; - - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); - let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; - - Ok(affected_rows as u64) - } - - fn do_execute_raw<'a>( - &'a self, - sql: &'a str, - params: &'a [quaint::Value<'a>], - ) -> SendFuture> + 'a> { - SendFuture(self.do_execute_raw_inner(sql, params)) - } -} - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub struct JsQueryable { - inner: JsBaseQueryable, - driver_proxy: DriverProxy, -} - -impl std::fmt::Display for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -impl std::fmt::Debug for JsQueryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JSQueryable(driver)") - } -} - -#[async_trait] -impl QuaintQueryable for JsQueryable { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.query(q).await - } - - async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw(sql, params).await - } - - async fn query_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.query_raw_typed(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.execute(q).await - } - - async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw(sql, params).await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - self.inner.execute_raw_typed(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - self.inner.raw_cmd(cmd).await - } - - async fn version(&self) -> quaint::Result> { - self.inner.version().await - } - - fn is_healthy(&self) -> bool { - self.inner.is_healthy() - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - self.inner.set_tx_isolation_level(isolation_level).await - } - - fn requires_isolation_first(&self) -> bool { - self.inner.requires_isolation_first() - } -} - -#[async_trait] -impl TransactionCapable for JsQueryable { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> quaint::Result> { - let tx = self.driver_proxy.start_transaction().await?; - - let isolation_first = tx.requires_isolation_first(); - - if isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - let begin_stmt = tx.begin_statement(); - - let tx_opts = tx.options(); - if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); - tx.raw_phantom_cmd(begin_stmt.as_str()).await?; - } else { - tx.raw_cmd(begin_stmt).await?; - } - - if !isolation_first { - if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; - } - } - - self.server_reset_query(tx.as_ref()).await?; - - Ok(tx) - } -} - -pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable { - let common = CommonProxy::new(&driver).unwrap(); - let driver_proxy = DriverProxy::new(&driver).unwrap(); - - JsQueryable { - inner: JsBaseQueryable::new(common), - driver_proxy, - } -} From fb365ea33fb977f9b34431091a012636619b1688 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 23 Nov 2023 15:03:25 +0100 Subject: [PATCH 069/134] chore(driver-adapters): add "createOne" and "driverAdapters" preview feature to example --- .../query-engine-wasm/example/example.js | 16 ++++++++++++++++ .../example/prisma/schema.prisma | 1 + 2 files changed, 17 insertions(+) diff --git a/query-engine/query-engine-wasm/example/example.js b/query-engine/query-engine-wasm/example/example.js index 52833baab014..9c8948a2e2d3 100644 --- a/query-engine/query-engine-wasm/example/example.js +++ b/query-engine/query-engine-wasm/example/example.js @@ -37,6 +37,22 @@ async function main() { const queryEngine = new QueryEngine(options, callback, driverAdapter) await queryEngine.connect('trace') + + const created = await queryEngine.query(JSON.stringify({ + modelName: 'User', + action: 'createOne', + query: { + arguments: { + data: { + id: 1234, + }, + }, + selection: { + $scalars: true + } + } + }), 'trace') + const res = await queryEngine.query(JSON.stringify({ modelName: 'User', action: 'findMany', diff --git a/query-engine/query-engine-wasm/example/prisma/schema.prisma b/query-engine/query-engine-wasm/example/prisma/schema.prisma index 93a7c64a6122..8e6b86202536 100644 --- a/query-engine/query-engine-wasm/example/prisma/schema.prisma +++ b/query-engine/query-engine-wasm/example/prisma/schema.prisma @@ -5,6 +5,7 @@ datasource db { generator client { provider = "prisma-client-js" + previewFeatures = ["driverAdapters"] } model User { From 7570e8c4e3db623ad312a72b5d8bc6727ae8fbdf Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 23 Nov 2023 15:04:48 +0100 Subject: [PATCH 070/134] chore(driver-adapters): add wasm-bindgen-test example --- Cargo.lock | 31 +++ Cargo.toml | 1 + query-engine/driver-adapters/Cargo.toml | 3 +- query-engine/driver-adapters/tests/wasm.rs | 275 +++++++++++++++++++++ 4 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 query-engine/driver-adapters/tests/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 23fabd3c52cf..29927bcbd21f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1119,6 +1119,7 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-bindgen-test", "wasm-rs-dbg", ] @@ -4530,6 +4531,12 @@ dependencies = [ "user-facing-errors", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -6103,6 +6110,30 @@ version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +[[package]] +name = "wasm-bindgen-test" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db36fc0f9fb209e88fb3642590ae0205bb5a56216dabd963ba15879fe53a30b" +dependencies = [ + "console_error_panic_hook", + "js-sys", + "scoped-tls", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0734759ae6b3b1717d661fe4f016efcfb9828f5edb4520c18eaee05af3b43be9" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "wasm-logger" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 52b5ddeac055..5480669eb45b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ tsify = { version = "0.4.5" } wasm-bindgen = { version = "0.2.88" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2" } +wasm-bindgen-test = { version = "0.3.0" } [workspace.dependencies.quaint] path = "quaint" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index c32b1b8e1d10..51b9e42ce298 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -26,7 +26,8 @@ futures = "0.3" [dev-dependencies] expect-test = "1" tokio = { version = "1.0", features = ["macros", "time", "sync"] } -wasm-rs-dbg.dev-workspace = true +wasm-rs-dbg.workspace = true +wasm-bindgen-test.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] napi.workspace = true diff --git a/query-engine/driver-adapters/tests/wasm.rs b/query-engine/driver-adapters/tests/wasm.rs new file mode 100644 index 000000000000..c40529978c9f --- /dev/null +++ b/query-engine/driver-adapters/tests/wasm.rs @@ -0,0 +1,275 @@ +use wasm_bindgen_test::*; + +// use driver_adapters::types::ColumnType; +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use tsify::Tsify; +use wasm_bindgen::prelude::*; + +// Recursive expansion of Deserialize macro +// ========================================= +// +// #[doc(hidden)] +// #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] +// const _: () = { +// #[allow(unused_extern_crates, clippy::useless_attribute)] +// extern crate serde as _serde; +// #[automatically_derived] +// impl<'de> _serde::Deserialize<'de> for ColumnType { +// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result +// where +// __D: _serde::Deserializer<'de>, +// { +// #[allow(non_camel_case_types)] +// #[doc(hidden)] +// enum __Field { +// __field0, +// __field1, +// } +// #[doc(hidden)] +// struct __FieldVisitor; + +// impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { +// type Value = __Field; +// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { +// _serde::__private::Formatter::write_str(__formatter, "variant identifier") +// } +// fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// 0u64 => _serde::__private::Ok(__Field::__field0), +// 1u64 => _serde::__private::Ok(__Field::__field1), +// _ => _serde::__private::Err(_serde::de::Error::invalid_value( +// _serde::de::Unexpected::Unsigned(__value), +// &"variant index 0 <= i < 2", +// )), +// } +// } +// fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// "Int32" => _serde::__private::Ok(__Field::__field0), +// "Int64" => _serde::__private::Ok(__Field::__field1), +// _ => _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)), +// } +// } +// fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// b"Int32" => _serde::__private::Ok(__Field::__field0), +// b"Int64" => _serde::__private::Ok(__Field::__field1), +// _ => { +// let __value = &_serde::__private::from_utf8_lossy(__value); +// _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) +// } +// } +// } +// } +// impl<'de> _serde::Deserialize<'de> for __Field { +// #[inline] +// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result +// where +// __D: _serde::Deserializer<'de>, +// { +// _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) +// } +// } +// #[doc(hidden)] +// struct __Visitor<'de> { +// marker: _serde::__private::PhantomData, +// lifetime: _serde::__private::PhantomData<&'de ()>, +// } +// impl<'de> _serde::de::Visitor<'de> for __Visitor<'de> { +// type Value = ColumnType; +// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { +// _serde::__private::Formatter::write_str(__formatter, "enum ColumnType") +// } +// fn visit_enum<__A>(self, __data: __A) -> _serde::__private::Result +// where +// __A: _serde::de::EnumAccess<'de>, +// { +// match _serde::de::EnumAccess::variant(__data)? { +// (__Field::__field0, __variant) => { +// _serde::de::VariantAccess::unit_variant(__variant)?; +// _serde::__private::Ok(ColumnType::Int32) +// } +// (__Field::__field1, __variant) => { +// _serde::de::VariantAccess::unit_variant(__variant)?; +// _serde::__private::Ok(ColumnType::Int64) +// } +// } +// } +// } +// #[doc(hidden)] +// const VARIANTS: &'static [&'static str] = &["Int32", "Int64"]; +// _serde::Deserializer::deserialize_enum( +// __deserializer, +// "ColumnType", +// VARIANTS, +// __Visitor { +// marker: _serde::__private::PhantomData::, +// lifetime: _serde::__private::PhantomData, +// }, +// ) +// } +// } +// }; +// +// +// Recursive expansion of Tsify macro +// =================================== +// +// #[automatically_derived] +// const _: () = { +// extern crate serde as _serde; +// use tsify::Tsify; +// use wasm_bindgen::{ +// convert::{FromWasmAbi, IntoWasmAbi, OptionFromWasmAbi, OptionIntoWasmAbi}, +// describe::WasmDescribe, +// prelude::*, +// }; +// #[wasm_bindgen] +// extern "C" { +// #[wasm_bindgen(typescript_type = "ColumnType")] +// pub type JsType; +// } +// impl Tsify for ColumnType { +// type JsType = JsType; +// const DECL: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; +// } +// #[wasm_bindgen(typescript_custom_section)] +// const TS_APPEND_CONTENT: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; +// impl WasmDescribe for ColumnType { +// #[inline] +// fn describe() { +// ::JsType::describe() +// } +// } +// impl IntoWasmAbi for ColumnType +// where +// Self: _serde::Serialize, +// { +// type Abi = ::Abi; +// #[inline] +// fn into_abi(self) -> Self::Abi { +// self.into_js().unwrap_throw().into_abi() +// } +// } +// impl OptionIntoWasmAbi for ColumnType +// where +// Self: _serde::Serialize, +// { +// #[inline] +// fn none() -> Self::Abi { +// ::none() +// } +// } +// impl FromWasmAbi for ColumnType +// where +// Self: _serde::de::DeserializeOwned, +// { +// type Abi = ::Abi; +// #[inline] +// unsafe fn from_abi(js: Self::Abi) -> Self { +// let result = Self::from_js(&JsType::from_abi(js)); +// if let Err(err) = result { +// wasm_bindgen::throw_str(err.to_string().as_ref()); +// } +// result.unwrap_throw() +// } +// } +// impl OptionFromWasmAbi for ColumnType +// where +// Self: _serde::de::DeserializeOwned, +// { +// #[inline] +// fn is_none(js: &Self::Abi) -> bool { +// ::is_none(js) +// } +// } +// }; +#[derive(Clone, Copy, Debug, Deserialize, Tsify)] +#[tsify(from_wasm_abi)] +pub enum ColumnType { + Int32 = 0, + Int64 = 1, +} + +#[derive(Debug, Deserialize, Tsify)] +#[tsify(from_wasm_abi)] +#[serde(rename_all = "camelCase")] +struct ColumnTypeWrapper { + column_type: ColumnType, +} + +// Recursive expansion of Deserialize_repr macro +// ============================================== +// +// impl<'de> serde::Deserialize<'de> for ColumnTypeWasmBindgen { +// #[allow(clippy::use_self)] +// fn deserialize(deserializer: D) -> ::core::result::Result +// where +// D: serde::Deserializer<'de>, +// { +// #[allow(non_camel_case_types)] +// struct discriminant; + +// #[allow(non_upper_case_globals)] +// impl discriminant { +// const Int32: u8 = ColumnTypeWasmBindgen::Int32 as u8; +// const Int64: u8 = ColumnTypeWasmBindgen::Int64 as u8; +// } +// match ::deserialize(deserializer)? { +// discriminant::Int32 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int32), +// discriminant::Int64 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int64), +// other => ::core::result::Result::Err(serde::de::Error::custom(format_args!( +// "invalid value: {}, expected {} or {}", +// other, +// discriminant::Int32, +// discriminant::Int64 +// ))), +// } +// } +// } +#[derive(Debug, Deserialize_repr, Tsify)] +#[tsify(from_wasm_abi)] +#[repr(u8)] +pub enum ColumnTypeWasmBindgen { + // #[serde(rename = "0")] + Int32 = 0, + + // #[serde(rename = "1")] + Int64 = 1, +} + +#[wasm_bindgen_test] +fn column_type_test() { + // Example deserialization code + let json_data = r#"0"#; + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + + // let json_data = "\"0\""; + let column_type = serde_json::from_str::(&json_data).unwrap(); +} + +// #[wasm_bindgen_test] +// fn column_type_test() { +// // Example deserialization code +// let json_data = r#"{ "columnType": 0 }"#; +// let column_type_wrapper = serde_json::from_str::(json_data); + +// panic!("{:?}", column_type_wrapper); +// } From d3b1ae9317aad504981d11093889292cee979291 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 23 Nov 2023 15:10:32 +0100 Subject: [PATCH 071/134] chore(driver-adapters): update Cargo.lock --- Cargo.lock | 2 -- 1 file changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 896305d796c7..0b241152f274 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3534,7 +3534,6 @@ dependencies = [ "serde_json", "url", "wasm-bindgen", - "wasm-rs-dbg", ] [[package]] @@ -3851,7 +3850,6 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", - "serde_repr", "sql-query-connector", "thiserror", "tokio", From 11526ce498638607db535e3ec57ba30b94a4fc96 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Thu, 23 Nov 2023 15:31:48 +0100 Subject: [PATCH 072/134] Handle non-promise return values --- .../driver-adapters/src/wasm/async_js_function.rs | 13 +++++++++---- query-engine/query-engine-wasm/example/example.js | 2 ++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index fe64847b5b1e..0f8fd310c75f 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -3,7 +3,7 @@ use serde::Serialize; use std::marker::PhantomData; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::describe::WasmDescribe; -use wasm_bindgen::{JsError, JsValue}; +use wasm_bindgen::{JsCast, JsError, JsValue}; use wasm_bindgen_futures::JsFuture; use super::error::into_quaint_error; @@ -52,9 +52,14 @@ where async fn call_internal(&self, arg1: T) -> Result, JsValue> { let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsValue::from(JsError::from(&err)))?; - let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; - let future = JsFuture::from(JsPromise::from(promise)); - let value = future.await?; + let return_value = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; + + let value = if let Some(promise) = return_value.dyn_ref::() { + JsFuture::from(promise.to_owned()).await? + } else { + return_value + }; + let js_result = JsResult::::from_js_value(value)?; Ok(js_result) diff --git a/query-engine/query-engine-wasm/example/example.js b/query-engine/query-engine-wasm/example/example.js index 9c8948a2e2d3..57093a0222bb 100644 --- a/query-engine/query-engine-wasm/example/example.js +++ b/query-engine/query-engine-wasm/example/example.js @@ -53,6 +53,8 @@ async function main() { } }), 'trace') + console.log({ created }) + const res = await queryEngine.query(JSON.stringify({ modelName: 'User', action: 'findMany', From ab2937a86e8401bf8b4db632b87cb06c5a7aeba0 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Fri, 24 Nov 2023 10:36:31 +0100 Subject: [PATCH 073/134] Run tests on WASM --- .../test-query-engine-driver-adapters.yml | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test-query-engine-driver-adapters.yml b/.github/workflows/test-query-engine-driver-adapters.yml index 0a9c933c9b58..08bd3b192eac 100644 --- a/.github/workflows/test-query-engine-driver-adapters.yml +++ b/.github/workflows/test-query-engine-driver-adapters.yml @@ -33,19 +33,18 @@ jobs: setup_task: 'dev-neon-js' - name: '@prisma/adapter-libsql (Turso) (napi)' setup_task: 'dev-libsql-js' - # TODO: uncomment when WASM engine is functional - # - name: '@prisma/adapter-planetscale' - # setup_task: 'dev-planetscale-wasm' - # needs_wasm_pack: true - # - name: '@prisma/adapter-pg (wasm)' - # setup_task: 'dev-pg-wasm' - # needs_wasm_pack: true - # - name: '@prisma/adapter-neon (ws) (wasm)' - # setup_task: 'dev-neon-wasm' - # needs_wasm_pack: true - # - name: '@prisma/adapter-libsql (Turso) (wasm)' - # setup_task: 'dev-libsql-wasm' - # needs_wasm_pack: true + - name: '@prisma/adapter-planetscale' + setup_task: 'dev-planetscale-wasm' + needs_wasm_pack: true + - name: '@prisma/adapter-pg (wasm)' + setup_task: 'dev-pg-wasm' + needs_wasm_pack: true + - name: '@prisma/adapter-neon (ws) (wasm)' + setup_task: 'dev-neon-wasm' + needs_wasm_pack: true + - name: '@prisma/adapter-libsql (Turso) (wasm)' + setup_task: 'dev-libsql-wasm' + needs_wasm_pack: true node_version: ['18'] env: LOG_LEVEL: 'info' # Set to "debug" to trace the query engine and node process running the driver adapter From 586c5a40696c4b2ea4264623546e0b3d54671c16 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 24 Nov 2023 13:35:02 +0100 Subject: [PATCH 074/134] feat(core): allow Drop'ing futures running in loop in wasm32-* via controlled spawns --- query-engine/core/Cargo.toml | 1 + query-engine/core/src/executor/task.rs | 87 +++++++++++++++++-- .../interactive_transactions/actor_manager.rs | 1 - .../src/interactive_transactions/actors.rs | 41 ++++++--- 4 files changed, 110 insertions(+), 20 deletions(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 42e7ee301525..b7d2d971b225 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,6 +37,7 @@ schema = { path = "../schema" } elapsed = { path = "../../libs/elapsed" } lru = "0.7.7" enumflags2 = "0.7" +wasm-rs-dbg.workspace = true pin-project = "1" wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs index 8d1c39bbcd06..1fe69c240d2c 100644 --- a/query-engine/core/src/executor/task.rs +++ b/query-engine/core/src/executor/task.rs @@ -1,21 +1,56 @@ //! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. -pub use arch::{spawn, JoinHandle}; +pub use arch::{spawn, spawn_controlled, JoinHandle}; use futures::Future; // On native targets, `tokio::spawn` spawns a new asynchronous task. #[cfg(not(target_arch = "wasm32"))] mod arch { use super::*; + use tokio::sync::broadcast::{self}; - pub type JoinHandle = tokio::task::JoinHandle; + pub struct JoinHandle { + handle: tokio::task::JoinHandle, + + sx_exit: Option>, + } + + impl JoinHandle { + pub fn abort(&mut self) { + if let Some(sx_exit) = self.sx_exit.as_ref() { + sx_exit.send(()).ok(); + } + + self.handle.abort(); + } + } pub fn spawn(future: T) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - tokio::spawn(future) + spawn_with_sx_exit::(future, None) + } + + pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); + let future = future_fn(rx_exit); + + spawn_with_sx_exit::(future, Some(sx_exit)) + } + + fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let handle = tokio::spawn(future); + JoinHandle { handle, sx_exit } } } @@ -23,28 +58,63 @@ mod arch { #[cfg(target_arch = "wasm32")] mod arch { use super::*; - use tokio::sync::oneshot::{self}; + use tokio::sync::{ + broadcast::{self}, + oneshot::{self}, + }; + use wasm_rs_dbg::dbg; // Wasm-compatible alternative to `tokio::task::JoinHandle`. // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. - pub struct JoinHandle(oneshot::Receiver); + #[pin_project::pin_project] + pub struct JoinHandle { + #[pin] + receiver: oneshot::Receiver, + + sx_exit: Option>, + } impl Future for JoinHandle { type Output = Result; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + dbg!("JoinHandle::poll"); + // the `self.project()` method is provided by the `pin_project` macro - core::pin::Pin::new(&mut self.0).poll(cx) + core::pin::Pin::new(&mut self.receiver).poll(cx) } } impl JoinHandle { pub fn abort(&mut self) { - // abort is noop on Wasm targets + dbg!("JoinHandle::abort"); + + if let Some(sx_exit) = self.sx_exit.as_ref() { + dbg!("JoinHandle::abort - Send sx_exit"); + sx_exit.send(()).ok(); + } } } pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + spawn_with_sx_exit::(future, None) + } + + pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); + let future = future_fn(rx_exit); + spawn_with_sx_exit::(future, Some(sx_exit)) + } + + fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, @@ -54,6 +124,7 @@ mod arch { let result = future.await; sender.send(result).ok(); }); - JoinHandle(receiver) + + JoinHandle { receiver, sx_exit } } } diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 105733be4166..e7c3c770c7e5 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -37,7 +37,6 @@ pub struct TransactionActorManager { impl Drop for TransactionActorManager { fn drop(&mut self) { - debug!("DROPPING TPM"); self.bg_reader_clear.abort(); } } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 104ffc26812f..0aac2b341bcd 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,5 +1,5 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; -use crate::executor::task::{spawn, JoinHandle}; +use crate::executor::task::{spawn, spawn_controlled, JoinHandle}; use crate::{ execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, TxId, @@ -17,6 +17,7 @@ use tokio::{ use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; +use wasm_rs_dbg::dbg; #[cfg(feature = "metrics")] use crate::telemetry::helpers::set_span_link_from_traceparent; @@ -385,17 +386,35 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { - spawn(async move { - loop { - if let Some((id, closed_tx)) = rx.recv().await { - trace!("removing {} from client list", id); + spawn_controlled(Box::new( + |mut rx_exit: tokio::sync::broadcast::Receiver<()>| async move { + loop { + tokio::select! { + result = rx.recv() => { + dbg!("spawn_controlled - AFTER rx.recv(): {:?}", result.is_some()); + match result { + Some((id, closed_tx)) => { + trace!("removing {} from client list", id); - let mut clients_guard = clients.write().await; - clients_guard.remove(&id); - drop(clients_guard); + let mut clients_guard = clients.write().await; - closed_txs.write().await.put(id, closed_tx); + clients_guard.remove(&id); + drop(clients_guard); + + closed_txs.write().await.put(id, closed_tx); + } + None => { + // the `rx` channel is closed. + break; + } + } + }, + _ = rx_exit.recv() => { + dbg!("spawn_controlled - AFTER rx_exit.recv()"); + break; + }, + } } - } - }) + }, + )) } From 68719a370e9c643182013f5a479f3b5f65a18420 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 24 Nov 2023 13:35:23 +0100 Subject: [PATCH 075/134] feat(chore): add Arc comment --- query-engine/connectors/sql-query-connector/src/database/js.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index 0d4714871e59..449b9053bef1 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -13,6 +13,8 @@ use quaint::{ }; use std::sync::{Arc, Mutex}; +// TODO: evaluate turning this into `Lazy>>>` to avoid +// a clone+drop on the adapter passed via `Js::from_source`. static ACTIVE_DRIVER_ADAPTER: Lazy>> = Lazy::new(|| Mutex::new(None)); fn active_driver_adapter(provider: &str) -> connector::Result { From 8173e61596b80353a62fe0e72b4d442ecce02fab Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 24 Nov 2023 13:35:54 +0100 Subject: [PATCH 076/134] feat(chore): re-enable previous "disconnect()" --- .../query-engine-wasm/src/wasm/engine.rs | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 8a455b6f4aa9..183d4940d1c8 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -30,7 +30,6 @@ use tracing_subscriber::filter::LevelFilter; use tsify::Tsify; use user_facing_errors::Error; use wasm_bindgen::prelude::wasm_bindgen; - /// The main query engine used by JS #[wasm_bindgen] pub struct QueryEngine { @@ -279,32 +278,28 @@ impl QueryEngine { /// Disconnect and drop the core. Can be reconnected later with `#connect`. #[wasm_bindgen] pub async fn disconnect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - // async_panic_to_js_error(async { - // let span = tracing::info_span!("prisma:engine:disconnect"); - - // TODO: when using Node Drivers, we need to call Driver::close() here. + async_panic_to_js_error(async { + let span = tracing::info_span!("prisma:engine:disconnect"); - // async { - let mut inner = self.inner.write().await; - let engine = inner.as_engine()?; + async { + let mut inner = self.inner.write().await; + let engine = inner.as_engine()?; - let builder = EngineBuilder { - schema: engine.schema.clone(), - config_dir: engine.config_dir.clone(), - env: engine.env.clone(), - engine_protocol: engine.engine_protocol(), - }; + let builder = EngineBuilder { + schema: engine.schema.clone(), + config_dir: engine.config_dir.clone(), + env: engine.env.clone(), + engine_protocol: engine.engine_protocol(), + }; - log::info!("Recreated builder"); - *inner = Inner::Builder(builder); - log::info!("Recreated inner builder"); + *inner = Inner::Builder(builder); - Ok(()) - // } - // .instrument(span) - // .await - // }) - // .await + Ok(()) + } + .instrument(span) + .await + }) + .await } /// If connected, sends a query to the core and returns the response. From fc2bd7f00040e537289e2266da4cc8eb5f8d6bc2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 24 Nov 2023 13:38:22 +0100 Subject: [PATCH 077/134] feat(chore): update example.js --- .../query-engine-wasm/example/example.js | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/query-engine/query-engine-wasm/example/example.js b/query-engine/query-engine-wasm/example/example.js index 57093a0222bb..154f901a6e0a 100644 --- a/query-engine/query-engine-wasm/example/example.js +++ b/query-engine/query-engine-wasm/example/example.js @@ -44,7 +44,7 @@ async function main() { query: { arguments: { data: { - id: 1234, + id: 1235, }, }, selection: { @@ -66,17 +66,25 @@ async function main() { } }), 'trace') const parsed = JSON.parse(res); - console.log('query result = ', parsed) + console.log('query result = ') + console.dir(parsed, { depth: null }) const error = parsed.errors?.[0]?.user_facing_error if (error?.error_code === 'P2036') { console.log('js error:', driverAdapter.errorRegistry.consumeError(error.meta.id)) } - // if (res.error.user_facing_error.code =) + + // console.log('before disconnect') await queryEngine.disconnect('trace') - console.log('after disconnect') - queryEngine.free() + // console.log('after disconnect') + + // console.log('before close') await driverAdapter.close() + // console.log('after close') + + // console.log('before free') + queryEngine.free() + // console.log('after free') } main() From 20d8a40b2a8249998ca74dd90281a1a9e268c8a2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 24 Nov 2023 13:43:37 +0100 Subject: [PATCH 078/134] fix(query-engine-node-api): fix compilation errors --- Cargo.lock | 2 ++ query-engine/driver-adapters/Cargo.toml | 2 +- query-engine/driver-adapters/src/queryable/napi.rs | 4 ++-- query-engine/query-engine-wasm/Cargo.toml | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0b241152f274..0ff54b6d6e71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3715,6 +3715,7 @@ dependencies = [ "user-facing-errors", "uuid", "wasm-bindgen-futures", + "wasm-rs-dbg", ] [[package]] @@ -3862,6 +3863,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "wasm-logger", + "wasm-rs-dbg", ] [[package]] diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 081b8a402f5c..befd68125880 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -14,6 +14,7 @@ metrics = "0.18" uuid = { version = "1", features = ["v4"] } pin-project = "1" wasm-rs-dbg = "0.1.2" +serde_repr.workspace = true # Note: these deps are temporarily specified here to avoid importing them from tiberius (the SQL server driver). # They will be imported from quaint-core instead in a future PR. @@ -37,7 +38,6 @@ quaint.workspace = true quaint = { path = "../../quaint" } js-sys.workspace = true serde-wasm-bindgen.workspace = true -serde_repr.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tsify.workspace = true diff --git a/query-engine/driver-adapters/src/queryable/napi.rs b/query-engine/driver-adapters/src/queryable/napi.rs index 2245802908c2..b7f4cf49028b 100644 --- a/query-engine/driver-adapters/src/queryable/napi.rs +++ b/query-engine/driver-adapters/src/queryable/napi.rs @@ -1,7 +1,7 @@ use crate::napi::proxy::{CommonProxy, DriverProxy}; use crate::JsQueryable; use napi::JsObject; -use psl::datamodel_connector::Flavour; +use crate::types::AdapterFlavour; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the /// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint @@ -18,7 +18,7 @@ use psl::datamodel_connector::Flavour; /// pub(crate) struct JsBaseQueryable { pub(crate) proxy: CommonProxy, - pub flavour: Flavour, + pub flavour: AdapterFlavour, } pub fn from_napi(driver: JsObject) -> JsQueryable { diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 1f9f9e65475c..23e1f9d0acbb 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -32,6 +32,7 @@ serde_json.workspace = true tsify.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true +wasm-rs-dbg.workspace = true prisma-models = { path = "../prisma-models" } From 18a81276fd2fb4f9d33a547d30c15d5bd5e8e90d Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Fri, 24 Nov 2023 16:08:36 +0100 Subject: [PATCH 079/134] Update cuid --- Cargo.lock | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0ff54b6d6e71..fe9937b45e4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,7 +23,7 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom 0.2.10", + "getrandom 0.2.11", "once_cell", "version_check", ] @@ -871,11 +871,13 @@ dependencies = [ [[package]] name = "cuid" version = "1.3.2" -source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#81309f9a11f70d178bb545971d51ceb7da692c52" +source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#ccfd958c224c79758c2527a0bca9efcd71790a19" dependencies = [ "base36", "cuid-util", "cuid2", + "getrandom 0.2.11", + "js-sys", "num", "once_cell", "rand 0.8.5", @@ -885,12 +887,12 @@ dependencies = [ [[package]] name = "cuid-util" version = "0.1.0" -source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#81309f9a11f70d178bb545971d51ceb7da692c52" +source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#ccfd958c224c79758c2527a0bca9efcd71790a19" [[package]] name = "cuid2" version = "0.1.2" -source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#81309f9a11f70d178bb545971d51ceb7da692c52" +source = "git+https://github.com/prisma/cuid-rust?branch=wasm32-support#ccfd958c224c79758c2527a0bca9efcd71790a19" dependencies = [ "cuid-util", "num", @@ -1565,9 +1567,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "js-sys", @@ -2126,9 +2128,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -3368,7 +3370,7 @@ dependencies = [ "bigdecimal", "chrono", "cuid", - "getrandom 0.2.10", + "getrandom 0.2.11", "itertools", "nanoid", "prisma-value", @@ -3588,7 +3590,7 @@ dependencies = [ "either", "elapsed", "futures", - "getrandom 0.2.10", + "getrandom 0.2.11", "hex", "indoc 0.3.6", "lru-cache", @@ -4041,7 +4043,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.10", + "getrandom 0.2.11", ] [[package]] @@ -5805,7 +5807,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.8.5", + "rand 0.3.23", "static_assertions", ] @@ -5951,7 +5953,7 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ - "getrandom 0.2.10", + "getrandom 0.2.11", "serde", ] From d04804d838eedc282541579ed41d183acd647231 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Fri, 24 Nov 2023 16:52:06 +0100 Subject: [PATCH 080/134] Remove one more isntant usage --- .../core/src/interactive_transactions/actor_manager.rs | 2 +- .../core/src/interactive_transactions/actors.rs | 5 +++-- query-engine/core/src/interactive_transactions/mod.rs | 10 +++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index e7c3c770c7e5..f2d1f539ebbf 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -107,7 +107,7 @@ impl TransactionActorManager { of the transaction. Consider increasing the interactive transaction timeout \ or doing less work in the transaction", timeout.as_millis(), - start_time.elapsed().as_millis(), + start_time.elapsed_time().as_millis(), ) } None => { diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 0aac2b341bcd..58d24c528261 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -5,6 +5,7 @@ use crate::{ TxId, }; use connector::Connection; +use elapsed::ElapsedTimeCounter; use schema::QuerySchemaRef; use std::{collections::HashMap, sync::Arc}; use tokio::{ @@ -12,7 +13,7 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - time::{self, Duration, Instant}, + time::{self, Duration}, }; use tracing::Span; use tracing_futures::Instrument; @@ -297,7 +298,7 @@ pub(crate) async fn spawn_itx_actor( query_schema, ); - let start_time = Instant::now(); + let start_time = ElapsedTimeCounter::start(); let sleep = time::sleep(timeout); tokio::pin!(sleep); diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index ce125e8fa17e..ac92d52efcf2 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,8 +1,9 @@ use crate::CoreError; use connector::Transaction; +use elapsed::ElapsedTimeCounter; use serde::Deserialize; use std::fmt::Display; -use tokio::time::{Duration, Instant}; +use tokio::time::Duration; mod actor_manager; mod actors; @@ -104,7 +105,7 @@ impl<'a> CachedTx<'a> { } } - pub(crate) fn to_closed(&self, start_time: Instant, timeout: Duration) -> Option { + pub(crate) fn to_closed(&self, start_time: ElapsedTimeCounter, timeout: Duration) -> Option { match self { CachedTx::Open(_) => None, CachedTx::Committed => Some(ClosedTx::Committed), @@ -117,5 +118,8 @@ impl<'a> CachedTx<'a> { pub(crate) enum ClosedTx { Committed, RolledBack, - Expired { start_time: Instant, timeout: Duration }, + Expired { + start_time: ElapsedTimeCounter, + timeout: Duration, + }, } From 782504867e493dad6530be11370106940f547267 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 27 Nov 2023 15:05:37 +0100 Subject: [PATCH 081/134] chore(driver-adapters): simplify async_js_function API --- .../src/conversion/js_to_quaint.rs | 1 - .../src/napi/async_js_function.rs | 6 +++--- query-engine/driver-adapters/src/napi/proxy.rs | 6 +----- .../src/wasm/async_js_function.rs | 16 +++++++--------- query-engine/driver-adapters/src/wasm/proxy.rs | 2 +- 5 files changed, 12 insertions(+), 19 deletions(-) diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index 2e7dd355bec2..f0b7de772c5e 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; use std::str::FromStr; -use std::sync::atomic::{AtomicBool, Ordering}; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use quaint::{ diff --git a/query-engine/driver-adapters/src/napi/async_js_function.rs b/query-engine/driver-adapters/src/napi/async_js_function.rs index e0848f1f9f5f..d62931e2c767 100644 --- a/query-engine/driver-adapters/src/napi/async_js_function.rs +++ b/query-engine/driver-adapters/src/napi/async_js_function.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use napi::{ bindgen_prelude::*, - threadsafe_function::{ErrorStrategy, ThreadsafeFunction}, + threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, }; use super::{ @@ -56,8 +56,8 @@ where js_result.into() } - pub(crate) fn as_raw(&self) -> &ThreadsafeFunction { - &self.threadsafe_fn + pub(crate) fn call_non_blocking(&self, arg: ArgType) { + _ = self.threadsafe_fn.call(arg, ThreadsafeFunctionCallMode::NonBlocking); } } diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index 34bfd32351cf..7a9c760fa676 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -3,7 +3,6 @@ pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use super::async_js_function::AsyncJsFunction; use super::transaction::JsTransaction; use metrics::increment_gauge; -use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; use napi::{JsObject, JsString}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -148,9 +147,6 @@ impl Drop for TransactionProxy { return; } - _ = self - .rollback - .as_raw() - .call((), napi::threadsafe_function::ThreadsafeFunctionCallMode::NonBlocking); + _ = self.rollback.call_non_blocking(()); } } diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 876ec441df55..26217fa8e6e1 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -16,7 +16,7 @@ where ArgType: Serialize, ReturnType: FromJsValue, { - pub threadsafe_fn: JsFunction, + threadsafe_fn: JsFunction, _phantom_arg: PhantomData, _phantom_return: PhantomData, @@ -41,7 +41,7 @@ where T: Serialize, R: FromJsValue, { - pub async fn call(&self, arg1: T) -> quaint::Result { + pub(crate) async fn call(&self, arg1: T) -> quaint::Result { let result = self.call_internal(arg1).await; match result { @@ -65,8 +65,10 @@ where Ok(js_result) } - pub(crate) fn as_raw(&self) -> &JsFunction { - &self.threadsafe_fn + pub(crate) fn call_non_blocking(&self, arg: T) { + if let Ok(arg) = serde_wasm_bindgen::to_value(&arg) { + _ = self.threadsafe_fn.call1(&JsValue::null(), &arg); + } } } @@ -88,10 +90,6 @@ where type Abi = ::Abi; unsafe fn from_abi(js: Self::Abi) -> Self { - Self { - threadsafe_fn: JsFunction::from_abi(js), - _phantom_arg: PhantomData:: {}, - _phantom_return: PhantomData:: {}, - } + JsFunction::from_abi(js).into() } } diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index 30f45b238a71..84c06e867619 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -131,7 +131,7 @@ impl Drop for TransactionProxy { return; } - _ = self.rollback.as_raw().call0(&JsValue::null()); + _ = self.rollback.call_non_blocking(()); } } From 6bedf793e883be5a790d389a3154c7397bb8c957 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 27 Nov 2023 15:17:30 +0100 Subject: [PATCH 082/134] chore(driver-adapters): unify napi/wasm errors into "crate::JsResult" --- query-engine/driver-adapters/src/lib.rs | 14 ++++++++++++++ query-engine/driver-adapters/src/napi/proxy.rs | 7 ++++--- query-engine/driver-adapters/src/wasm/proxy.rs | 15 +++++++-------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 0e40b814c43f..625d8f6bcbd5 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -24,3 +24,17 @@ pub mod wasm; #[cfg(target_arch = "wasm32")] pub use wasm::*; + +#[cfg(target_arch = "wasm32")] +mod arch { + use wasm_bindgen::JsValue; + + pub(crate) type JsResult = core::result::Result; +} + +#[cfg(not(target_arch = "wasm32"))] +mod arch { + pub(crate) type JsResult = napi::Result; +} + +pub(crate) use arch::JsResult; diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index 7a9c760fa676..fd61f87847be 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -1,4 +1,5 @@ pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; +use crate::JsResult; use super::async_js_function::AsyncJsFunction; use super::transaction::JsTransaction; @@ -43,7 +44,7 @@ pub(crate) struct TransactionProxy { } impl CommonProxy { - pub fn new(object: &JsObject) -> napi::Result { + pub fn new(object: &JsObject) -> JsResult { let flavour: JsString = object.get_named_property("flavour")?; Ok(Self { @@ -63,7 +64,7 @@ impl CommonProxy { } impl DriverProxy { - pub fn new(driver_adapter: &JsObject) -> napi::Result { + pub fn new(driver_adapter: &JsObject) -> JsResult { Ok(Self { start_transaction: driver_adapter.get_named_property("startTransaction")?, }) @@ -82,7 +83,7 @@ impl DriverProxy { } impl TransactionProxy { - pub fn new(js_transaction: &JsObject) -> napi::Result { + pub fn new(js_transaction: &JsObject) -> JsResult { let commit = js_transaction.get_named_property("commit")?; let rollback = js_transaction.get_named_property("rollback")?; let options = js_transaction.get_named_property("options")?; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index 84c06e867619..b9d98739d698 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,16 +1,15 @@ -use futures::Future; -use js_sys::{Function as JsFunction, JsString}; -use tsify::Tsify; - -use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; use crate::send_future::SendFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use crate::JsObjectExtern; +use crate::JsResult; + +use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; +use futures::Future; +use js_sys::{Function as JsFunction, JsString}; use metrics::increment_gauge; use std::sync::atomic::{AtomicBool, Ordering}; -use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; - -type JsResult = core::result::Result; +use tsify::Tsify; +use wasm_bindgen::prelude::wasm_bindgen; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for /// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to From 56fc689e22abee3d5e08bb14de1eb0510155c301 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 27 Nov 2023 15:59:58 +0100 Subject: [PATCH 083/134] chore(driver-adapters): continue unifying napi/wasm functions --- query-engine/driver-adapters/src/lib.rs | 30 ++++++++++++++++++- .../driver-adapters/src/napi/proxy.rs | 19 ++++++------ .../src/wasm/js_object_extern.rs | 3 +- .../driver-adapters/src/wasm/proxy.rs | 7 +++-- 4 files changed, 44 insertions(+), 15 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 625d8f6bcbd5..774391268770 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -27,14 +27,42 @@ pub use wasm::*; #[cfg(target_arch = "wasm32")] mod arch { + use crate::JsObjectExtern as JsObject; + pub(crate) use js_sys::JsString; use wasm_bindgen::JsValue; + pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult + where + T: From, + { + // object.get("queryRaw".into())? + Ok(object.get(name.into())?.into()) + } + + pub(crate) fn to_rust_str(value: JsString) -> JsResult { + Ok(value.into()) + } + pub(crate) type JsResult = core::result::Result; } #[cfg(not(target_arch = "wasm32"))] mod arch { + use napi::bindgen_prelude::FromNapiValue; + pub(crate) use napi::{JsObject, JsString}; + + pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult + where + T: FromNapiValue, + { + object.get_named_property(name).into() + } + + pub(crate) fn to_rust_str(value: JsString) -> JsResult { + Ok(value.into_utf8()?.as_str()?.to_string()) + } + pub(crate) type JsResult = napi::Result; } -pub(crate) use arch::JsResult; +pub(crate) use arch::*; diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index fd61f87847be..753fdfb56616 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -1,10 +1,9 @@ pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; -use crate::JsResult; +use crate::{get_named_property, to_rust_str, JsObject, JsResult, JsString}; use super::async_js_function::AsyncJsFunction; use super::transaction::JsTransaction; use metrics::increment_gauge; -use napi::{JsObject, JsString}; use std::sync::atomic::{AtomicBool, Ordering}; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for @@ -45,12 +44,12 @@ pub(crate) struct TransactionProxy { impl CommonProxy { pub fn new(object: &JsObject) -> JsResult { - let flavour: JsString = object.get_named_property("flavour")?; + let flavour: JsString = get_named_property(object, "flavour")?; Ok(Self { - query_raw: object.get_named_property("queryRaw")?, - execute_raw: object.get_named_property("executeRaw")?, - flavour: flavour.into_utf8()?.as_str()?.to_owned(), + query_raw: get_named_property(object, "queryRaw")?, + execute_raw: get_named_property(object, "executeRaw")?, + flavour: to_rust_str(flavour)?, }) } @@ -66,7 +65,7 @@ impl CommonProxy { impl DriverProxy { pub fn new(driver_adapter: &JsObject) -> JsResult { Ok(Self { - start_transaction: driver_adapter.get_named_property("startTransaction")?, + start_transaction: get_named_property(driver_adapter, "startTransaction")?, }) } @@ -84,9 +83,9 @@ impl DriverProxy { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> JsResult { - let commit = js_transaction.get_named_property("commit")?; - let rollback = js_transaction.get_named_property("rollback")?; - let options = js_transaction.get_named_property("options")?; + let commit = get_named_property(js_transaction, "commit")?; + let rollback = get_named_property(js_transaction, "rollback")?; + let options = get_named_property(js_transaction, "options")?; let closed = AtomicBool::new(false); Ok(Self { diff --git a/query-engine/driver-adapters/src/wasm/js_object_extern.rs b/query-engine/driver-adapters/src/wasm/js_object_extern.rs index 8804e706f67f..29abc1d6ef6f 100644 --- a/query-engine/driver-adapters/src/wasm/js_object_extern.rs +++ b/query-engine/driver-adapters/src/wasm/js_object_extern.rs @@ -1,8 +1,9 @@ -use js_sys::JsString; +use js_sys::{JsString, Object as JsObject}; use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; #[wasm_bindgen] extern "C" { + #[wasm_bindgen(js_name = String, extends = JsObject, is_type_of = JsValue::is_object, typescript_type = "object")] pub type JsObjectExtern; #[wasm_bindgen(method, catch, structural, indexing_getter)] diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index b9d98739d698..dc64a0c24cba 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,7 +1,7 @@ use crate::send_future::SendFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use crate::JsObjectExtern; -use crate::JsResult; +use crate::{get_named_property, to_rust_str, JsResult}; use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; use futures::Future; @@ -53,12 +53,13 @@ pub(crate) struct TransactionProxy { impl CommonProxy { pub fn new(object: &JsObjectExtern) -> JsResult { - let flavour: String = JsString::from(object.get("flavour".into())?).into(); + let flavour: JsString = get_named_property(object, "flavour")?; Ok(Self { + // TODO: remove the need for `JsFunction::from` (?) query_raw: JsFunction::from(object.get("queryRaw".into())?).into(), execute_raw: JsFunction::from(object.get("executeRaw".into())?).into(), - flavour, + flavour: to_rust_str(flavour)?, }) } From 1474f7b3d89a83108ec1c2ce0a759ae62805bbd5 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 27 Nov 2023 19:13:45 +0100 Subject: [PATCH 084/134] chore(driver-adapters): unify napi/wasm logic for proxy --- query-engine/driver-adapters/src/lib.rs | 14 +++++++- .../driver-adapters/src/napi/proxy.rs | 6 ++-- .../src/wasm/async_js_function.rs | 10 ++++++ .../driver-adapters/src/wasm/proxy.rs | 32 +++++++++---------- query-engine/query-engine-wasm/build.sh | 11 ++++++- 5 files changed, 51 insertions(+), 22 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 774391268770..7cf7f928cb5a 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -27,8 +27,9 @@ pub use wasm::*; #[cfg(target_arch = "wasm32")] mod arch { - use crate::JsObjectExtern as JsObject; + pub(crate) use crate::JsObjectExtern as JsObject; pub(crate) use js_sys::JsString; + use tsify::Tsify; use wasm_bindgen::JsValue; pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult @@ -43,6 +44,13 @@ mod arch { Ok(value.into()) } + pub(crate) fn from_js(value: JsValue) -> C + where + C: Tsify + serde::de::DeserializeOwned, + { + C::from_js(value).unwrap() + } + pub(crate) type JsResult = core::result::Result; } @@ -62,6 +70,10 @@ mod arch { Ok(value.into_utf8()?.as_str()?.to_string()) } + pub(crate) fn from_js(value: C) -> C { + value + } + pub(crate) type JsResult = napi::Result; } diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/napi/proxy.rs index 753fdfb56616..298c4a2ba35b 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/napi/proxy.rs @@ -1,5 +1,5 @@ pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; -use crate::{get_named_property, to_rust_str, JsObject, JsResult, JsString}; +use crate::{from_js, get_named_property, to_rust_str, JsObject, JsResult, JsString}; use super::async_js_function::AsyncJsFunction; use super::transaction::JsTransaction; @@ -86,13 +86,13 @@ impl TransactionProxy { let commit = get_named_property(js_transaction, "commit")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; - let closed = AtomicBool::new(false); + let options = from_js::(options); Ok(Self { commit, rollback, options, - closed, + closed: AtomicBool::new(false), }) } diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 26217fa8e6e1..289de651ff64 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -22,6 +22,16 @@ where _phantom_return: PhantomData, } +impl From for AsyncJsFunction +where + T: Serialize, + R: FromJsValue, +{ + fn from(js_value: JsValue) -> Self { + JsFunction::from(js_value).into() + } +} + impl From for AsyncJsFunction where T: Serialize, diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs index dc64a0c24cba..607e4440bbf7 100644 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ b/query-engine/driver-adapters/src/wasm/proxy.rs @@ -1,14 +1,11 @@ use crate::send_future::SendFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; -use crate::JsObjectExtern; -use crate::{get_named_property, to_rust_str, JsResult}; +use crate::{from_js, get_named_property, to_rust_str, JsObject, JsResult, JsString}; use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; use futures::Future; -use js_sys::{Function as JsFunction, JsString}; use metrics::increment_gauge; use std::sync::atomic::{AtomicBool, Ordering}; -use tsify::Tsify; use wasm_bindgen::prelude::wasm_bindgen; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for @@ -52,13 +49,12 @@ pub(crate) struct TransactionProxy { } impl CommonProxy { - pub fn new(object: &JsObjectExtern) -> JsResult { + pub fn new(object: &JsObject) -> JsResult { let flavour: JsString = get_named_property(object, "flavour")?; Ok(Self { - // TODO: remove the need for `JsFunction::from` (?) - query_raw: JsFunction::from(object.get("queryRaw".into())?).into(), - execute_raw: JsFunction::from(object.get("executeRaw".into())?).into(), + query_raw: get_named_property(object, "queryRaw")?, + execute_raw: get_named_property(object, "executeRaw")?, flavour: to_rust_str(flavour)?, }) } @@ -73,9 +69,9 @@ impl CommonProxy { } impl DriverProxy { - pub fn new(object: &JsObjectExtern) -> JsResult { + pub fn new(object: &JsObject) -> JsResult { Ok(Self { - start_transaction: JsFunction::from(object.get("startTransaction".into())?).into(), + start_transaction: get_named_property(object, "startTransaction")?, }) } @@ -98,15 +94,17 @@ impl DriverProxy { } impl TransactionProxy { - pub fn new(object: &JsObjectExtern) -> JsResult { - let options = object.get("options".into())?; - let closed = AtomicBool::new(false); + pub fn new(js_transaction: &JsObject) -> JsResult { + let commit = get_named_property(js_transaction, "commit")?; + let rollback = get_named_property(js_transaction, "rollback")?; + let options = get_named_property(js_transaction, "options")?; + let options = from_js::(options); Ok(Self { - options: TransactionOptions::from_js(options).unwrap(), - commit: JsFunction::from(object.get("commit".into())?).into(), - rollback: JsFunction::from(object.get("rollback".into())?).into(), - closed, + commit, + rollback, + options, + closed: AtomicBool::new(false), }) } diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 784d4e0e2064..13a7c13e89ec 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -13,7 +13,16 @@ OUT_NPM_NAME="@prisma/query-engine-wasm" # This little `sed -i` trick below is a hack to publish "@prisma/query-engine-wasm" # with the same binding filenames currently expected by the Prisma Client. sed -i '' 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml -wasm-pack build --dev --target $OUT_TARGET + +# use `wasm-pack build --release` on CI only +if [[ -z "$BUILDKITE" ]] || [[ -z "$GITHUB_ACTIONS" ]]; then + BUILD_PROFILE="--release" +else + BUILD_PROFILE="--dev" +fi + +wasm-pack build $BUILD_PROFILE --target $OUT_TARGET + sed -i '' 's/name = "query_engine"/name = "query_engine_wasm"/g' Cargo.toml sleep 1 From d9b8e5737f38e53758e9af82cd8dc03c8c72ab06 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 12:35:16 +0100 Subject: [PATCH 085/134] chore(driver-adapters): unify napi/wasm logic queryable.rs and proxy.rs --- query-engine/driver-adapters/src/lib.rs | 23 ++- query-engine/driver-adapters/src/napi/mod.rs | 4 +- .../driver-adapters/src/napi/transaction.rs | 2 +- .../driver-adapters/src/{napi => }/proxy.rs | 54 +++++-- .../src/{queryable/mod.rs => queryable.rs} | 57 ++++--- .../driver-adapters/src/queryable/napi.rs | 32 ---- .../driver-adapters/src/queryable/wasm.rs | 32 ---- query-engine/driver-adapters/src/wasm/mod.rs | 4 +- .../driver-adapters/src/wasm/proxy.rs | 144 ------------------ .../driver-adapters/src/wasm/transaction.rs | 6 +- .../query-engine-node-api/src/engine.rs | 2 +- .../query-engine-wasm/example/package.json | 8 +- .../query-engine-wasm/example/pnpm-lock.yaml | 70 ++++++--- .../query-engine-wasm/src/wasm/engine.rs | 6 +- 14 files changed, 153 insertions(+), 291 deletions(-) rename query-engine/driver-adapters/src/{napi => }/proxy.rs (74%) rename query-engine/driver-adapters/src/{queryable/mod.rs => queryable.rs} (87%) delete mode 100644 query-engine/driver-adapters/src/queryable/napi.rs delete mode 100644 query-engine/driver-adapters/src/queryable/wasm.rs delete mode 100644 query-engine/driver-adapters/src/wasm/proxy.rs diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 7cf7f928cb5a..230dc1af83d7 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -9,25 +9,34 @@ pub(crate) mod conversion; pub(crate) mod error; +pub(crate) mod proxy; pub(crate) mod queryable; pub(crate) mod send_future; pub(crate) mod types; +pub use queryable::from_js; + +#[cfg(target_arch = "wasm32")] +pub use wasm::JsObjectExtern as JsObject; + +#[cfg(not(target_arch = "wasm32"))] +pub use ::napi::JsObject; + #[cfg(not(target_arch = "wasm32"))] pub mod napi; #[cfg(not(target_arch = "wasm32"))] -pub use napi::*; +pub(crate) use napi::*; #[cfg(target_arch = "wasm32")] pub mod wasm; #[cfg(target_arch = "wasm32")] -pub use wasm::*; +pub(crate) use wasm::*; #[cfg(target_arch = "wasm32")] mod arch { - pub(crate) use crate::JsObjectExtern as JsObject; + pub(crate) use super::JsObject; pub(crate) use js_sys::JsString; use tsify::Tsify; use wasm_bindgen::JsValue; @@ -36,7 +45,6 @@ mod arch { where T: From, { - // object.get("queryRaw".into())? Ok(object.get(name.into())?.into()) } @@ -44,7 +52,7 @@ mod arch { Ok(value.into()) } - pub(crate) fn from_js(value: JsValue) -> C + pub(crate) fn from_js_value(value: JsValue) -> C where C: Tsify + serde::de::DeserializeOwned, { @@ -56,8 +64,9 @@ mod arch { #[cfg(not(target_arch = "wasm32"))] mod arch { + pub(crate) use super::JsObject; use napi::bindgen_prelude::FromNapiValue; - pub(crate) use napi::{JsObject, JsString}; + pub(crate) use napi::JsString; pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult where @@ -70,7 +79,7 @@ mod arch { Ok(value.into_utf8()?.as_str()?.to_string()) } - pub(crate) fn from_js(value: C) -> C { + pub(crate) fn from_js_value(value: C) -> C { value } diff --git a/query-engine/driver-adapters/src/napi/mod.rs b/query-engine/driver-adapters/src/napi/mod.rs index 4612cb550553..69dd2caa6582 100644 --- a/query-engine/driver-adapters/src/napi/mod.rs +++ b/query-engine/driver-adapters/src/napi/mod.rs @@ -3,8 +3,8 @@ mod async_js_function; mod conversion; mod error; -pub(crate) mod proxy; mod result; mod transaction; -pub use crate::queryable::{from_napi, JsQueryable}; +pub(crate) use async_js_function::AsyncJsFunction; +pub(crate) use transaction::JsTransaction; diff --git a/query-engine/driver-adapters/src/napi/transaction.rs b/query-engine/driver-adapters/src/napi/transaction.rs index 69219d06ef1e..b32c408641bc 100644 --- a/query-engine/driver-adapters/src/napi/transaction.rs +++ b/query-engine/driver-adapters/src/napi/transaction.rs @@ -7,7 +7,7 @@ use quaint::{ Value, }; -use super::proxy::{CommonProxy, TransactionOptions, TransactionProxy}; +use crate::proxy::{CommonProxy, TransactionOptions, TransactionProxy}; use crate::queryable::JsBaseQueryable; // Wrapper around JS transaction objects that implements Queryable diff --git a/query-engine/driver-adapters/src/napi/proxy.rs b/query-engine/driver-adapters/src/proxy.rs similarity index 74% rename from query-engine/driver-adapters/src/napi/proxy.rs rename to query-engine/driver-adapters/src/proxy.rs index 298c4a2ba35b..7d66b798e20d 100644 --- a/query-engine/driver-adapters/src/napi/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -1,14 +1,19 @@ +use crate::send_future::SendFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; -use crate::{from_js, get_named_property, to_rust_str, JsObject, JsResult, JsString}; +use crate::{from_js_value, get_named_property, to_rust_str, JsObject, JsResult, JsString}; -use super::async_js_function::AsyncJsFunction; -use super::transaction::JsTransaction; +use crate::{AsyncJsFunction, JsTransaction}; +use futures::Future; use metrics::increment_gauge; use std::sync::atomic::{AtomicBool, Ordering}; +#[cfg(target_arch = "wasm32")] +use wasm_bindgen::prelude::wasm_bindgen; + /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for -/// querying and executing SQL (i.e. a client connector). The Proxy uses NAPI ThreadSafeFunction to -/// invoke the code within the node runtime that implements the client connector. +/// querying and executing SQL (i.e. a client connector). The Proxy uses Napi/Wasm's JsFunction +/// to invoke the code within the node runtime that implements the client connector. +#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter_with_clone))] pub(crate) struct CommonProxy { /// Execute a query given as SQL, interpolating the given parameters. query_raw: AsyncJsFunction, @@ -23,11 +28,14 @@ pub(crate) struct CommonProxy { /// This is a JS proxy for accessing the methods specific to top level /// JS driver objects +#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter_with_clone))] pub(crate) struct DriverProxy { start_transaction: AsyncJsFunction<(), JsTransaction>, } + /// This a JS proxy for accessing the methods, specific /// to JS transaction objects +#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter_with_clone))] pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, @@ -63,13 +71,13 @@ impl CommonProxy { } impl DriverProxy { - pub fn new(driver_adapter: &JsObject) -> JsResult { + pub fn new(object: &JsObject) -> JsResult { Ok(Self { - start_transaction: get_named_property(driver_adapter, "startTransaction")?, + start_transaction: get_named_property(object, "startTransaction")?, }) } - pub async fn start_transaction(&self) -> quaint::Result> { + async fn start_transaction_inner(&self) -> quaint::Result> { let tx = self.start_transaction.call(()).await?; // Decrement for this gauge is done in JsTransaction::commit/JsTransaction::rollback @@ -79,6 +87,12 @@ impl DriverProxy { increment_gauge!("prisma_client_queries_active", 1.0); Ok(Box::new(tx)) } + + pub fn start_transaction<'a>( + &'a self, + ) -> SendFuture>> + 'a> { + SendFuture(self.start_transaction_inner()) + } } impl TransactionProxy { @@ -86,7 +100,7 @@ impl TransactionProxy { let commit = get_named_property(js_transaction, "commit")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; - let options = from_js::(options); + let options = from_js_value::(options); Ok(Self { commit, @@ -115,9 +129,9 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling the transaction back even if the `commit` future was dropped while /// waiting on the JavaScript call to complete and deliver response. - pub async fn commit(&self) -> quaint::Result<()> { + pub fn commit<'a>(&'a self) -> SendFuture> + 'a> { self.closed.store(true, Ordering::Relaxed); - self.commit.call(()).await + SendFuture(self.commit.call(())) } /// Rolls back the transaction via the driver adapter. @@ -135,9 +149,9 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling back again even if the `rollback` future was dropped while waiting /// on the JavaScript call to complete and deliver response. - pub async fn rollback(&self) -> quaint::Result<()> { + pub fn rollback<'a>(&'a self) -> SendFuture> + 'a> { self.closed.store(true, Ordering::Relaxed); - self.rollback.call(()).await + SendFuture(self.rollback.call(())) } } @@ -150,3 +164,17 @@ impl Drop for TransactionProxy { _ = self.rollback.call_non_blocking(()); } } + +macro_rules! impl_send_sync_on_wasm { + ($struct:ident) => { + #[cfg(target_arch = "wasm32")] + unsafe impl Send for $struct {} + #[cfg(target_arch = "wasm32")] + unsafe impl Sync for $struct {} + }; +} + +// Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. +impl_send_sync_on_wasm!(TransactionProxy); +impl_send_sync_on_wasm!(DriverProxy); +impl_send_sync_on_wasm!(CommonProxy); diff --git a/query-engine/driver-adapters/src/queryable/mod.rs b/query-engine/driver-adapters/src/queryable.rs similarity index 87% rename from query-engine/driver-adapters/src/queryable/mod.rs rename to query-engine/driver-adapters/src/queryable.rs index 9cd2eb1c9b33..c3262902771b 100644 --- a/query-engine/driver-adapters/src/queryable/mod.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -1,26 +1,8 @@ -#[cfg(not(target_arch = "wasm32"))] -pub(crate) mod napi; +use crate::proxy::{CommonProxy, DriverProxy}; +use crate::types::{AdapterFlavour, Query}; +use crate::JsObject; -#[cfg(not(target_arch = "wasm32"))] -pub use napi::from_napi; - -#[cfg(not(target_arch = "wasm32"))] -pub(crate) use napi::JsBaseQueryable; - -#[cfg(target_arch = "wasm32")] -pub(crate) mod wasm; - -#[cfg(target_arch = "wasm32")] -pub use wasm::from_wasm; - -#[cfg(target_arch = "wasm32")] -pub(crate) use wasm::JsBaseQueryable; - -use super::{ - conversion, - proxy::{CommonProxy, DriverProxy, Query}, - types::AdapterFlavour, -}; +use super::conversion; use crate::send_future::SendFuture; use async_trait::async_trait; use futures::Future; @@ -32,6 +14,27 @@ use quaint::{ }; use tracing::{info_span, Instrument}; +#[cfg(target_arch = "wasm32")] +use wasm_bindgen::prelude::wasm_bindgen; + +/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the +/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint +/// types to types that can be translated into javascript and viceversa. This is to let the rest of +/// the query engine work as if it was using quaint itself. The aforementioned transformations are: +/// +/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL +/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client +/// connector) +/// +/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) +/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector +/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. +#[cfg_attr(target_arch = "wasm32", wasm_bindgen(getter_with_clone))] +pub(crate) struct JsBaseQueryable { + pub(crate) proxy: CommonProxy, + pub flavour: AdapterFlavour, +} + impl JsBaseQueryable { pub(crate) fn new(proxy: CommonProxy) -> Self { let flavour: AdapterFlavour = proxy.flavour.parse().unwrap(); @@ -305,3 +308,13 @@ impl TransactionCapable for JsQueryable { Ok(tx) } } + +pub fn from_js(driver: JsObject) -> JsQueryable { + let common = CommonProxy::new(&driver).unwrap(); + let driver_proxy = DriverProxy::new(&driver).unwrap(); + + JsQueryable { + inner: JsBaseQueryable::new(common), + driver_proxy, + } +} diff --git a/query-engine/driver-adapters/src/queryable/napi.rs b/query-engine/driver-adapters/src/queryable/napi.rs deleted file mode 100644 index b7f4cf49028b..000000000000 --- a/query-engine/driver-adapters/src/queryable/napi.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::napi::proxy::{CommonProxy, DriverProxy}; -use crate::JsQueryable; -use napi::JsObject; -use crate::types::AdapterFlavour; - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -/// -pub(crate) struct JsBaseQueryable { - pub(crate) proxy: CommonProxy, - pub flavour: AdapterFlavour, -} - -pub fn from_napi(driver: JsObject) -> JsQueryable { - let common = CommonProxy::new(&driver).unwrap(); - let driver_proxy = DriverProxy::new(&driver).unwrap(); - - JsQueryable { - inner: JsBaseQueryable::new(common), - driver_proxy, - } -} diff --git a/query-engine/driver-adapters/src/queryable/wasm.rs b/query-engine/driver-adapters/src/queryable/wasm.rs deleted file mode 100644 index ee1c65a81347..000000000000 --- a/query-engine/driver-adapters/src/queryable/wasm.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::types::AdapterFlavour; -use crate::wasm::proxy::{CommonProxy, DriverProxy}; -use crate::{JsObjectExtern, JsQueryable}; -use wasm_bindgen::prelude::wasm_bindgen; - -/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the -/// responsibility of transforming inputs and outputs of `query` and `execute` methods from quaint -/// types to types that can be translated into javascript and viceversa. This is to let the rest of -/// the query engine work as if it was using quaint itself. The aforementioned transformations are: -/// -/// Transforming a `quaint::ast::Query` into SQL by visiting it for the specific flavour of SQL -/// expected by the client connector. (eg. using the mysql visitor for the Planetscale client -/// connector) -/// -/// Transforming a `JSResultSet` (what client connectors implemented in javascript provide) -/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector -/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is. -#[wasm_bindgen(getter_with_clone)] -pub(crate) struct JsBaseQueryable { - pub(crate) proxy: CommonProxy, - pub flavour: AdapterFlavour, -} - -pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable { - let common = CommonProxy::new(&driver).unwrap(); - let driver_proxy = DriverProxy::new(&driver).unwrap(); - - JsQueryable { - inner: JsBaseQueryable::new(common), - driver_proxy, - } -} diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 2afe1987e1a7..e60d23ff2445 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -4,9 +4,9 @@ mod async_js_function; mod error; mod from_js; mod js_object_extern; -pub(crate) mod proxy; mod result; mod transaction; -pub use crate::queryable::{from_wasm, JsQueryable}; +pub(crate) use async_js_function::AsyncJsFunction; pub use js_object_extern::JsObjectExtern; +pub(crate) use transaction::JsTransaction; diff --git a/query-engine/driver-adapters/src/wasm/proxy.rs b/query-engine/driver-adapters/src/wasm/proxy.rs deleted file mode 100644 index 607e4440bbf7..000000000000 --- a/query-engine/driver-adapters/src/wasm/proxy.rs +++ /dev/null @@ -1,144 +0,0 @@ -use crate::send_future::SendFuture; -pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; -use crate::{from_js, get_named_property, to_rust_str, JsObject, JsResult, JsString}; - -use super::{async_js_function::AsyncJsFunction, transaction::JsTransaction}; -use futures::Future; -use metrics::increment_gauge; -use std::sync::atomic::{AtomicBool, Ordering}; -use wasm_bindgen::prelude::wasm_bindgen; - -/// Proxy is a struct wrapping a javascript object that exhibits basic primitives for -/// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to -/// invoke the code within the node runtime that implements the client connector. -#[wasm_bindgen(getter_with_clone)] -pub(crate) struct CommonProxy { - /// Execute a query given as SQL, interpolating the given parameters. - query_raw: AsyncJsFunction, - - /// Execute a query given as SQL, interpolating the given parameters and - /// returning the number of affected rows. - execute_raw: AsyncJsFunction, - - /// Return the flavour for this driver. - pub(crate) flavour: String, -} - -/// This is a JS proxy for accessing the methods specific to top level -/// JS driver objects -#[wasm_bindgen(getter_with_clone)] -pub(crate) struct DriverProxy { - start_transaction: AsyncJsFunction<(), JsTransaction>, -} - -/// This a JS proxy for accessing the methods, specific -/// to JS transaction objects -#[wasm_bindgen(getter_with_clone)] -pub(crate) struct TransactionProxy { - /// transaction options - options: TransactionOptions, - - /// commit transaction - commit: AsyncJsFunction<(), ()>, - - /// rollback transaction - rollback: AsyncJsFunction<(), ()>, - - /// whether the transaction has already been committed or rolled back - closed: AtomicBool, -} - -impl CommonProxy { - pub fn new(object: &JsObject) -> JsResult { - let flavour: JsString = get_named_property(object, "flavour")?; - - Ok(Self { - query_raw: get_named_property(object, "queryRaw")?, - execute_raw: get_named_property(object, "executeRaw")?, - flavour: to_rust_str(flavour)?, - }) - } - - pub async fn query_raw(&self, params: Query) -> quaint::Result { - self.query_raw.call(params).await - } - - pub async fn execute_raw(&self, params: Query) -> quaint::Result { - self.execute_raw.call(params).await - } -} - -impl DriverProxy { - pub fn new(object: &JsObject) -> JsResult { - Ok(Self { - start_transaction: get_named_property(object, "startTransaction")?, - }) - } - - async fn start_transaction_inner(&self) -> quaint::Result> { - let tx = self.start_transaction.call(()).await?; - - // Decrement for this gauge is done in JsTransaction::commit/JsTransaction::rollback - // Previously, it was done in JsTransaction::new, similar to the native Transaction. - // However, correct Dispatcher is lost there and increment does not register, so we moved - // it here instead. - increment_gauge!("prisma_client_queries_active", 1.0); - Ok(Box::new(tx)) - } - - pub fn start_transaction<'a>( - &'a self, - ) -> SendFuture>> + 'a> { - SendFuture(self.start_transaction_inner()) - } -} - -impl TransactionProxy { - pub fn new(js_transaction: &JsObject) -> JsResult { - let commit = get_named_property(js_transaction, "commit")?; - let rollback = get_named_property(js_transaction, "rollback")?; - let options = get_named_property(js_transaction, "options")?; - let options = from_js::(options); - - Ok(Self { - commit, - rollback, - options, - closed: AtomicBool::new(false), - }) - } - - pub fn options(&self) -> &TransactionOptions { - &self.options - } - - pub fn commit<'a>(&'a self) -> SendFuture> + 'a> { - self.closed.store(true, Ordering::Relaxed); - SendFuture(self.commit.call(())) - } - - pub fn rollback<'a>(&'a self) -> SendFuture> + 'a> { - self.closed.store(true, Ordering::Relaxed); - SendFuture(self.rollback.call(())) - } -} - -impl Drop for TransactionProxy { - fn drop(&mut self) { - if self.closed.swap(true, Ordering::Relaxed) { - return; - } - - _ = self.rollback.call_non_blocking(()); - } -} - -// Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. -unsafe impl Send for TransactionProxy {} -unsafe impl Sync for TransactionProxy {} - -unsafe impl Send for DriverProxy {} -unsafe impl Sync for DriverProxy {} - -unsafe impl Send for CommonProxy {} -unsafe impl Sync for CommonProxy {} diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/wasm/transaction.rs index b9eac2965e48..1d543ce6def3 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/wasm/transaction.rs @@ -8,10 +8,8 @@ use quaint::{ }; use wasm_bindgen::JsCast; -use super::{ - from_js::FromJsValue, - proxy::{TransactionOptions, TransactionProxy}, -}; +use super::from_js::FromJsValue; +use crate::proxy::{TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::SendFuture, JsObjectExtern}; // Wrapper around JS transaction objects that implements Queryable diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 23782af1776a..1d56239ecf6d 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -192,7 +192,7 @@ impl QueryEngine { } else { #[cfg(feature = "driver-adapters")] if let Some(adapter) = maybe_adapter { - let js_queryable = driver_adapters::from_napi(adapter); + let js_queryable = driver_adapters::from_js(adapter); sql_connector::activate_driver_adapter(Arc::new(js_queryable)); connector_mode = ConnectorMode::Js; diff --git a/query-engine/query-engine-wasm/example/package.json b/query-engine/query-engine-wasm/example/package.json index bb6d7b868ede..3b0c4c91c9f9 100644 --- a/query-engine/query-engine-wasm/example/package.json +++ b/query-engine/query-engine-wasm/example/package.json @@ -6,9 +6,9 @@ }, "dependencies": { "@libsql/client": "0.4.0-pre.2", - "@prisma/adapter-libsql": "5.6.0", - "@prisma/client": "5.6.0", - "@prisma/driver-adapter-utils": "5.6.0", - "prisma": "5.6.0" + "@prisma/adapter-libsql": "5.7.0-dev.54", + "@prisma/client": "5.7.0-dev.54", + "@prisma/driver-adapter-utils": "5.7.0-dev.54", + "prisma": "5.7.0-dev.54" } } diff --git a/query-engine/query-engine-wasm/example/pnpm-lock.yaml b/query-engine/query-engine-wasm/example/pnpm-lock.yaml index 887edea0e8cc..beb050a5398a 100644 --- a/query-engine/query-engine-wasm/example/pnpm-lock.yaml +++ b/query-engine/query-engine-wasm/example/pnpm-lock.yaml @@ -9,17 +9,17 @@ dependencies: specifier: 0.4.0-pre.2 version: 0.4.0-pre.2 '@prisma/adapter-libsql': - specifier: 5.6.0 - version: 5.6.0(@libsql/client@0.4.0-pre.2) + specifier: 5.7.0-dev.54 + version: 5.7.0-dev.54(@libsql/client@0.4.0-pre.2) '@prisma/client': - specifier: 5.6.0 - version: 5.6.0(prisma@5.6.0) + specifier: 5.7.0-dev.54 + version: 5.7.0-dev.54(prisma@5.7.0-dev.54) '@prisma/driver-adapter-utils': - specifier: 5.6.0 - version: 5.6.0 + specifier: 5.7.0-dev.54 + version: 5.7.0-dev.54 prisma: - specifier: 5.6.0 - version: 5.6.0 + specifier: 5.7.0-dev.54 + version: 5.7.0-dev.54 packages: @@ -127,20 +127,20 @@ packages: resolution: {integrity: sha512-kTPhdZyTQxB+2wpiRcFWrDcejc4JI6tkPuS7UZCG4l6Zvc5kU/gGQ/ozvHTh1XR5tS+UlfAfGuPajjzQjCiHCw==} dev: false - /@prisma/adapter-libsql@5.6.0(@libsql/client@0.4.0-pre.2): - resolution: {integrity: sha512-XFDLw9QqEDDVXAe8YdX8TL4mCiolDijjxh8HQRJ33VcuujGnAWWpBKE35MKfIsuONVyNXFthB/Gky/MlmMcE6Q==} + /@prisma/adapter-libsql@5.7.0-dev.54(@libsql/client@0.4.0-pre.2): + resolution: {integrity: sha512-P+npdjsKYGv3bW4XWDEruLFAaih9ECZI7vH90DeWY3AOAQY9Siy9bKecsTmCTeKYscrqKVgP1uK3MRHacvWhyQ==} peerDependencies: '@libsql/client': ^0.3.5 dependencies: '@libsql/client': 0.4.0-pre.2 - '@prisma/driver-adapter-utils': 5.6.0 + '@prisma/driver-adapter-utils': 5.7.0-dev.54 async-mutex: 0.4.0 transitivePeerDependencies: - supports-color dev: false - /@prisma/client@5.6.0(prisma@5.6.0): - resolution: {integrity: sha512-mUDefQFa1wWqk4+JhKPYq8BdVoFk9NFMBXUI8jAkBfQTtgx8WPx02U2HB/XbAz3GSUJpeJOKJQtNvaAIDs6sug==} + /@prisma/client@5.7.0-dev.54(prisma@5.7.0-dev.54): + resolution: {integrity: sha512-WjR+Cpfssce60M6FSXHjFpH+hFLUAfsRxAbRJTw2+W2HdJyZcVXF4FTqCZLZVaBF5NW/60fK3K3aRnMuEvsDtA==} engines: {node: '>=16.13'} requiresBuild: true peerDependencies: @@ -149,25 +149,47 @@ packages: prisma: optional: true dependencies: - '@prisma/engines-version': 5.6.0-32.e95e739751f42d8ca026f6b910f5a2dc5adeaeee - prisma: 5.6.0 + prisma: 5.7.0-dev.54 dev: false - /@prisma/driver-adapter-utils@5.6.0: - resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} + /@prisma/debug@5.7.0-dev.54: + resolution: {integrity: sha512-5KodpKA1Th05sREvQoQ4U8oJa8QFXPjxzE5AduzYLHjXibgd18p2//c0wtU9erP7jgLFC9vrvlSsWhjsAyc0fA==} + dev: false + + /@prisma/driver-adapter-utils@5.7.0-dev.54: + resolution: {integrity: sha512-5wGFzahzgIPgDjuVpU8hisB71RYDVtIeYord920PAW//ZnHPvS6yHg1+O+z/PMndV5iL9UP5EJDx19LpmH+sDg==} dependencies: debug: 4.3.4 transitivePeerDependencies: - supports-color dev: false - /@prisma/engines-version@5.6.0-32.e95e739751f42d8ca026f6b910f5a2dc5adeaeee: - resolution: {integrity: sha512-UoFgbV1awGL/3wXuUK3GDaX2SolqczeeJ5b4FVec9tzeGbSWJboPSbT0psSrmgYAKiKnkOPFSLlH6+b+IyOwAw==} + /@prisma/engines-version@5.7.0-20.01aad9b63c8d574cc270d2b09461e920d19986e6: + resolution: {integrity: sha512-aKw2Ge9kZQrU5DRxqQ9xwyksH6aFtJR4BIuUDSevkFbrq3PFy/SBhLE4RWVfJmYqWs5/BBat7ZP3T5xK178liQ==} dev: false - /@prisma/engines@5.6.0: - resolution: {integrity: sha512-Mt2q+GNJpU2vFn6kif24oRSBQv1KOkYaterQsi0k2/lA+dLvhRX6Lm26gon6PYHwUM8/h8KRgXIUMU0PCLB6bw==} + /@prisma/engines@5.7.0-dev.54: + resolution: {integrity: sha512-qeV4+hbQFaVqUw3CRpRhFt0/W+7BzLZ8RFhuVF9tOqdcZ0Mu5ktdX0pevbWtJHMnbqt9nrcxVv42Ok9mqJ2mFA==} requiresBuild: true + dependencies: + '@prisma/debug': 5.7.0-dev.54 + '@prisma/engines-version': 5.7.0-20.01aad9b63c8d574cc270d2b09461e920d19986e6 + '@prisma/fetch-engine': 5.7.0-dev.54 + '@prisma/get-platform': 5.7.0-dev.54 + dev: false + + /@prisma/fetch-engine@5.7.0-dev.54: + resolution: {integrity: sha512-pFm+hWMS3zSrjyvlTY8JQWYL9jRCVyEOc/qt1sIe/EILQsQfKjweOk6yyQm4+wLId+otjc738A6+wLVjBfoiNw==} + dependencies: + '@prisma/debug': 5.7.0-dev.54 + '@prisma/engines-version': 5.7.0-20.01aad9b63c8d574cc270d2b09461e920d19986e6 + '@prisma/get-platform': 5.7.0-dev.54 + dev: false + + /@prisma/get-platform@5.7.0-dev.54: + resolution: {integrity: sha512-5vbvS2qo1QtWam4oQKbrVo9kC5YVODTlF3p3GqlrPACy8B4wEvGd2MLEtFb9UQl3gCOcvZNZmMx+hm2aV/f2Fw==} + dependencies: + '@prisma/debug': 5.7.0-dev.54 dev: false /@types/node-fetch@2.6.9: @@ -320,13 +342,13 @@ packages: formdata-polyfill: 4.0.10 dev: false - /prisma@5.6.0: - resolution: {integrity: sha512-EEaccku4ZGshdr2cthYHhf7iyvCcXqwJDvnoQRAJg5ge2Tzpv0e2BaMCp+CbbDUwoVTzwgOap9Zp+d4jFa2O9A==} + /prisma@5.7.0-dev.54: + resolution: {integrity: sha512-+dpJABpFg6l4DTSSCGBIxgrRPJ3QMDtiB3SB56UOk5vX2guG96+yU46N1WWwSCgZirwhy3IR1zVuhmRZFmatSA==} engines: {node: '>=16.13'} hasBin: true requiresBuild: true dependencies: - '@prisma/engines': 5.6.0 + '@prisma/engines': 5.7.0-dev.54 dev: false /tr46@0.0.3: diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 183d4940d1c8..1b3b51653c7a 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -5,7 +5,7 @@ use crate::{ error::ApiError, logger::{LogCallback, Logger}, }; -use driver_adapters::JsObjectExtern; +use driver_adapters::JsObject; use futures::FutureExt; use js_sys::Function as JsFunction; use query_core::{ @@ -135,7 +135,7 @@ impl QueryEngine { pub fn new( options: ConstructorOptions, callback: JsFunction, - maybe_adapter: Option, + maybe_adapter: Option, ) -> Result { log::info!("Called `QueryEngine::new()`"); @@ -161,7 +161,7 @@ impl QueryEngine { let preview_features = config.preview_features(); if let Some(adapter) = maybe_adapter { - let js_queryable = driver_adapters::from_wasm(adapter); + let js_queryable = driver_adapters::from_js(adapter); sql_connector::activate_driver_adapter(Arc::new(js_queryable)); From 7a77d975f89a1ca375e8ac3e8dd1279138535708 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 12:50:22 +0100 Subject: [PATCH 086/134] chore(driver-adapters): unify napi/wasm logic for transaction.rs --- query-engine/driver-adapters/src/lib.rs | 2 + query-engine/driver-adapters/src/napi/mod.rs | 2 - .../driver-adapters/src/napi/transaction.rs | 134 ------------------ query-engine/driver-adapters/src/proxy.rs | 1 + .../driver-adapters/src/send_future.rs | 7 +- .../src/{wasm => }/transaction.rs | 46 +++--- .../driver-adapters/src/wasm/from_js.rs | 2 +- query-engine/driver-adapters/src/wasm/mod.rs | 3 +- 8 files changed, 38 insertions(+), 159 deletions(-) delete mode 100644 query-engine/driver-adapters/src/napi/transaction.rs rename query-engine/driver-adapters/src/{wasm => }/transaction.rs (81%) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 230dc1af83d7..9873ac0d994a 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -12,9 +12,11 @@ pub(crate) mod error; pub(crate) mod proxy; pub(crate) mod queryable; pub(crate) mod send_future; +pub(crate) mod transaction; pub(crate) mod types; pub use queryable::from_js; +pub(crate) use transaction::JsTransaction; #[cfg(target_arch = "wasm32")] pub use wasm::JsObjectExtern as JsObject; diff --git a/query-engine/driver-adapters/src/napi/mod.rs b/query-engine/driver-adapters/src/napi/mod.rs index 69dd2caa6582..c9bb8d24ac33 100644 --- a/query-engine/driver-adapters/src/napi/mod.rs +++ b/query-engine/driver-adapters/src/napi/mod.rs @@ -4,7 +4,5 @@ mod async_js_function; mod conversion; mod error; mod result; -mod transaction; pub(crate) use async_js_function::AsyncJsFunction; -pub(crate) use transaction::JsTransaction; diff --git a/query-engine/driver-adapters/src/napi/transaction.rs b/query-engine/driver-adapters/src/napi/transaction.rs deleted file mode 100644 index b32c408641bc..000000000000 --- a/query-engine/driver-adapters/src/napi/transaction.rs +++ /dev/null @@ -1,134 +0,0 @@ -use async_trait::async_trait; -use metrics::decrement_gauge; -use napi::{bindgen_prelude::FromNapiValue, JsObject}; -use quaint::{ - connector::{IsolationLevel, Transaction as QuaintTransaction}, - prelude::{Query as QuaintQuery, Queryable, ResultSet}, - Value, -}; - -use crate::proxy::{CommonProxy, TransactionOptions, TransactionProxy}; -use crate::queryable::JsBaseQueryable; - -// Wrapper around JS transaction objects that implements Queryable -// and quaint::Transaction. Can be used in place of quaint transaction, -// but delegates most operations to JS -pub(crate) struct JsTransaction { - tx_proxy: TransactionProxy, - inner: JsBaseQueryable, -} - -impl JsTransaction { - pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } - } - - pub fn options(&self) -> &TransactionOptions { - self.tx_proxy.options() - } - - pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { - let params = &[]; - quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await - } -} - -#[async_trait] -impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { - // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); - - let commit_stmt = "COMMIT"; - - if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); - self.raw_phantom_cmd(commit_stmt.as_str()).await?; - } else { - self.inner.raw_cmd(commit_stmt).await?; - } - - self.tx_proxy.commit().await - } - - async fn rollback(&self) -> quaint::Result<()> { - // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); - - let rollback_stmt = "ROLLBACK"; - - if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); - self.raw_phantom_cmd(rollback_stmt.as_str()).await?; - } else { - self.inner.raw_cmd(rollback_stmt).await?; - } - - self.tx_proxy.rollback().await - } - - fn as_queryable(&self) -> &dyn Queryable { - self - } -} - -#[async_trait] -impl Queryable for JsTransaction { - async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.query(q).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.query_raw(sql, params).await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.query_raw_typed(sql, params).await - } - - async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { - self.inner.execute(q).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.execute_raw(sql, params).await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - self.inner.execute_raw_typed(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - self.inner.raw_cmd(cmd).await - } - - async fn version(&self) -> quaint::Result> { - self.inner.version().await - } - - fn is_healthy(&self) -> bool { - self.inner.is_healthy() - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - self.inner.set_tx_isolation_level(isolation_level).await - } - - fn requires_isolation_first(&self) -> bool { - self.inner.requires_isolation_first() - } -} - -/// Implementing unsafe `from_napi_value` is only way I managed to get threadsafe -/// JsTransaction value in `DriverProxy`. Going through any intermediate safe napi.rs value, -/// like `JsObject` or `JsUnknown` wrapped inside `JsPromise` makes it impossible to extract the value -/// out of promise while keeping the future `Send`. -impl FromNapiValue for JsTransaction { - unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { - let object = JsObject::from_napi_value(env, napi_val)?; - let common_proxy = CommonProxy::new(&object)?; - let tx_proxy = TransactionProxy::new(&object)?; - - Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) - } -} diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 7d66b798e20d..7e01176be12f 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -178,3 +178,4 @@ macro_rules! impl_send_sync_on_wasm { impl_send_sync_on_wasm!(TransactionProxy); impl_send_sync_on_wasm!(DriverProxy); impl_send_sync_on_wasm!(CommonProxy); +impl_send_sync_on_wasm!(JsTransaction); diff --git a/query-engine/driver-adapters/src/send_future.rs b/query-engine/driver-adapters/src/send_future.rs index 61c64a960450..ed5e78345afd 100644 --- a/query-engine/driver-adapters/src/send_future.rs +++ b/query-engine/driver-adapters/src/send_future.rs @@ -11,8 +11,6 @@ use futures::Future; #[pin_project::pin_project] pub struct SendFuture(#[pin] pub F); -unsafe impl Send for SendFuture {} - impl Future for SendFuture { type Output = F::Output; @@ -22,3 +20,8 @@ impl Future for SendFuture { future.poll(cx) } } + +// Note: on Napi.rs, we require the underlying future to be `Send`. +// On Wasm, that's currently not possible. +#[cfg(target_arch = "wasm32")] +unsafe impl Send for SendFuture {} diff --git a/query-engine/driver-adapters/src/wasm/transaction.rs b/query-engine/driver-adapters/src/transaction.rs similarity index 81% rename from query-engine/driver-adapters/src/wasm/transaction.rs rename to query-engine/driver-adapters/src/transaction.rs index 1d543ce6def3..7f7180ae6d30 100644 --- a/query-engine/driver-adapters/src/wasm/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,16 +1,14 @@ use async_trait::async_trait; -use js_sys::Object as JsObject; use metrics::decrement_gauge; use quaint::{ connector::{IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use wasm_bindgen::JsCast; -use super::from_js::FromJsValue; use crate::proxy::{TransactionOptions, TransactionProxy}; -use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::SendFuture, JsObjectExtern}; +use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::SendFuture}; +use crate::{JsObject, JsResult}; // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, @@ -35,17 +33,6 @@ impl JsTransaction { } } -impl FromJsValue for JsTransaction { - fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> Result { - let object: JsObjectExtern = value.dyn_into::()?.unchecked_into(); - let common_proxy = CommonProxy::new(&object)?; - let base = JsBaseQueryable::new(common_proxy); - let tx_proxy = TransactionProxy::new(&object)?; - - Ok(Self::new(base, tx_proxy)) - } -} - #[async_trait] impl QuaintTransaction for JsTransaction { async fn commit(&self) -> quaint::Result<()> { @@ -132,6 +119,29 @@ impl Queryable for JsTransaction { } } -// Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. -unsafe impl Send for JsTransaction {} -unsafe impl Sync for JsTransaction {} +#[cfg(target_arch = "wasm32")] +impl super::wasm::FromJsValue for JsTransaction { + fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> JsResult { + use wasm_bindgen::JsCast; + + let object = value.dyn_into::()?; + let common_proxy = CommonProxy::new(&object)?; + let base = JsBaseQueryable::new(common_proxy); + let tx_proxy = TransactionProxy::new(&object)?; + + Ok(Self::new(base, tx_proxy)) + } +} + +/// Implementing unsafe `from_napi_value` allows retrieving a threadsafe `JsTransaction` in `DriverProxy` +/// while keeping derived futures `Send`. +#[cfg(not(target_arch = "wasm32"))] +impl ::napi::bindgen_prelude::FromNapiValue for JsTransaction { + unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> JsResult { + let object = JsObject::from_napi_value(env, napi_val)?; + let common_proxy = CommonProxy::new(&object)?; + let tx_proxy = TransactionProxy::new(&object)?; + + Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) + } +} diff --git a/query-engine/driver-adapters/src/wasm/from_js.rs b/query-engine/driver-adapters/src/wasm/from_js.rs index aaa0d91223f6..9195ea4dabef 100644 --- a/query-engine/driver-adapters/src/wasm/from_js.rs +++ b/query-engine/driver-adapters/src/wasm/from_js.rs @@ -1,7 +1,7 @@ use serde::de::DeserializeOwned; use wasm_bindgen::JsValue; -pub trait FromJsValue: Sized { +pub(crate) trait FromJsValue: Sized { fn from_js_value(value: JsValue) -> Result; } diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index e60d23ff2445..655ea1a6080d 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -5,8 +5,7 @@ mod error; mod from_js; mod js_object_extern; mod result; -mod transaction; pub(crate) use async_js_function::AsyncJsFunction; +pub(crate) use from_js::FromJsValue; pub use js_object_extern::JsObjectExtern; -pub(crate) use transaction::JsTransaction; From 91bd6964c00931f06298a13b462626d09f793119 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 12:55:37 +0100 Subject: [PATCH 087/134] chore(driver-adapters): clippy fixes --- query-engine/driver-adapters/src/proxy.rs | 10 ++++------ query-engine/driver-adapters/src/wasm/from_js.rs | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 7e01176be12f..001aa3587e8d 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -88,9 +88,7 @@ impl DriverProxy { Ok(Box::new(tx)) } - pub fn start_transaction<'a>( - &'a self, - ) -> SendFuture>> + 'a> { + pub fn start_transaction(&self) -> SendFuture>> + '_> { SendFuture(self.start_transaction_inner()) } } @@ -129,7 +127,7 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling the transaction back even if the `commit` future was dropped while /// waiting on the JavaScript call to complete and deliver response. - pub fn commit<'a>(&'a self) -> SendFuture> + 'a> { + pub fn commit(&self) -> SendFuture> + '_> { self.closed.store(true, Ordering::Relaxed); SendFuture(self.commit.call(())) } @@ -149,7 +147,7 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling back again even if the `rollback` future was dropped while waiting /// on the JavaScript call to complete and deliver response. - pub fn rollback<'a>(&'a self) -> SendFuture> + 'a> { + pub fn rollback(&self) -> SendFuture> + '_> { self.closed.store(true, Ordering::Relaxed); SendFuture(self.rollback.call(())) } @@ -161,7 +159,7 @@ impl Drop for TransactionProxy { return; } - _ = self.rollback.call_non_blocking(()); + self.rollback.call_non_blocking(()); } } diff --git a/query-engine/driver-adapters/src/wasm/from_js.rs b/query-engine/driver-adapters/src/wasm/from_js.rs index 9195ea4dabef..a49095ddbff1 100644 --- a/query-engine/driver-adapters/src/wasm/from_js.rs +++ b/query-engine/driver-adapters/src/wasm/from_js.rs @@ -10,6 +10,6 @@ where T: DeserializeOwned, { fn from_js_value(value: JsValue) -> Result { - serde_wasm_bindgen::from_value(value).map_err(|e| JsValue::from(e)) + serde_wasm_bindgen::from_value(value).map_err(JsValue::from) } } From 50f8601b7d69faaffa5ec47044acf69c3e5ba178 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 14:06:28 +0100 Subject: [PATCH 088/134] chore(driver-adapters): cli ppy fixes --- query-engine/driver-adapters/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 9873ac0d994a..720a951da60a 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -74,7 +74,7 @@ mod arch { where T: FromNapiValue, { - object.get_named_property(name).into() + object.get_named_property(name) } pub(crate) fn to_rust_str(value: JsString) -> JsResult { From e1867df360a02926a408dec0cb5c976afc9a9c0b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 14:11:17 +0100 Subject: [PATCH 089/134] chore: remove dbg! output --- Cargo.lock | 1 - query-engine/core/Cargo.toml | 1 - query-engine/core/src/executor/task.rs | 6 ------ query-engine/core/src/interactive_transactions/actors.rs | 2 -- 4 files changed, 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 290a1e4bf28f..893b1b9a345e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3701,7 +3701,6 @@ dependencies = [ "user-facing-errors", "uuid", "wasm-bindgen-futures", - "wasm-rs-dbg", ] [[package]] diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index a43006b89160..370ce6b81ec5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,7 +37,6 @@ schema = { path = "../schema" } elapsed = { path = "../../libs/elapsed" } lru = "0.7.7" enumflags2 = "0.7" -wasm-rs-dbg.workspace = true pin-project = "1" wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs index 1fe69c240d2c..f127da55fc00 100644 --- a/query-engine/core/src/executor/task.rs +++ b/query-engine/core/src/executor/task.rs @@ -62,7 +62,6 @@ mod arch { broadcast::{self}, oneshot::{self}, }; - use wasm_rs_dbg::dbg; // Wasm-compatible alternative to `tokio::task::JoinHandle`. // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. @@ -78,8 +77,6 @@ mod arch { type Output = Result; fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - dbg!("JoinHandle::poll"); - // the `self.project()` method is provided by the `pin_project` macro core::pin::Pin::new(&mut self.receiver).poll(cx) } @@ -87,10 +84,7 @@ mod arch { impl JoinHandle { pub fn abort(&mut self) { - dbg!("JoinHandle::abort"); - if let Some(sx_exit) = self.sx_exit.as_ref() { - dbg!("JoinHandle::abort - Send sx_exit"); sx_exit.send(()).ok(); } } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 58d24c528261..5af01260a921 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -392,7 +392,6 @@ pub(crate) fn spawn_client_list_clear_actor( loop { tokio::select! { result = rx.recv() => { - dbg!("spawn_controlled - AFTER rx.recv(): {:?}", result.is_some()); match result { Some((id, closed_tx)) => { trace!("removing {} from client list", id); @@ -411,7 +410,6 @@ pub(crate) fn spawn_client_list_clear_actor( } }, _ = rx_exit.recv() => { - dbg!("spawn_controlled - AFTER rx_exit.recv()"); break; }, } From 2dd815942d6e0074fdffd8cb56736904dd5fcb16 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 28 Nov 2023 14:13:24 +0100 Subject: [PATCH 090/134] chore: remove dbg! output --- query-engine/core/src/interactive_transactions/actors.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 5af01260a921..2ff910d60702 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -18,7 +18,6 @@ use tokio::{ use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; -use wasm_rs_dbg::dbg; #[cfg(feature = "metrics")] use crate::telemetry::helpers::set_span_link_from_traceparent; From 3d82368ecd0145a18612334793f5d3b146154d7b Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Tue, 28 Nov 2023 14:54:45 +0100 Subject: [PATCH 091/134] Fix itx panic --- Cargo.lock | 21 ++++--- libs/crosstarget-utils/Cargo.toml | 17 ++++++ libs/crosstarget-utils/src/common.rs | 35 +++++++++++ .../{elapsed => crosstarget-utils}/src/lib.rs | 7 ++- libs/crosstarget-utils/src/native/mod.rs | 2 + libs/crosstarget-utils/src/native/spawn.rs | 11 ++++ libs/crosstarget-utils/src/native/time.rs | 35 +++++++++++ libs/crosstarget-utils/src/wasm/mod.rs | 2 + libs/crosstarget-utils/src/wasm/spawn.rs | 10 ++++ libs/crosstarget-utils/src/wasm/time.rs | 59 +++++++++++++++++++ libs/elapsed/Cargo.toml | 11 ---- libs/elapsed/src/native.rs | 17 ------ libs/elapsed/src/wasm.rs | 15 ----- quaint/Cargo.toml | 2 +- quaint/src/connector/metrics.rs | 2 +- query-engine/core/Cargo.toml | 2 +- .../core/src/executor/execute_operation.rs | 6 +- .../src/executor/interpreting_executor.rs | 10 ++-- query-engine/core/src/executor/mod.rs | 6 +- query-engine/core/src/executor/task.rs | 6 +- .../src/interactive_transactions/actors.rs | 6 +- .../core/src/interactive_transactions/mod.rs | 2 +- 22 files changed, 211 insertions(+), 73 deletions(-) create mode 100644 libs/crosstarget-utils/Cargo.toml create mode 100644 libs/crosstarget-utils/src/common.rs rename libs/{elapsed => crosstarget-utils}/src/lib.rs (63%) create mode 100644 libs/crosstarget-utils/src/native/mod.rs create mode 100644 libs/crosstarget-utils/src/native/spawn.rs create mode 100644 libs/crosstarget-utils/src/native/time.rs create mode 100644 libs/crosstarget-utils/src/wasm/mod.rs create mode 100644 libs/crosstarget-utils/src/wasm/spawn.rs create mode 100644 libs/crosstarget-utils/src/wasm/time.rs delete mode 100644 libs/elapsed/Cargo.toml delete mode 100644 libs/elapsed/src/native.rs delete mode 100644 libs/elapsed/src/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 893b1b9a345e..f9c731261f94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -848,6 +848,16 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crosstarget-utils" +version = "0.1.0" +dependencies = [ + "js-sys", + "tokio", + "wasm-bindgen", + "wasm-bindgen-futures", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -1130,13 +1140,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" -[[package]] -name = "elapsed" -version = "0.1.0" -dependencies = [ - "js-sys", -] - [[package]] name = "encode_unicode" version = "0.3.6" @@ -3571,8 +3574,8 @@ dependencies = [ "bytes", "chrono", "connection-string", + "crosstarget-utils", "either", - "elapsed", "futures", "getrandom 0.2.11", "hex", @@ -3674,8 +3677,8 @@ dependencies = [ "chrono", "connection-string", "crossbeam-channel", + "crosstarget-utils", "cuid", - "elapsed", "enumflags2", "futures", "indexmap 1.9.3", diff --git a/libs/crosstarget-utils/Cargo.toml b/libs/crosstarget-utils/Cargo.toml new file mode 100644 index 000000000000..6fd110652afe --- /dev/null +++ b/libs/crosstarget-utils/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "crosstarget-utils" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +js-sys.workspace = true +wasm-bindgen.workspace = true +wasm-bindgen-futures.workspace = true +tokio = { version = "1.25", features = ["macros"] } + + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio.workspace = true diff --git a/libs/crosstarget-utils/src/common.rs b/libs/crosstarget-utils/src/common.rs new file mode 100644 index 000000000000..3afce64c6714 --- /dev/null +++ b/libs/crosstarget-utils/src/common.rs @@ -0,0 +1,35 @@ +use std::fmt::Display; + +#[derive(Debug)] +pub struct SpawnError {} + +impl SpawnError { + pub fn new() -> Self { + SpawnError {} + } +} + +impl Display for SpawnError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to spawn a future") + } +} + +impl std::error::Error for SpawnError {} + +#[derive(Debug)] +pub struct TimeoutError {} + +impl TimeoutError { + pub fn new() -> Self { + TimeoutError {} + } +} + +impl Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Operation timed out") + } +} + +impl std::error::Error for TimeoutError {} diff --git a/libs/elapsed/src/lib.rs b/libs/crosstarget-utils/src/lib.rs similarity index 63% rename from libs/elapsed/src/lib.rs rename to libs/crosstarget-utils/src/lib.rs index d339f75f836d..a41d8dd0f9a6 100644 --- a/libs/elapsed/src/lib.rs +++ b/libs/crosstarget-utils/src/lib.rs @@ -1,9 +1,12 @@ +mod common; #[cfg(target_arch = "wasm32")] mod wasm; #[cfg(target_arch = "wasm32")] -pub use crate::wasm::ElapsedTimeCounter; +pub use crate::wasm::*; #[cfg(not(target_arch = "wasm32"))] mod native; #[cfg(not(target_arch = "wasm32"))] -pub use crate::native::ElapsedTimeCounter; +pub use crate::native::*; + +pub use common::SpawnError; diff --git a/libs/crosstarget-utils/src/native/mod.rs b/libs/crosstarget-utils/src/native/mod.rs new file mode 100644 index 000000000000..b801d82e3118 --- /dev/null +++ b/libs/crosstarget-utils/src/native/mod.rs @@ -0,0 +1,2 @@ +pub mod spawn; +pub mod time; diff --git a/libs/crosstarget-utils/src/native/spawn.rs b/libs/crosstarget-utils/src/native/spawn.rs new file mode 100644 index 000000000000..b0d541258c2a --- /dev/null +++ b/libs/crosstarget-utils/src/native/spawn.rs @@ -0,0 +1,11 @@ +use std::future::Future; + +use crate::common::SpawnError; + +pub async fn spawn_if_possible(future: F) -> Result +where + F: Future + 'static + Send, + F::Output: Send + 'static, +{ + tokio::spawn(future).await.map_err(|_| SpawnError::new()) +} diff --git a/libs/crosstarget-utils/src/native/time.rs b/libs/crosstarget-utils/src/native/time.rs new file mode 100644 index 000000000000..e222e08cf628 --- /dev/null +++ b/libs/crosstarget-utils/src/native/time.rs @@ -0,0 +1,35 @@ +use std::{ + future::Future, + time::{Duration, Instant}, +}; + +use crate::common::TimeoutError; + +pub struct ElapsedTimeCounter { + instant: Instant, +} + +impl ElapsedTimeCounter { + pub fn start() -> Self { + let instant = Instant::now(); + + Self { instant } + } + + pub fn elapsed_time(&self) -> Duration { + self.instant.elapsed() + } +} + +pub async fn sleep(duration: Duration) -> () { + tokio::time::sleep(duration).await +} + +pub async fn timeout(duration: Duration, future: F) -> Result +where + F: Future + Send, +{ + let result = tokio::time::timeout(duration, future).await; + + result.map_err(|_| TimeoutError::new()) +} diff --git a/libs/crosstarget-utils/src/wasm/mod.rs b/libs/crosstarget-utils/src/wasm/mod.rs new file mode 100644 index 000000000000..b801d82e3118 --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/mod.rs @@ -0,0 +1,2 @@ +pub mod spawn; +pub mod time; diff --git a/libs/crosstarget-utils/src/wasm/spawn.rs b/libs/crosstarget-utils/src/wasm/spawn.rs new file mode 100644 index 000000000000..33ed1d21b3b7 --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/spawn.rs @@ -0,0 +1,10 @@ +use std::future::Future; + +use crate::common::SpawnError; + +pub async fn spawn_if_possible(future: F) -> Result +where + F: Future + 'static, +{ + Ok(future.await) +} diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs new file mode 100644 index 000000000000..e983aa5678a6 --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -0,0 +1,59 @@ +use js_sys::{Date, Function, Promise}; +use std::future::Future; +use std::time::Duration; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; + +use crate::common::TimeoutError; + +#[wasm_bindgen] +extern "C" { + + type Performance; + #[wasm_bindgen(js_name = "performance")] + static PERFORMANCE: Option; + + #[wasm_bindgen(method)] + fn now(this: &Performance) -> f64; + + #[wasm_bindgen(js_name = setTimeout)] + fn set_timeout(closure: &Function, millis: u32) -> f64; + +} + +pub struct ElapsedTimeCounter { + start_time: f64, +} + +impl ElapsedTimeCounter { + pub fn start() -> Self { + Self { start_time: now() } + } + + pub fn elapsed_time(&self) -> Duration { + Duration::from_millis((self.start_time - now()) as u64) + } +} + +pub async fn sleep(duration: Duration) -> () { + JsFuture::from(Promise::new(&mut |resolve, _reject| { + set_timeout(&resolve, duration.as_millis() as u32); + })) + .await + // TODO: + .unwrap(); +} + +pub async fn timeout(duration: Duration, future: F) -> Result +where + F: Future, +{ + tokio::select! { + result = future => Ok(result), + _ = sleep(duration) => Err(TimeoutError::new()) + } +} + +fn now() -> f64 { + PERFORMANCE.as_ref().map(|p| p.now()).unwrap_or_else(|| Date::now()) +} diff --git a/libs/elapsed/Cargo.toml b/libs/elapsed/Cargo.toml deleted file mode 100644 index 71f103e73ab4..000000000000 --- a/libs/elapsed/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "elapsed" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] - -[target.'cfg(target_arch = "wasm32")'.dependencies] -js-sys.workspace = true diff --git a/libs/elapsed/src/native.rs b/libs/elapsed/src/native.rs deleted file mode 100644 index 93855abbe648..000000000000 --- a/libs/elapsed/src/native.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::time::{Duration, Instant}; - -pub struct ElapsedTimeCounter { - instant: Instant, -} - -impl ElapsedTimeCounter { - pub fn start() -> Self { - let instant = Instant::now(); - - Self { instant } - } - - pub fn elapsed_time(&self) -> Duration { - self.instant.elapsed() - } -} diff --git a/libs/elapsed/src/wasm.rs b/libs/elapsed/src/wasm.rs deleted file mode 100644 index cdd83251e4b1..000000000000 --- a/libs/elapsed/src/wasm.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::time::Duration; - -/// TODO: this is a stub that always returns 0 as elapsed time -/// In should use performance::now() instead -pub struct ElapsedTimeCounter {} - -impl ElapsedTimeCounter { - pub fn start() -> Self { - Self {} - } - - pub fn elapsed_time(&self) -> Duration { - Duration::from_millis(0u64) - } -} diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b884834277d5..254c27446c9b 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -89,7 +89,7 @@ mobc = { version = "0.8", optional = true } serde = { version = "1.0", optional = true } sqlformat = { version = "0.2.0", optional = true } uuid = { version = "1", features = ["v4"] } -elapsed = { path = "../libs/elapsed" } +crosstarget-utils = { path = "../libs/crosstarget-utils" } [dev-dependencies] once_cell = "1.3" diff --git a/quaint/src/connector/metrics.rs b/quaint/src/connector/metrics.rs index 628a2e81f7a3..a0c4ef426988 100644 --- a/quaint/src/connector/metrics.rs +++ b/quaint/src/connector/metrics.rs @@ -1,7 +1,7 @@ use tracing::{info_span, Instrument}; use crate::ast::{Params, Value}; -use elapsed::ElapsedTimeCounter; +use crosstarget_utils::time::ElapsedTimeCounter; use std::future::Future; pub async fn query<'a, F, T, U>(tag: &'static str, query: &'a str, params: &'a [Value<'_>], f: F) -> crate::Result diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 370ce6b81ec5..1b7c52e59de9 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -34,7 +34,7 @@ user-facing-errors = { path = "../../libs/user-facing-errors" } uuid = "1" cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } schema = { path = "../schema" } -elapsed = { path = "../../libs/elapsed" } +crosstarget-utils = { path = "../../libs/crosstarget-utils" } lru = "0.7.7" enumflags2 = "0.7" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index c6860eecb29e..63555187fb7b 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -6,7 +6,7 @@ use crate::{ QueryGraphBuilder, QueryInterpreter, ResponseData, }; use connector::{Connection, ConnectionLike, Connector}; -use elapsed::ElapsedTimeCounter; +use crosstarget_utils::time::ElapsedTimeCounter; use futures::future; #[cfg(feature = "metrics")] @@ -123,7 +123,7 @@ pub async fn execute_many_self_contained( ); let conn = connector.get_connection().instrument(conn_span).await?; - futures.push(tokio::spawn( + futures.push(crosstarget_utils::spawn::spawn_if_possible( request_context::with_request_context( engine_protocol, execute_self_contained( @@ -227,7 +227,7 @@ async fn execute_self_contained_with_retry( let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; if is_transient_error(&res) && retry_timeout.elapsed_time() < MAX_TX_TIMEOUT_RETRY_LIMIT { - tokio::time::sleep(TX_RETRY_BACKOFF).await; + crosstarget_utils::time::sleep(TX_RETRY_BACKOFF).await; continue; } else { return res; diff --git a/query-engine/core/src/executor/interpreting_executor.rs b/query-engine/core/src/executor/interpreting_executor.rs index fb2b13938378..0408361b766d 100644 --- a/query-engine/core/src/executor/interpreting_executor.rs +++ b/query-engine/core/src/executor/interpreting_executor.rs @@ -8,7 +8,7 @@ use crate::{ use async_trait::async_trait; use connector::Connector; use schema::QuerySchemaRef; -use tokio::time::{self, Duration}; +use tokio::time::Duration; use tracing_futures::Instrument; /// Central query executor and main entry point into the query core. @@ -36,7 +36,8 @@ where } } -#[async_trait] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl QueryExecutor for InterpretingExecutor where C: Connector + Send + Sync + 'static, @@ -140,7 +141,8 @@ where } } -#[async_trait] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl TransactionManager for InterpretingExecutor where C: Connector + Send + Sync, @@ -162,7 +164,7 @@ where user_facing = true, "db.type" = self.connector.name() ); - let conn = time::timeout( + let conn = crosstarget_utils::time::timeout( Duration::from_millis(tx_opts.max_acquisition_millis), self.connector.get_connection(), ) diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ba2784d3c71a..01a9e09674db 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -25,7 +25,8 @@ use connector::Connector; use serde::{Deserialize, Serialize}; use tracing::Dispatch; -#[async_trait] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] pub trait QueryExecutor: TransactionManager { /// Executes a single operation and returns its result. /// Implementers must honor the passed transaction ID and execute the operation on the transaction identified @@ -95,7 +96,8 @@ impl TransactionOptions { tx_id } } -#[async_trait] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] pub trait TransactionManager { /// Starts a new transaction. /// Returns ID of newly opened transaction. diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs index f127da55fc00..3113ecd28f88 100644 --- a/query-engine/core/src/executor/task.rs +++ b/query-engine/core/src/executor/task.rs @@ -92,7 +92,7 @@ mod arch { pub fn spawn(future: T) -> JoinHandle where - T: Future + Send + 'static, + T: Future + 'static, T::Output: Send + 'static, { spawn_with_sx_exit::(future, None) @@ -100,7 +100,7 @@ mod arch { pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle where - T: Future + Send + 'static, + T: Future + 'static, T::Output: Send + 'static, { let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); @@ -110,7 +110,7 @@ mod arch { fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle where - T: Future + Send + 'static, + T: Future + 'static, T::Output: Send + 'static, { let (sender, receiver) = oneshot::channel(); diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 5af01260a921..2c801ba3a5b0 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -5,7 +5,7 @@ use crate::{ TxId, }; use connector::Connection; -use elapsed::ElapsedTimeCounter; +use crosstarget_utils::time::ElapsedTimeCounter; use schema::QuerySchemaRef; use std::{collections::HashMap, sync::Arc}; use tokio::{ @@ -13,7 +13,7 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - time::{self, Duration}, + time::Duration, }; use tracing::Span; use tracing_futures::Instrument; @@ -299,7 +299,7 @@ pub(crate) async fn spawn_itx_actor( ); let start_time = ElapsedTimeCounter::start(); - let sleep = time::sleep(timeout); + let sleep = crosstarget_utils::time::sleep(timeout); tokio::pin!(sleep); loop { diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index ac92d52efcf2..c3ee76703a06 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,6 +1,6 @@ use crate::CoreError; use connector::Transaction; -use elapsed::ElapsedTimeCounter; +use crosstarget_utils::time::ElapsedTimeCounter; use serde::Deserialize; use std::fmt::Display; use tokio::time::Duration; From e5c9716cde9c668dcd2dac8cfb67b8eae0ce8b0b Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Tue, 28 Nov 2023 15:24:35 +0100 Subject: [PATCH 092/134] Remove unused import --- query-engine/core/src/interactive_transactions/actors.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 2c801ba3a5b0..c5568d117d7c 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -18,7 +18,6 @@ use tokio::{ use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; -use wasm_rs_dbg::dbg; #[cfg(feature = "metrics")] use crate::telemetry::helpers::set_span_link_from_traceparent; From 08accaed3896677c981c7de59607edb632a4be2f Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Tue, 28 Nov 2023 19:05:30 +0100 Subject: [PATCH 093/134] Fix hanging itx --- query-engine/core/src/interactive_transactions/actors.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index c5568d117d7c..071512cee06d 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -315,6 +315,8 @@ pub(crate) async fn spawn_itx_actor( if run_state == RunState::Finished { break } + } else { + break; } } } From 269998d4cff60fc2159b5336cbcb985b0b1dff31 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 29 Nov 2023 12:51:53 +0100 Subject: [PATCH 094/134] fix insta tests Strictly speaking, not related to wasm engine at all - we bumped `insta` at some point and that required adding `allow_duplicates` macro around the loop. Close prisma/team-orm#651 --- Cargo.lock | 1 + .../query-tests-setup/Cargo.toml | 5 +- .../query-tests-setup/src/lib.rs | 74 ++++++++++--------- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9c731261f94..8c1f1e22583f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3890,6 +3890,7 @@ dependencies = [ "hyper", "indexmap 1.9.3", "indoc 2.0.3", + "insta", "itertools", "jsonrpc-core", "nom", diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index cf1b98b25adb..aa9f5957fb89 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -12,7 +12,7 @@ request-handlers = { path = "../../request-handlers" } tokio.workspace = true query-core = { path = "../../core", features = ["metrics"] } sql-query-connector = { path = "../../connectors/sql-query-connector" } -query-engine = { path = "../../query-engine"} +query-engine = { path = "../../query-engine" } psl.workspace = true user-facing-errors = { path = "../../../libs/user-facing-errors" } thiserror = "1.0" @@ -30,9 +30,10 @@ indoc.workspace = true enumflags2 = "0.7" hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1.0", features = ["serde-1"] } -query-engine-metrics = {path = "../../metrics"} +query-engine-metrics = { path = "../../metrics" } quaint.workspace = true jsonrpc-core = "17" +insta = "1.7.1" # Only this version is vetted, upgrade only after going through the code, # as this is a small crate with little user base. diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs index af99d9a7a7d3..b216e44b9d12 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs @@ -144,45 +144,47 @@ fn run_relation_link_test_impl( let (dms, capabilities) = schema_with_relation(on_parent, on_child, id_only); - for (i, (dm, caps)) in dms.into_iter().zip(capabilities.into_iter()).enumerate() { - if RELATION_TEST_IDX.map(|idx| idx != i).unwrap_or(false) { - continue; - } - - let required_capabilities_for_test = required_capabilities | caps; - let test_db_name = format!("{suite_name}_{test_name}_{i}"); - let template = dm.datamodel().to_owned(); - let (connector, version) = CONFIG.test_connector().unwrap(); - - if !should_run(&connector, &version, only, exclude, required_capabilities_for_test) { - continue; - } - - let datamodel = render_test_datamodel(&test_db_name, template, &[], None, Default::default(), None); - let (connector_tag, version) = CONFIG.test_connector().unwrap(); - let metrics = setup_metrics(); - let metrics_for_subscriber = metrics.clone(); - let (log_capture, log_tx) = TestLogCapture::new(); - - run_with_tokio( - async move { - println!("Used datamodel:\n {}", datamodel.yellow()); - let runner = Runner::load(datamodel.clone(), &[], version, connector_tag, metrics, log_capture) - .await - .unwrap(); + insta::allow_duplicates! { + for (i, (dm, caps)) in dms.into_iter().zip(capabilities.into_iter()).enumerate() { + if RELATION_TEST_IDX.map(|idx| idx != i).unwrap_or(false) { + continue; + } - test_fn(&runner, &dm).await.unwrap(); + let required_capabilities_for_test = required_capabilities | caps; + let test_db_name = format!("{suite_name}_{test_name}_{i}"); + let template = dm.datamodel().to_owned(); + let (connector, version) = CONFIG.test_connector().unwrap(); - teardown_project(&datamodel, Default::default(), runner.schema_id()) - .await - .unwrap(); + if !should_run(&connector, &version, only, exclude, required_capabilities_for_test) { + continue; } - .with_subscriber(test_tracing_subscriber( - ENV_LOG_LEVEL.to_string(), - metrics_for_subscriber, - log_tx, - )), - ); + + let datamodel = render_test_datamodel(&test_db_name, template, &[], None, Default::default(), None); + let (connector_tag, version) = CONFIG.test_connector().unwrap(); + let metrics = setup_metrics(); + let metrics_for_subscriber = metrics.clone(); + let (log_capture, log_tx) = TestLogCapture::new(); + + run_with_tokio( + async move { + println!("Used datamodel:\n {}", datamodel.yellow()); + let runner = Runner::load(datamodel.clone(), &[], version, connector_tag, metrics, log_capture) + .await + .unwrap(); + + test_fn(&runner, &dm).await.unwrap(); + + teardown_project(&datamodel, Default::default(), runner.schema_id()) + .await + .unwrap(); + } + .with_subscriber(test_tracing_subscriber( + ENV_LOG_LEVEL.to_string(), + metrics_for_subscriber, + log_tx, + )), + ); + } } } From a063244e92b5b75579e04b8686f3f5af66effc2b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 29 Nov 2023 13:20:27 +0100 Subject: [PATCH 095/134] feat(query-engine-wasm): enable tracing and bits of telemetry --- Cargo.lock | 2 + query-engine/core/src/lib.rs | 6 +- .../core/src/telemetry/capturing/mod.rs | 10 +- query-engine/query-engine-wasm/Cargo.toml | 2 + query-engine/query-engine-wasm/src/wasm.rs | 1 + .../query-engine-wasm/src/wasm/engine.rs | 56 ++++++++--- .../query-engine-wasm/src/wasm/logger.rs | 33 +++++-- .../query-engine-wasm/src/wasm/tracer.rs | 93 +++++++++++++++++++ 8 files changed, 174 insertions(+), 29 deletions(-) create mode 100644 query-engine/query-engine-wasm/src/wasm/tracer.rs diff --git a/Cargo.lock b/Cargo.lock index 893b1b9a345e..eca2bfc07255 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3827,6 +3827,7 @@ dependencies = [ "futures", "js-sys", "log", + "opentelemetry", "psl", "quaint", "query-connector", @@ -3841,6 +3842,7 @@ dependencies = [ "tokio", "tracing", "tracing-futures", + "tracing-opentelemetry", "tracing-subscriber", "tsify", "url", diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 38f39e9fb5d9..219b78753277 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -9,10 +9,9 @@ pub mod protocol; pub mod query_document; pub mod query_graph_builder; pub mod response_ir; - -#[cfg(feature = "metrics")] pub mod telemetry; +pub use self::telemetry::*; pub use self::{ error::{CoreError, FieldConversionError}, executor::{QueryExecutor, TransactionOptions}, @@ -20,9 +19,6 @@ pub use self::{ query_document::*, }; -#[cfg(feature = "metrics")] -pub use self::telemetry::*; - pub use connector::{ error::{ConnectorError, ErrorKind as ConnectorErrorKind}, Connector, diff --git a/query-engine/core/src/telemetry/capturing/mod.rs b/query-engine/core/src/telemetry/capturing/mod.rs index 73a5c318697d..bbdc6ae9a083 100644 --- a/query-engine/core/src/telemetry/capturing/mod.rs +++ b/query-engine/core/src/telemetry/capturing/mod.rs @@ -134,7 +134,8 @@ //! - Finally, the server sets the `logs` and `traces` extensions in the `PrismaResponse`**[12]**, //! it serializes the extended response in json format and returns it as an HTTP Response //! blob **[13]**. -//! +//! +#![allow(unused_imports, dead_code)] pub use self::capturer::Capturer; pub use self::settings::Settings; pub use tx_ext::TxTraceExt; @@ -142,7 +143,6 @@ pub use tx_ext::TxTraceExt; use self::capturer::Processor; use once_cell::sync::Lazy; use opentelemetry::{global, sdk, trace}; -use query_engine_metrics::MetricRegistry; use tracing::subscriber; use tracing_subscriber::{ filter::filter_fn, layer::Layered, prelude::__tracing_subscriber_SubscriberExt, Layer, Registry, @@ -158,9 +158,13 @@ pub fn capturer(trace_id: trace::TraceId, settings: Settings) -> Capturer { /// Adds a capturing layer to the given subscriber and installs the transformed subscriber as the /// global, default subscriber +#[cfg(feature = "metrics")] #[allow(clippy::type_complexity)] pub fn install_capturing_layer( - subscriber: Layered, Layered + Send + Sync>, Registry>>, + subscriber: Layered< + Option, + Layered + Send + Sync>, Registry>, + >, log_queries: bool, ) { // set a trace context propagator, so that the trace context is propagated via the diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 06738c456709..e08d412d5f97 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -46,4 +46,6 @@ wasm-logger = "0.2.0" tracing = "0.1" tracing-subscriber = { version = "0.3" } tracing-futures = "0.2" +tracing-opentelemetry = "0.17.3" +opentelemetry = { version = "0.17"} console_error_panic_hook = "0.1.7" diff --git a/query-engine/query-engine-wasm/src/wasm.rs b/query-engine/query-engine-wasm/src/wasm.rs index 5e83cf3aa2b6..8174dc8738c4 100644 --- a/query-engine/query-engine-wasm/src/wasm.rs +++ b/query-engine/query-engine-wasm/src/wasm.rs @@ -2,5 +2,6 @@ pub mod engine; pub mod error; pub mod functions; pub mod logger; +mod tracer; pub(crate) type Executor = Box; diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 1b3b51653c7a..939c93726c60 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -8,10 +8,11 @@ use crate::{ use driver_adapters::JsObject; use futures::FutureExt; use js_sys::Function as JsFunction; +use psl::PreviewFeature; use query_core::{ protocol::EngineProtocol, schema::{self, QuerySchema}, - QueryExecutor, TransactionOptions, TxId, + telemetry, QueryExecutor, TransactionOptions, TxId, }; use request_handlers::ConnectorMode; use request_handlers::{dmmf, load_executor, render_graphql_schema, RequestBody, RequestHandler}; @@ -25,7 +26,7 @@ use std::{ sync::Arc, }; use tokio::sync::RwLock; -use tracing::{field, Instrument, Span}; +use tracing::{field, instrument::WithSubscriber, Instrument, Span}; use tracing_subscriber::filter::LevelFilter; use tsify::Tsify; use user_facing_errors::Error; @@ -137,10 +138,7 @@ impl QueryEngine { callback: JsFunction, maybe_adapter: Option, ) -> Result { - log::info!("Called `QueryEngine::new()`"); - let log_callback = LogCallback(callback); - log::info!("Parsed `log_callback`"); let ConstructorOptions { datamodel, @@ -166,7 +164,7 @@ impl QueryEngine { sql_connector::activate_driver_adapter(Arc::new(js_queryable)); let provider_name = schema.connector.provider_name(); - log::info!("Received driver adapter for {provider_name}."); + tracing::info!("Received driver adapter for {provider_name}."); } schema @@ -186,6 +184,7 @@ impl QueryEngine { .validate_that_one_datasource_is_provided() .map_err(|errors| ApiError::conversion(errors, schema.db.source()))?; + let enable_tracing = config.preview_features().contains(PreviewFeature::Tracing); let engine_protocol = engine_protocol.unwrap_or(EngineProtocol::Json); let builder = EngineBuilder { @@ -196,7 +195,7 @@ impl QueryEngine { }; let log_level = log_level.parse::().unwrap(); - let logger = Logger::new(log_queries, log_level, log_callback); + let logger = Logger::new(log_queries, log_level, log_callback, enable_tracing); let connector_mode = ConnectorMode::Js; @@ -210,8 +209,11 @@ impl QueryEngine { /// Connect to the database, allow queries to be run. #[wasm_bindgen] pub async fn connect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { + let dispatcher = self.logger.dispatcher(); + async_panic_to_js_error(async { let span = tracing::info_span!("prisma:engine:connect"); + let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -270,6 +272,7 @@ impl QueryEngine { Ok(()) }) + .with_subscriber(dispatcher) .await?; Ok(()) @@ -278,8 +281,11 @@ impl QueryEngine { /// Disconnect and drop the core. Can be reconnected later with `#connect`. #[wasm_bindgen] pub async fn disconnect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { + let dispatcher = self.logger.dispatcher(); + async_panic_to_js_error(async { let span = tracing::info_span!("prisma:engine:disconnect"); + let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); async { let mut inner = self.inner.write().await; @@ -299,6 +305,7 @@ impl QueryEngine { .instrument(span) .await }) + .with_subscriber(dispatcher) .await } @@ -310,6 +317,8 @@ impl QueryEngine { trace: String, tx_id: Option, ) -> Result { + let dispatcher = self.logger.dispatcher(); + async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; @@ -323,9 +332,11 @@ impl QueryEngine { Span::none() }; + let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); let response = handler - .handle(query, tx_id.map(TxId::from), None) + .handle(query, tx_id.map(TxId::from), trace_id) .instrument(span) .await; @@ -333,6 +344,7 @@ impl QueryEngine { } .await }) + .with_subscriber(dispatcher) .await } @@ -343,6 +355,8 @@ impl QueryEngine { let inner = self.inner.read().await; let engine = inner.as_engine()?; + let dispatcher = self.logger.dispatcher(); + async move { let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); @@ -357,6 +371,7 @@ impl QueryEngine { Err(err) => Ok(map_known_error(err)?), } } + .with_subscriber(dispatcher) .await }) .await @@ -369,12 +384,15 @@ impl QueryEngine { let inner = self.inner.read().await; let engine = inner.as_engine()?; + let dispatcher = self.logger.dispatcher(); + async move { match engine.executor().commit_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } + .with_subscriber(dispatcher) .await }) .await @@ -386,14 +404,21 @@ impl QueryEngine { let inner = self.inner.read().await; let engine = inner.as_engine()?; - let dmmf = dmmf::render_dmmf(&engine.query_schema); + let dispatcher = self.logger.dispatcher(); - let json = { - let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); - serde_json::to_string(&dmmf)? - }; + tracing::dispatcher::with_default(&dispatcher, || { + let span = tracing::info_span!("prisma:engine:dmmf"); + let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let _guard = span.enter(); + let dmmf = dmmf::render_dmmf(&engine.query_schema); - Ok(json) + let json = { + let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); + serde_json::to_string(&dmmf)? + }; + + Ok(json) + }) }) .await } @@ -405,12 +430,15 @@ impl QueryEngine { let inner = self.inner.read().await; let engine = inner.as_engine()?; + let dispatcher = self.logger.dispatcher(); + async move { match engine.executor().rollback_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } + .with_subscriber(dispatcher) .await }) .await diff --git a/query-engine/query-engine-wasm/src/wasm/logger.rs b/query-engine/query-engine-wasm/src/wasm/logger.rs index 561c48271b77..a4d03a83e82d 100644 --- a/query-engine/query-engine-wasm/src/wasm/logger.rs +++ b/query-engine/query-engine-wasm/src/wasm/logger.rs @@ -2,6 +2,7 @@ use core::fmt; use js_sys::Function as JsFunction; +use query_core::telemetry; use serde_json::Value; use std::collections::BTreeMap; use tracing::{ @@ -16,7 +17,17 @@ use tracing_subscriber::{ }; use wasm_bindgen::JsValue; -pub(crate) struct LogCallback(pub JsFunction); +#[derive(Clone)] +pub struct LogCallback(pub JsFunction); + +impl LogCallback { + pub fn call>(&self, arg1: T) -> Result<(), String> { + self.0 + .call1(&JsValue::NULL, &arg1.into()) + .map(|_| ()) + .map_err(|err| err.as_string().unwrap_or_default()) + } +} unsafe impl Send for LogCallback {} unsafe impl Sync for LogCallback {} @@ -27,7 +38,7 @@ pub(crate) struct Logger { impl Logger { /// Creates a new logger using a call layer - pub fn new(log_queries: bool, log_level: LevelFilter, log_callback: LogCallback) -> Self { + pub fn new(log_queries: bool, log_level: LevelFilter, log_callback: LogCallback, enable_tracing: bool) -> Self { let is_sql_query = filter_fn(|meta| { meta.target() == "quaint::connector::metrics" && meta.fields().iter().any(|f| f.name() == "query") }); @@ -44,10 +55,21 @@ impl Logger { FilterExt::boxed(log_level) }; + let is_user_trace = filter_fn(telemetry::helpers::user_facing_span_only_filter); + let tracer = super::tracer::new_pipeline().install_simple(log_callback.clone()); + let telemetry = if enable_tracing { + let telemetry = tracing_opentelemetry::layer() + .with_tracer(tracer) + .with_filter(is_user_trace); + Some(telemetry) + } else { + None + }; + let layer = CallbackLayer::new(log_callback).with_filter(filters); Self { - dispatcher: Dispatch::new(Registry::default().with(layer)), + dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer)), } } @@ -124,9 +146,6 @@ impl Layer for CallbackLayer { let mut visitor = JsonVisitor::new(event.metadata().level(), event.metadata().target()); event.record(&mut visitor); - let _ = self - .callback - .0 - .call1(&JsValue::NULL, &JsValue::from_str(&visitor.to_string())); + let _ = self.callback.call(&visitor.to_string()); } } diff --git a/query-engine/query-engine-wasm/src/wasm/tracer.rs b/query-engine/query-engine-wasm/src/wasm/tracer.rs new file mode 100644 index 000000000000..7bcd1ab81043 --- /dev/null +++ b/query-engine/query-engine-wasm/src/wasm/tracer.rs @@ -0,0 +1,93 @@ +use async_trait::async_trait; +use opentelemetry::{ + global, sdk, + sdk::{ + export::trace::{ExportResult, SpanData, SpanExporter}, + propagation::TraceContextPropagator, + }, + trace::{TraceError, TracerProvider}, +}; +use query_core::telemetry; +use std::fmt::{self, Debug}; + +use crate::logger::LogCallback; + +/// Pipeline builder +#[derive(Debug)] +pub struct PipelineBuilder { + trace_config: Option, +} + +/// Create a new stdout exporter pipeline builder. +pub fn new_pipeline() -> PipelineBuilder { + PipelineBuilder::default() +} + +impl Default for PipelineBuilder { + /// Return the default pipeline builder. + fn default() -> Self { + Self { trace_config: None } + } +} + +impl PipelineBuilder { + /// Assign the SDK trace configuration. + #[allow(dead_code)] + pub fn with_trace_config(mut self, config: sdk::trace::Config) -> Self { + self.trace_config = Some(config); + self + } +} + +impl PipelineBuilder { + pub fn install_simple(mut self, log_callback: LogCallback) -> sdk::trace::Tracer { + global::set_text_map_propagator(TraceContextPropagator::new()); + let exporter = ClientSpanExporter::new(log_callback); + + let mut provider_builder = sdk::trace::TracerProvider::builder().with_simple_exporter(exporter); + // This doesn't work at the moment because we create the logger outside of an async runtime + // we could later move the creation of logger into the `connect` function + // let mut provider_builder = sdk::trace::TracerProvider::builder().with_batch_exporter(exporter, runtime::Tokio); + // remember to add features = ["rt-tokio"] to the cargo.toml + if let Some(config) = self.trace_config.take() { + provider_builder = provider_builder.with_config(config); + } + let provider = provider_builder.build(); + let tracer = provider.tracer("opentelemetry"); + global::set_tracer_provider(provider); + + tracer + } +} + +/// A [`ClientSpanExporter`] that sends spans to the JS callback. +pub struct ClientSpanExporter { + callback: LogCallback, +} + +impl ClientSpanExporter { + pub fn new(callback: LogCallback) -> Self { + Self { callback } + } +} + +impl Debug for ClientSpanExporter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientSpanExporter").finish() + } +} + +#[async_trait] +impl SpanExporter for ClientSpanExporter { + /// Export spans to stdout + async fn export(&mut self, batch: Vec) -> ExportResult { + let result = telemetry::helpers::spans_to_json(batch); + let status = self.callback.call(result); + + if let Err(err) = status { + return Err(TraceError::from(format!("Could not call JS callback: {}", err))); + } + + Ok(()) + } +} From bec8bc3b7c773fbfa8505275efee100cb60ee04d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 29 Nov 2023 13:23:19 +0100 Subject: [PATCH 096/134] chore: clippy --- libs/crosstarget-utils/src/common.rs | 4 ++-- libs/crosstarget-utils/src/native/time.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/crosstarget-utils/src/common.rs b/libs/crosstarget-utils/src/common.rs index 3afce64c6714..be5e3742ee50 100644 --- a/libs/crosstarget-utils/src/common.rs +++ b/libs/crosstarget-utils/src/common.rs @@ -1,6 +1,6 @@ use std::fmt::Display; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct SpawnError {} impl SpawnError { @@ -17,7 +17,7 @@ impl Display for SpawnError { impl std::error::Error for SpawnError {} -#[derive(Debug)] +#[derive(Debug, Default)] pub struct TimeoutError {} impl TimeoutError { diff --git a/libs/crosstarget-utils/src/native/time.rs b/libs/crosstarget-utils/src/native/time.rs index e222e08cf628..273ef6a4f364 100644 --- a/libs/crosstarget-utils/src/native/time.rs +++ b/libs/crosstarget-utils/src/native/time.rs @@ -21,7 +21,7 @@ impl ElapsedTimeCounter { } } -pub async fn sleep(duration: Duration) -> () { +pub async fn sleep(duration: Duration) { tokio::time::sleep(duration).await } From 715c87fac5f8e46185d1fb56177cf636fcf48eee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Fern=C3=A1ndez?= Date: Wed, 29 Nov 2023 14:21:17 +0100 Subject: [PATCH 097/134] Fix panic kills running engine in query-engine-tests (#4499) * I didn't have wasm-pack installed, let's fix that * Update wasm-bindgen-futures to 0.4.39 This includes https://github.com/rustwasm/wasm-bindgen/issues/3203 that use queueMicrotask to transalate spawn_local rust code. This has fixed https://github.com/rustwasm/wasm-bindgen/issues/2392 which was an issue about not being able to catch async wasm traps. This might (or not) have an effect on the issue we are trying to solve in here. * Revert "Update wasm-bindgen-futures to 0.4.39" This reverts commit 9a494dc1f5f2250ddf2c3f71b52ec2221e6c93cb. * Restart executor when it dies * Document Restartable * Remove async_panic_to_js_error in WASM query engine * Rename p -> process * Use tokio::sync::RwLock rather than futures::lock::Mutex * Better error messaging * Fixing clippy * Exclude unit tests for wasm32 when compiling the binary for other architectures --- libs/crosstarget-utils/src/common.rs | 16 +- libs/crosstarget-utils/src/native/spawn.rs | 2 +- libs/crosstarget-utils/src/native/time.rs | 4 +- libs/crosstarget-utils/src/wasm/time.rs | 2 +- .../query-tests-setup/src/connector_tag/js.rs | 2 +- .../src/connector_tag/js/external_process.rs | 89 +++++- .../query-tests-setup/src/lib.rs | 4 +- query-engine/driver-adapters/tests/wasm.rs | 2 +- query-engine/query-engine-wasm/Cargo.toml | 2 +- query-engine/query-engine-wasm/build.sh | 7 + .../query-engine-wasm/src/wasm/engine.rs | 293 ++++++++---------- 11 files changed, 218 insertions(+), 205 deletions(-) diff --git a/libs/crosstarget-utils/src/common.rs b/libs/crosstarget-utils/src/common.rs index 3afce64c6714..92a1d5094e89 100644 --- a/libs/crosstarget-utils/src/common.rs +++ b/libs/crosstarget-utils/src/common.rs @@ -1,13 +1,7 @@ use std::fmt::Display; #[derive(Debug)] -pub struct SpawnError {} - -impl SpawnError { - pub fn new() -> Self { - SpawnError {} - } -} +pub struct SpawnError; impl Display for SpawnError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18,13 +12,7 @@ impl Display for SpawnError { impl std::error::Error for SpawnError {} #[derive(Debug)] -pub struct TimeoutError {} - -impl TimeoutError { - pub fn new() -> Self { - TimeoutError {} - } -} +pub struct TimeoutError; impl Display for TimeoutError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/libs/crosstarget-utils/src/native/spawn.rs b/libs/crosstarget-utils/src/native/spawn.rs index b0d541258c2a..8a8360c580fa 100644 --- a/libs/crosstarget-utils/src/native/spawn.rs +++ b/libs/crosstarget-utils/src/native/spawn.rs @@ -7,5 +7,5 @@ where F: Future + 'static + Send, F::Output: Send + 'static, { - tokio::spawn(future).await.map_err(|_| SpawnError::new()) + tokio::spawn(future).await.map_err(|_| SpawnError) } diff --git a/libs/crosstarget-utils/src/native/time.rs b/libs/crosstarget-utils/src/native/time.rs index e222e08cf628..3b154a27565c 100644 --- a/libs/crosstarget-utils/src/native/time.rs +++ b/libs/crosstarget-utils/src/native/time.rs @@ -21,7 +21,7 @@ impl ElapsedTimeCounter { } } -pub async fn sleep(duration: Duration) -> () { +pub async fn sleep(duration: Duration) { tokio::time::sleep(duration).await } @@ -31,5 +31,5 @@ where { let result = tokio::time::timeout(duration, future).await; - result.map_err(|_| TimeoutError::new()) + result.map_err(|_| TimeoutError) } diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs index e983aa5678a6..6f14ac001ee8 100644 --- a/libs/crosstarget-utils/src/wasm/time.rs +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -50,7 +50,7 @@ where { tokio::select! { result = future => Ok(result), - _ = sleep(duration) => Err(TimeoutError::new()) + _ = sleep(duration) => Err(TimeoutError) } } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs index 2ec8513baeda..c852924bbf69 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs @@ -3,7 +3,7 @@ mod external_process; use super::*; use external_process::*; use serde::de::DeserializeOwned; -use std::{collections::HashMap, sync::atomic::AtomicU64}; +use std::sync::atomic::AtomicU64; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; pub(crate) async fn executor_process_request( diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs index 1abfedbaf8ee..912a5e6d8abf 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs @@ -1,8 +1,13 @@ use super::*; use once_cell::sync::Lazy; use serde::de::DeserializeOwned; -use std::{fmt::Display, io::Write as _, sync::atomic::Ordering}; -use tokio::sync::{mpsc, oneshot}; +use std::{ + error::Error as StdError, + fmt::Display, + io::Write as _, + sync::{atomic::Ordering, Arc}, +}; +use tokio::sync::{mpsc, oneshot, RwLock}; type Result = std::result::Result>; @@ -29,6 +34,17 @@ fn exit_with_message(status_code: i32, message: &str) -> ! { } impl ExecutorProcess { + fn spawn() -> ExecutorProcess { + match std::thread::spawn(ExecutorProcess::new).join() { + Ok(Ok(process)) => process, + Ok(Err(err)) => exit_with_message(1, &format!("Failed to start node process. Details: {err}")), + Err(err) => { + let err = err.downcast_ref::().map(ToOwned::to_owned).unwrap_or_default(); + exit_with_message(1, &format!("Panic while trying to start node process.\nDetails: {err}")) + } + } + } + fn new() -> Result { let (sender, receiver) = mpsc::channel::(300); @@ -81,15 +97,50 @@ impl ExecutorProcess { } } -pub(super) static EXTERNAL_PROCESS: Lazy = - Lazy::new(|| match std::thread::spawn(ExecutorProcess::new).join() { - Ok(Ok(process)) => process, - Ok(Err(err)) => exit_with_message(1, &format!("Failed to start node process. Details: {err}")), - Err(err) => { - let err = err.downcast_ref::().map(ToOwned::to_owned).unwrap_or_default(); - exit_with_message(1, &format!("Panic while trying to start node process.\nDetails: {err}")) +/// Wraps an ExecutorProcess allowing for restarting it. +/// +/// A node process can die for a number of reasons, being one that any `panic!` occurring in Rust +/// asynchronous code are translated to an abort trap by wasm-bindgen, which kills the node process. +#[derive(Clone)] +pub(crate) struct RestartableExecutorProcess { + process: Arc>, +} + +impl RestartableExecutorProcess { + fn new() -> Self { + Self { + process: Arc::new(RwLock::new(ExecutorProcess::spawn())), } - }); + } + + async fn restart(&self) { + let mut process = self.process.write().await; + *process = ExecutorProcess::spawn(); + } + + pub(crate) async fn request(&self, method: &str, params: serde_json::Value) -> Result { + let p = self.process.read().await; + p.request(method, params).await + } +} + +struct ExecutorProcessDiedError; + +impl fmt::Debug for ExecutorProcessDiedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "The external test executor process died") + } +} + +impl Display for ExecutorProcessDiedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl StdError for ExecutorProcessDiedError {} + +pub(super) static EXTERNAL_PROCESS: Lazy = Lazy::new(RestartableExecutorProcess::new); type ReqImpl = ( jsonrpc_core::MethodCall, @@ -122,8 +173,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { let mut stdout = BufReader::new(process.stdout.unwrap()).lines(); let mut stdin = process.stdin.unwrap(); - let mut pending_requests: HashMap>> = - HashMap::new(); + let mut last_pending_request: Option<(jsonrpc_core::Id, oneshot::Sender>)> = None; loop { tokio::select! { @@ -137,7 +187,11 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { { match serde_json::from_str::(&line) { Ok(response) => { - let sender = pending_requests.remove(response.id()).unwrap(); + let (id, sender) = last_pending_request.take().expect("got a response from the external process, but there was no pending request"); + if &id != response.id() { + unreachable!("got a response from the external process, but the id didn't match. Are you running with cargo tests with `--test-threads=1`"); + } + match response { jsonrpc_core::Output::Success(success) => { // The other end may be dropped if the whole @@ -159,7 +213,12 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { } Ok(None) => // end of the stream { - exit_with_message(1, "child node process stdout closed") + tracing::error!("Error when reading from child node process. Process might have exited. Restarting..."); + if let Some((_, sender)) = last_pending_request.take() { + sender.send(Err(Box::new(ExecutorProcessDiedError))).unwrap(); + } + EXTERNAL_PROCESS.restart().await; + break; } Err(err) => // log it { @@ -174,7 +233,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { exit_with_message(1, "The json-rpc client channel was closed"); } Some((request, response_sender)) => { - pending_requests.insert(request.id.clone(), response_sender); + last_pending_request = Some((request.id.clone(), response_sender)); let mut req = serde_json::to_vec(&request).unwrap(); req.push(b'\n'); stdin.write_all(&req).await.unwrap(); diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs index b216e44b9d12..d7dbd0f53897 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs @@ -289,7 +289,9 @@ fn run_connector_test_impl( .unwrap(); let schema_id = runner.schema_id(); - test_fn(runner).await.unwrap(); + if let Err(err) = test_fn(runner).await { + panic!("💥 Test failed due to an error: {err:?}"); + } crate::teardown_project(&datamodel, db_schemas, schema_id) .await diff --git a/query-engine/driver-adapters/tests/wasm.rs b/query-engine/driver-adapters/tests/wasm.rs index c40529978c9f..8f3aa30f7335 100644 --- a/query-engine/driver-adapters/tests/wasm.rs +++ b/query-engine/driver-adapters/tests/wasm.rs @@ -1,6 +1,6 @@ +#![cfg(target_os = "wasm32")] use wasm_bindgen_test::*; -// use driver_adapters::types::ColumnType; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use tsify::Tsify; diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 06738c456709..175f47a01ae5 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [lib] doc = false crate-type = ["cdylib"] -name = "query_engine_wasm" +name = "query_engine" [dependencies] diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 13a7c13e89ec..10e3008912f0 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -21,6 +21,13 @@ else BUILD_PROFILE="--dev" fi +# Check if wasm-pack is installed +if ! command -v wasm-pack &> /dev/null +then + echo "wasm-pack could not be found, installing now..." + # Install wasm-pack + curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh +fi wasm-pack build $BUILD_PROFILE --target $OUT_TARGET sed -i '' 's/name = "query_engine"/name = "query_engine_wasm"/g' Cargo.toml diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 1b3b51653c7a..9a326e6547a9 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -6,7 +6,6 @@ use crate::{ logger::{LogCallback, Logger}, }; use driver_adapters::JsObject; -use futures::FutureExt; use js_sys::Function as JsFunction; use query_core::{ protocol::EngineProtocol, @@ -19,8 +18,6 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::{ collections::{BTreeMap, HashMap}, - future::Future, - panic::AssertUnwindSafe, path::PathBuf, sync::Arc, }; @@ -28,7 +25,6 @@ use tokio::sync::RwLock; use tracing::{field, Instrument, Span}; use tracing_subscriber::filter::LevelFilter; use tsify::Tsify; -use user_facing_errors::Error; use wasm_bindgen::prelude::wasm_bindgen; /// The main query engine used by JS #[wasm_bindgen] @@ -210,95 +206,87 @@ impl QueryEngine { /// Connect to the database, allow queries to be run. #[wasm_bindgen] pub async fn connect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:connect"); - - let mut inner = self.inner.write().await; - let builder = inner.as_builder()?; - let arced_schema = Arc::clone(&builder.schema); - let arced_schema_2 = Arc::clone(&builder.schema); - - let url = { - let data_source = builder - .schema - .configuration - .datasources - .first() - .ok_or_else(|| ApiError::configuration("No valid data source found"))?; - data_source - .load_url_with_config_dir(&builder.config_dir, |key| builder.env.get(key).map(ToString::to_string)) - .map_err(|err| crate::error::ApiError::Conversion(err, builder.schema.db.source().to_owned()))? - }; - - let engine = async move { - // We only support one data source & generator at the moment, so take the first one (default not exposed yet). - let data_source = arced_schema - .configuration - .datasources - .first() - .ok_or_else(|| ApiError::configuration("No valid data source found"))?; - - let preview_features = arced_schema.configuration.preview_features(); - - let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; - let connector = executor.primary_connector(); - - let conn_span = tracing::info_span!( - "prisma:engine:connection", - user_facing = true, - "db.type" = connector.name(), - ); - - connector.get_connection().instrument(conn_span).await?; - - let query_schema_span = tracing::info_span!("prisma:engine:schema"); - let query_schema = query_schema_span.in_scope(|| schema::build(arced_schema_2, true)); - - Ok(ConnectedEngine { - schema: builder.schema.clone(), - query_schema: Arc::new(query_schema), - executor, - config_dir: builder.config_dir.clone(), - env: builder.env.clone(), - engine_protocol: builder.engine_protocol, - }) as crate::Result - } - .instrument(span) - .await?; - - *inner = Inner::Connected(engine); + let span = tracing::info_span!("prisma:engine:connect"); + + let mut inner = self.inner.write().await; + let builder = inner.as_builder()?; + let arced_schema = Arc::clone(&builder.schema); + let arced_schema_2 = Arc::clone(&builder.schema); + + let url = { + let data_source = builder + .schema + .configuration + .datasources + .first() + .ok_or_else(|| ApiError::configuration("No valid data source found"))?; + data_source + .load_url_with_config_dir(&builder.config_dir, |key| builder.env.get(key).map(ToString::to_string)) + .map_err(|err| crate::error::ApiError::Conversion(err, builder.schema.db.source().to_owned()))? + }; - Ok(()) - }) + let engine = async move { + // We only support one data source & generator at the moment, so take the first one (default not exposed yet). + let data_source = arced_schema + .configuration + .datasources + .first() + .ok_or_else(|| ApiError::configuration("No valid data source found"))?; + + let preview_features = arced_schema.configuration.preview_features(); + + let executor = load_executor(self.connector_mode, data_source, preview_features, &url).await?; + let connector = executor.primary_connector(); + + let conn_span = tracing::info_span!( + "prisma:engine:connection", + user_facing = true, + "db.type" = connector.name(), + ); + + connector.get_connection().instrument(conn_span).await?; + + let query_schema_span = tracing::info_span!("prisma:engine:schema"); + let query_schema = query_schema_span.in_scope(|| schema::build(arced_schema_2, true)); + + Ok(ConnectedEngine { + schema: builder.schema.clone(), + query_schema: Arc::new(query_schema), + executor, + config_dir: builder.config_dir.clone(), + env: builder.env.clone(), + engine_protocol: builder.engine_protocol, + }) as crate::Result + } + .instrument(span) .await?; + *inner = Inner::Connected(engine); + Ok(()) } /// Disconnect and drop the core. Can be reconnected later with `#connect`. #[wasm_bindgen] pub async fn disconnect(&self, trace: String) -> Result<(), wasm_bindgen::JsError> { - async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:disconnect"); + let span = tracing::info_span!("prisma:engine:disconnect"); - async { - let mut inner = self.inner.write().await; - let engine = inner.as_engine()?; + async { + let mut inner = self.inner.write().await; + let engine = inner.as_engine()?; - let builder = EngineBuilder { - schema: engine.schema.clone(), - config_dir: engine.config_dir.clone(), - env: engine.env.clone(), - engine_protocol: engine.engine_protocol(), - }; + let builder = EngineBuilder { + schema: engine.schema.clone(), + config_dir: engine.config_dir.clone(), + env: engine.env.clone(), + engine_protocol: engine.engine_protocol(), + }; - *inner = Inner::Builder(builder); + *inner = Inner::Builder(builder); - Ok(()) - } - .instrument(span) - .await - }) + Ok(()) + } + .instrument(span) .await } @@ -310,122 +298,104 @@ impl QueryEngine { trace: String, tx_id: Option, ) -> Result { - async_panic_to_js_error(async { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; + let inner = self.inner.read().await; + let engine = inner.as_engine()?; - let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; + let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; - async move { - let span = if tx_id.is_none() { - tracing::info_span!("prisma:engine", user_facing = true) - } else { - Span::none() - }; + async move { + let span = if tx_id.is_none() { + tracing::info_span!("prisma:engine", user_facing = true) + } else { + Span::none() + }; - let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); - let response = handler - .handle(query, tx_id.map(TxId::from), None) - .instrument(span) - .await; + let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); + let response = handler + .handle(query, tx_id.map(TxId::from), None) + .instrument(span) + .await; - Ok(serde_json::to_string(&response)?) - } - .await - }) + Ok(serde_json::to_string(&response)?) + } .await } /// If connected, attempts to start a transaction in the core and returns its ID. #[wasm_bindgen(js_name = startTransaction)] pub async fn start_transaction(&self, input: String, trace: String) -> Result { - async_panic_to_js_error(async { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; - - async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - - let tx_opts: TransactionOptions = serde_json::from_str(&input)?; - match engine - .executor() - .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) - .instrument(span) - .await - { - Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), - Err(err) => Ok(map_known_error(err)?), - } + let inner = self.inner.read().await; + let engine = inner.as_engine()?; + + async move { + let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); + + let tx_opts: TransactionOptions = serde_json::from_str(&input)?; + match engine + .executor() + .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) + .instrument(span) + .await + { + Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), + Err(err) => Ok(map_known_error(err)?), } - .await - }) + } .await } /// If connected, attempts to commit a transaction with id `tx_id` in the core. #[wasm_bindgen(js_name = commitTransaction)] pub async fn commit_transaction(&self, tx_id: String, trace: String) -> Result { - async_panic_to_js_error(async { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; + let inner = self.inner.read().await; + let engine = inner.as_engine()?; - async move { - match engine.executor().commit_tx(TxId::from(tx_id)).await { - Ok(_) => Ok("{}".to_string()), - Err(err) => Ok(map_known_error(err)?), - } + async move { + match engine.executor().commit_tx(TxId::from(tx_id)).await { + Ok(_) => Ok("{}".to_string()), + Err(err) => Ok(map_known_error(err)?), } - .await - }) + } .await } #[wasm_bindgen] pub async fn dmmf(&self, trace: String) -> Result { - async_panic_to_js_error(async { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; + let inner = self.inner.read().await; + let engine = inner.as_engine()?; - let dmmf = dmmf::render_dmmf(&engine.query_schema); + let dmmf = dmmf::render_dmmf(&engine.query_schema); - let json = { - let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); - serde_json::to_string(&dmmf)? - }; + let json = { + let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); + serde_json::to_string(&dmmf)? + }; - Ok(json) - }) - .await + Ok(json) } /// If connected, attempts to roll back a transaction with id `tx_id` in the core. #[wasm_bindgen(js_name = rollbackTransaction)] pub async fn rollback_transaction(&self, tx_id: String, trace: String) -> Result { - async_panic_to_js_error(async { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; + let inner = self.inner.read().await; + let engine = inner.as_engine()?; - async move { - match engine.executor().rollback_tx(TxId::from(tx_id)).await { - Ok(_) => Ok("{}".to_string()), - Err(err) => Ok(map_known_error(err)?), - } + async move { + match engine.executor().rollback_tx(TxId::from(tx_id)).await { + Ok(_) => Ok("{}".to_string()), + Err(err) => Ok(map_known_error(err)?), } - .await - }) + } .await } /// Loads the query schema. Only available when connected. #[wasm_bindgen(js_name = sdlSchema)] pub async fn sdl_schema(&self) -> Result { - async_panic_to_js_error(async move { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; + let inner = self.inner.read().await; + let engine = inner.as_engine()?; - Ok(render_graphql_schema(engine.query_schema())) - }) - .await + Ok(render_graphql_schema(engine.query_schema())) } #[wasm_bindgen] @@ -472,16 +442,3 @@ fn stringify_env_values(origin: serde_json::Value) -> crate::Result(fut: F) -> Result -where - F: Future>, -{ - match AssertUnwindSafe(fut).catch_unwind().await { - Ok(result) => result, - Err(err) => match Error::extract_panic_message(err) { - Some(message) => Err(wasm_bindgen::JsError::new(&format!("PANIC: {message}"))), - None => Err(wasm_bindgen::JsError::new("PANIC: unknown panic")), - }, - } -} From 77e10271d301010a903de9d9a9aad241737178cc Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Wed, 29 Nov 2023 16:34:17 +0100 Subject: [PATCH 098/134] Fix duplicate snapshots in json_filters test --- .../tests/queries/filters/json_filters.rs | 865 +++++++++--------- 1 file changed, 455 insertions(+), 410 deletions(-) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs index e2ab83cfd62f..f3e4026a8678 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs @@ -46,53 +46,68 @@ mod json_filters { create_row(&runner, 4, r#"{ \"a\": { \"b\": [null] } }"#, false).await?; create_row(&runner, 5, r#"{ }"#, false).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b"], equals: "\"c\"" "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: ["a", "b"], equals: "\"c\"" "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b", "0"], equals: "1" "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: ["a", "b", "0"], equals: "1" "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b", "0"], equals: JsonNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: ["a", "b", "0"], equals: JsonNull "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b"], equals: JsonNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":3}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: ["a", "b"], equals: JsonNull "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b"], equals: DbNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":5}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"path: ["a", "b"], equals: DbNull "#, Some(""))); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":5}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: ["a", "b"], equals: AnyNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":3},{"id":5}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: ["a", "b"], equals: AnyNull "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3},{"id":5}]}}"### + ); + } Ok(()) } @@ -120,13 +135,13 @@ mod json_filters { create_row(&runner, 5, r#"{ \"a\": { \"b\": [null] } }"#, false).await?; create_row(&runner, 6, r#"{ }"#, false).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: "$.a.b", equals: "\"c\"" "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"path: "$.a.b", equals: "\"c\"" "#, Some(""))); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### + ); + } insta::assert_snapshot!( run_query!( @@ -136,29 +151,32 @@ mod json_filters { @r###"{"data":{"findManyTestModel":[{"id":2},{"id":3}]}}"### ); - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: "$.a.b[0]", equals: JsonNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":4},{"id":5}]}}"### + let res = run_query!( + runner, + jsonq(&runner, r#"path: "$.a.b[0]", equals: JsonNull "#, Some("")) ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4},{"id":5}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: "$.a.b", equals: DbNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"path: "$.a.b", equals: DbNull "#, Some(""))); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"path: "$.a.b", equals: AnyNull "#, Some("")) - ), - @r###"{"data":{"findManyTestModel":[{"id":4},{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"path: "$.a.b", equals: AnyNull "#, Some(""))); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4},{"id":6}]}}"### + ); + } Ok(()) } @@ -174,36 +192,36 @@ mod json_filters { create_row(&runner, 8, r#"[1, [null], 2]"#, true).await?; // array_contains - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_contains: "[3]""#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_contains: "[\"a\"]""#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_contains: "[3]""#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } + let res = run_query!(runner, jsonq(&runner, r#"array_contains: "[\"a\"]""#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### + ); + } // NOT array_contains - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_contains: "[3]""#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":4},{"id":6},{"id":7},{"id":8}]}}"### - ); - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_contains: "[\"a\"]""#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":6},{"id":7},{"id":8}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_contains: "[3]""#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4},{"id":6},{"id":7},{"id":8}]}}"### + ); + } + let res = run_query!(runner, not_jsonq(&runner, r#"array_contains: "[\"a\"]""#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":6},{"id":7},{"id":8}]}}"### + ); + } // MySQL has slightly different semantics and also coerces null to [null]. is_one_of!( @@ -225,30 +243,30 @@ mod json_filters { match runner.connector_version() { // MariaDB does not support finding arrays in arrays, unlike MySQL ConnectorVersion::MySql(Some(MySqlVersion::MariaDb)) => { - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_contains: "[[1, 2]]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":6},{"id":7},{"id":8}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_contains: "[[1, 2]]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":6},{"id":7},{"id":8}]}}"### + ); + } } _ => { - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_contains: "[[1, 2]]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### - ); - - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_contains: "[[null]]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_contains: "[[1, 2]]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### + ); + } + + let res = run_query!(runner, jsonq(&runner, r#"array_contains: "[[null]]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### + ); + } } } @@ -280,86 +298,86 @@ mod json_filters { create_row(&runner, 8, r#"[null, \"test\"]"#, true).await?; create_row(&runner, 9, r#"[[null], \"test\"]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_starts_with: "3" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_starts_with: "3" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_starts_with: "\"a\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_starts_with: "\"a\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_starts_with: "[1, 2]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_starts_with: "[1, 2]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":6}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_starts_with: "null" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_starts_with: "null" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_starts_with: "[null]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":9}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_starts_with: "[null]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":9}]}}"### + ); + } // NOT - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_starts_with: "3" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":4},{"id":6},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_starts_with: "3" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":4},{"id":6},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_starts_with: "\"a\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":6},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_starts_with: "\"a\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":6},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_starts_with: "[1, 2]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_starts_with: "[1, 2]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_starts_with: "null" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":6},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_starts_with: "null" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":6},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_starts_with: "[null]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":6},{"id":8}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_starts_with: "[null]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":6},{"id":8}]}}"### + ); + } Ok(()) } @@ -387,86 +405,86 @@ mod json_filters { create_row(&runner, 8, r#"[\"test\", null]"#, true).await?; create_row(&runner, 9, r#"[\"test\", [null]]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_ends_with: "3" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_ends_with: "3" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_ends_with: "\"b\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":3}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_ends_with: "\"b\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_ends_with: "[3, 4]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_ends_with: "[3, 4]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_ends_with: "null" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_ends_with: "null" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":8}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"array_ends_with: "[null]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":9}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"array_ends_with: "[null]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":9}]}}"### + ); + } // NOT - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_ends_with: "3" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":2},{"id":3},{"id":4},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_ends_with: "3" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":2},{"id":3},{"id":4},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_ends_with: "\"b\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_ends_with: "\"b\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":4},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_ends_with: "[3, 4]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":8},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_ends_with: "[3, 4]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":8},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_ends_with: "null" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":9}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_ends_with: "null" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":9}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"array_ends_with: "[null]" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":8}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"array_ends_with: "[null]" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":8}]}}"### + ); + } Ok(()) } @@ -490,22 +508,22 @@ mod json_filters { create_row(&runner, 2, r#"\"fool\""#, true).await?; create_row(&runner, 3, r#"[\"foo\"]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"string_contains: "oo" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"string_contains: "oo" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } // NOT - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"string_contains: "ab" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"string_contains: "ab" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } Ok(()) } @@ -530,22 +548,22 @@ mod json_filters { create_row(&runner, 3, r#"[\"foo\"]"#, true).await?; // string_starts_with - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"string_starts_with: "foo" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"string_starts_with: "foo" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } // NOT string_starts_with - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"string_starts_with: "ab" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"string_starts_with: "ab" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } Ok(()) } @@ -569,22 +587,22 @@ mod json_filters { create_row(&runner, 2, r#"\"fool\""#, true).await?; create_row(&runner, 3, r#"[\"foo\"]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"string_ends_with: "oo" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"string_ends_with: "oo" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1}]}}"### + ); + } // NOT - insta::assert_snapshot!( - run_query!( - runner, - not_jsonq(&runner, r#"string_ends_with: "oo" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### - ); + let res = run_query!(runner, not_jsonq(&runner, r#"string_ends_with: "oo" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### + ); + } Ok(()) } @@ -612,37 +630,37 @@ mod json_filters { create_row(&runner, 6, r#"100"#, true).await?; create_row(&runner, 7, r#"[\"foo\"]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"gt: "\"b\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"gt: "\"b\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"gte: "\"b\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"gte: "\"b\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"gt: "1" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":4},{"id":5},{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"gt: "1" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":4},{"id":5},{"id":6}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"gte: "1" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5},{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"gte: "1" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5},{"id":6}]}}"### + ); + } Ok(()) } @@ -693,37 +711,37 @@ mod json_filters { create_row(&runner, 6, r#"100"#, true).await?; create_row(&runner, 7, r#"[\"foo\"]"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"lt: "\"f\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"lt: "\"f\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"lte: "\"foo\"" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"lte: "\"foo\"" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":1},{"id":2}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"lt: "100" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"lt: "100" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5}]}}"### + ); + } - insta::assert_snapshot!( - run_query!( - runner, - jsonq(&runner, r#"lte: "100" "#, None) - ), - @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5},{"id":6}]}}"### - ); + let res = run_query!(runner, jsonq(&runner, r#"lte: "100" "#, None)); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"id":3},{"id":4},{"id":5},{"id":6}]}}"### + ); + } Ok(()) } @@ -753,70 +771,97 @@ mod json_filters { create_row(&runner, 6, r#"2.4"#, true).await?; create_row(&runner, 7, r#"3"#, true).await?; - insta::assert_snapshot!( - run_query!( - runner, - format!(r#"query {{ - findManyTestModel( - where: {{ json: {{ {}, array_contains: "3", array_starts_with: "3" }} }}, - cursor: {{ id: 2 }}, - take: 2 - ) {{ json }} - }}"#, json_path(&runner)) - ), - @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":[3,4,5]}}"},{"json":"{\"a\":{\"b\":[3,4,6]}}"}]}}"### - ); - insta::assert_snapshot!( - run_query!( - runner, - format!(r#"query {{ - findManyTestModel( - where: {{ - AND: [ - {{ json: {{ {}, gte: "1" }} }}, - {{ json: {{ {}, lt: "3" }} }}, - ] - }} - ) {{ json }} - }}"#, json_path(&runner), json_path(&runner)) - ), - @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":1}}"},{"json":"{\"a\":{\"b\":2.4}}"}]}}"### - ); + let res = run_query!( + runner, + format!( + r#"query {{ + findManyTestModel( + where: {{ json: {{ {}, array_contains: "3", array_starts_with: "3" }} }}, + cursor: {{ id: 2 }}, + take: 2 + ) {{ json }} + }}"#, + json_path(&runner) + ) + ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":[3,4,5]}}"},{"json":"{\"a\":{\"b\":[3,4,6]}}"}]}}"### + ); + } + + let res = run_query!( + runner, + format!( + r#"query {{ + findManyTestModel( + where: {{ + AND: [ + {{ json: {{ {}, gte: "1" }} }}, + {{ json: {{ {}, lt: "3" }} }}, + ] + }} + ) {{ json }} + }}"#, + json_path(&runner), + json_path(&runner) + ) + ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":1}}"},{"json":"{\"a\":{\"b\":2.4}}"}]}}"### + ); + } // NOT - insta::assert_snapshot!( - run_query!( - runner, - format!(r#"query {{ - findManyTestModel( - where: {{ NOT: {{ json: {{ {}, array_contains: "3", array_starts_with: "3" }} }} }}, - cursor: {{ id: 2 }}, - take: 2 - ) {{ json }} - }}"#, json_path(&runner)) - ), - @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":[5,6,7]}}"}]}}"### - ); + let res = run_query!( + runner, + format!( + r#"query {{ + findManyTestModel( + where: {{ NOT: {{ json: {{ {}, array_contains: "3", array_starts_with: "3" }} }} }}, + cursor: {{ id: 2 }}, + take: 2 + ) {{ json }} + }}"#, + json_path(&runner) + ) + ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":[5,6,7]}}"}]}}"### + ); + } // 1, 2.4, 3 // filter: false, true, false // negated: true, false, true // result: 1, 3 - insta::assert_snapshot!( - run_query!( - runner, - format!(r#"query {{ - findManyTestModel( - where: {{ - NOT: {{ AND: [ - {{ json: {{ {}, gt: "1" }} }}, - {{ json: {{ {}, lt: "3" }} }}, - ]}} - }} - ) {{ json }} - }}"#, json_path(&runner), json_path(&runner)) - ), - @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":1}}"},{"json":"{\"a\":{\"b\":3}}"}]}}"### - ); + let res = run_query!( + runner, + format!( + r#"query {{ + findManyTestModel( + where: {{ + NOT: {{ AND: [ + {{ json: {{ {}, gt: "1" }} }}, + {{ json: {{ {}, lt: "3" }} }}, + ]}} + }} + ) {{ json }} + }}"#, + json_path(&runner), + json_path(&runner) + ) + ); + insta::allow_duplicates! { + insta::assert_snapshot!( + res, + @r###"{"data":{"findManyTestModel":[{"json":"{\"a\":{\"b\":1}}"},{"json":"{\"a\":{\"b\":3}}"}]}}"### + ); + } Ok(()) } From f63661b544c9b037a71a9d62b25690a8b6dbb53a Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Thu, 30 Nov 2023 11:17:58 +0100 Subject: [PATCH 099/134] Size low hanging fruits Removes following functionality from WASM engine: - GraphQL protocol - DMMF - SDL Schema Neither of the features are used by the client runtimes and thrid party clients don't and can not use WASM engine, so it is safe to remove. --- query-engine/core/Cargo.toml | 5 ++- query-engine/core/src/protocol.rs | 3 ++ query-engine/core/src/response_ir/internal.rs | 2 ++ .../query-engine-node-api/src/functions.rs | 1 - query-engine/query-engine-wasm/Cargo.toml | 2 +- .../query-engine-wasm/src/wasm/engine.rs | 33 +------------------ .../query-engine-wasm/src/wasm/functions.rs | 18 ---------- query-engine/request-handlers/Cargo.toml | 12 +++++-- query-engine/request-handlers/src/lib.rs | 4 ++- .../request-handlers/src/protocols/mod.rs | 6 ++++ 10 files changed, 29 insertions(+), 57 deletions(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 1b7c52e59de9..da9e8331dfdf 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -5,6 +5,7 @@ version = "0.1.0" [features] metrics = ["query-engine-metrics"] +graphql-protocol = [] [dependencies] async-trait = "0.1" @@ -19,7 +20,9 @@ indexmap = { version = "1.7", features = ["serde-1"] } itertools = "0.10" once_cell = "1" petgraph = "0.4" -query-structure = { path = "../query-structure", features = ["default_generators"] } +query-structure = { path = "../query-structure", features = [ + "default_generators", +] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } query-engine-metrics = { path = "../metrics", optional = true } serde.workspace = true diff --git a/query-engine/core/src/protocol.rs b/query-engine/core/src/protocol.rs index 75e8dbc0fd70..e92438d5e92d 100644 --- a/query-engine/core/src/protocol.rs +++ b/query-engine/core/src/protocol.rs @@ -3,6 +3,7 @@ use serde::Deserialize; #[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "camelCase")] pub enum EngineProtocol { + #[cfg(feature = "graphql-protocol")] Graphql, Json, } @@ -14,6 +15,7 @@ impl EngineProtocol { } /// Returns `true` if the engine protocol is [`Graphql`]. + #[cfg(feature = "graphql-protocol")] pub fn is_graphql(&self) -> bool { matches!(self, Self::Graphql) } @@ -22,6 +24,7 @@ impl EngineProtocol { impl From<&String> for EngineProtocol { fn from(s: &String) -> Self { match s.as_str() { + #[cfg(feature = "graphql-protocol")] "graphql" => EngineProtocol::Graphql, "json" => EngineProtocol::Json, x => panic!("Unknown engine protocol '{x}'. Must be 'graphql' or 'json'."), diff --git a/query-engine/core/src/response_ir/internal.rs b/query-engine/core/src/response_ir/internal.rs index 7becb19e768b..bbda8d7bd05d 100644 --- a/query-engine/core/src/response_ir/internal.rs +++ b/query-engine/core/src/response_ir/internal.rs @@ -552,11 +552,13 @@ fn serialize_scalar(field: &OutputField<'_>, value: PrismaValue) -> crate::Resul fn convert_prisma_value(field: &OutputField<'_>, value: PrismaValue, st: &ScalarType) -> crate::Result { match crate::executor::get_engine_protocol() { + #[cfg(feature = "graphql-protocol")] EngineProtocol::Graphql => convert_prisma_value_graphql_protocol(field, value, st), EngineProtocol::Json => convert_prisma_value_json_protocol(field, value, st), } } +#[cfg(feature = "graphql-protocol")] fn convert_prisma_value_graphql_protocol( field: &OutputField<'_>, value: PrismaValue, diff --git a/query-engine/query-engine-node-api/src/functions.rs b/query-engine/query-engine-node-api/src/functions.rs index 868178f7361d..bcb64e240cae 100644 --- a/query-engine/query-engine-node-api/src/functions.rs +++ b/query-engine/query-engine-node-api/src/functions.rs @@ -1,4 +1,3 @@ -use crate::error::ApiError; use napi_derive::napi; use request_handlers::dmmf; use std::sync::Arc; diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 6d3dc33050b7..e08d412d5f97 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [lib] doc = false crate-type = ["cdylib"] -name = "query_engine" +name = "query_engine_wasm" [dependencies] diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 19f442106b97..a40dfbd4ff94 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -14,7 +14,7 @@ use query_core::{ telemetry, QueryExecutor, TransactionOptions, TxId, }; use request_handlers::ConnectorMode; -use request_handlers::{dmmf, load_executor, render_graphql_schema, RequestBody, RequestHandler}; +use request_handlers::{load_executor, RequestBody, RequestHandler}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::{ @@ -385,28 +385,6 @@ impl QueryEngine { .await } - #[wasm_bindgen] - pub async fn dmmf(&self, trace: String) -> Result { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; - - let dispatcher = self.logger.dispatcher(); - - tracing::dispatcher::with_default(&dispatcher, || { - let span = tracing::info_span!("prisma:engine:dmmf"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); - let _guard = span.enter(); - let dmmf = dmmf::render_dmmf(&engine.query_schema); - - let json = { - let _span = tracing::info_span!("prisma:engine:dmmf_to_json").entered(); - serde_json::to_string(&dmmf)? - }; - - Ok(json) - }) - } - /// If connected, attempts to roll back a transaction with id `tx_id` in the core. #[wasm_bindgen(js_name = rollbackTransaction)] pub async fn rollback_transaction(&self, tx_id: String, trace: String) -> Result { @@ -425,15 +403,6 @@ impl QueryEngine { .await } - /// Loads the query schema. Only available when connected. - #[wasm_bindgen(js_name = sdlSchema)] - pub async fn sdl_schema(&self) -> Result { - let inner = self.inner.read().await; - let engine = inner.as_engine()?; - - Ok(render_graphql_schema(engine.query_schema())) - } - #[wasm_bindgen] pub async fn metrics(&self, json_options: String) -> Result<(), wasm_bindgen::JsError> { log::info!("Called `QueryEngine::metrics()`"); diff --git a/query-engine/query-engine-wasm/src/wasm/functions.rs b/query-engine/query-engine-wasm/src/wasm/functions.rs index 9767b22fb811..5aa2a8d6ba2a 100644 --- a/query-engine/query-engine-wasm/src/wasm/functions.rs +++ b/query-engine/query-engine-wasm/src/wasm/functions.rs @@ -1,7 +1,4 @@ -use crate::error::ApiError; -use request_handlers::dmmf; use serde::Serialize; -use std::sync::Arc; use tsify::Tsify; use wasm_bindgen::prelude::wasm_bindgen; @@ -21,21 +18,6 @@ pub fn version() -> Version { } } -#[wasm_bindgen] -pub fn dmmf(datamodel_string: String) -> Result { - let mut schema = psl::validate(datamodel_string.into()); - - schema - .diagnostics - .to_result() - .map_err(|errors| ApiError::conversion(errors, schema.db.source()))?; - - let query_schema = query_core::schema::build(Arc::new(schema), true); - let dmmf = dmmf::render_dmmf(&query_schema); - - Ok(serde_json::to_string(&dmmf)?) -} - #[wasm_bindgen] pub fn debug_panic(panic_message: Option) -> Result<(), wasm_bindgen::JsError> { let user_facing = user_facing_errors::Error::from_panic_payload(Box::new( diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index 51ed4bd8b5ad..3686f14154af 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -20,7 +20,7 @@ bigdecimal = "0.3" thiserror = "1" tracing = "0.1" url = "2" -connection-string.workspace = true +connection-string.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } @@ -32,11 +32,17 @@ schema = { path = "../schema" } codspeed-criterion-compat = "1.1.0" [features] -default = ["sql", "mongodb", "native"] +default = ["sql", "mongodb", "native", "graphql-protocol"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] driver-adapters = ["sql-query-connector/driver-adapters"] -native = ["mongodb", "sql-query-connector", "quaint/native", "query-core/metrics"] +native = [ + "mongodb", + "sql-query-connector", + "quaint/native", + "query-core/metrics", +] +graphql-protocol = ["query-core/graphql-protocol"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/lib.rs b/query-engine/request-handlers/src/lib.rs index 361e5c628bdf..949c26b302f3 100644 --- a/query-engine/request-handlers/src/lib.rs +++ b/query-engine/request-handlers/src/lib.rs @@ -12,7 +12,9 @@ mod response; pub use self::{error::HandlerError, load_executor::load as load_executor}; pub use connector_mode::ConnectorMode; pub use handler::*; -pub use protocols::{graphql::*, json::*, RequestBody}; +#[cfg(feature = "graphql-protocol")] +pub use protocols::graphql::*; +pub use protocols::{json::*, RequestBody}; pub use response::*; pub type Result = std::result::Result; diff --git a/query-engine/request-handlers/src/protocols/mod.rs b/query-engine/request-handlers/src/protocols/mod.rs index e2c50c2e7f1f..93bac460fecb 100644 --- a/query-engine/request-handlers/src/protocols/mod.rs +++ b/query-engine/request-handlers/src/protocols/mod.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "graphql-protocol")] pub mod graphql; pub mod json; @@ -5,6 +6,7 @@ use query_core::{protocol::EngineProtocol, schema::QuerySchemaRef, QueryDocument #[derive(Debug)] pub enum RequestBody { + #[cfg(feature = "graphql-protocol")] Graphql(graphql::GraphqlBody), Json(json::JsonBody), } @@ -12,6 +14,7 @@ pub enum RequestBody { impl RequestBody { pub fn into_doc(self, query_schema: &QuerySchemaRef) -> crate::Result { match self { + #[cfg(feature = "graphql-protocol")] RequestBody::Graphql(body) => body.into_doc(), RequestBody::Json(body) => body.into_doc(query_schema), } @@ -19,6 +22,7 @@ impl RequestBody { pub fn try_from_str(val: &str, engine_protocol: EngineProtocol) -> Result { match engine_protocol { + #[cfg(feature = "graphql-protocol")] EngineProtocol::Graphql => serde_json::from_str::(val).map(Self::from), EngineProtocol::Json => serde_json::from_str::(val).map(Self::from), } @@ -26,12 +30,14 @@ impl RequestBody { pub fn try_from_slice(val: &[u8], engine_protocol: EngineProtocol) -> Result { match engine_protocol { + #[cfg(feature = "graphql-protocol")] EngineProtocol::Graphql => serde_json::from_slice::(val).map(Self::from), EngineProtocol::Json => serde_json::from_slice::(val).map(Self::from), } } } +#[cfg(feature = "graphql-protocol")] impl From for RequestBody { fn from(body: graphql::GraphqlBody) -> Self { Self::Graphql(body) From c592c88b7fe67cee5bd9d53114121047513ab866 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 11:40:20 +0100 Subject: [PATCH 100/134] feat(driver-adapters): serialize empty values as "null" rather than "undefined" --- .../driver-adapters/src/wasm/async_js_function.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 289de651ff64..8a176976886f 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -1,5 +1,6 @@ use js_sys::{Function as JsFunction, Promise as JsPromise}; use serde::Serialize; +use serde_wasm_bindgen::Serializer; use std::marker::PhantomData; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::describe::WasmDescribe; @@ -10,6 +11,11 @@ use super::error::into_quaint_error; use super::from_js::FromJsValue; use super::result::JsResult; +// `serialize_missing_as_null` is required to make sure that "empty" values (e.g., `None` and `()`) +// are serialized as `null` and not `undefined`. +// This is due to certain drivers (e.g., LibSQL) not supporting `undefined` values. +static SERIALIZER: Serializer = Serializer::new().serialize_missing_as_null(true); + #[derive(Clone)] pub(crate) struct AsyncJsFunction where @@ -61,7 +67,9 @@ where } async fn call_internal(&self, arg1: T) -> Result, JsValue> { - let arg1 = serde_wasm_bindgen::to_value(&arg1).map_err(|err| JsValue::from(JsError::from(&err)))?; + let arg1 = arg1 + .serialize(&SERIALIZER) + .map_err(|err| JsValue::from(JsError::from(&err)))?; let return_value = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; let value = if let Some(promise) = return_value.dyn_ref::() { From 8b920f586df43c992d230589240ffe190471817e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:08:09 +0100 Subject: [PATCH 101/134] chore: fixed query-engine-node-api build --- query-engine/query-engine-node-api/src/functions.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/query-engine/query-engine-node-api/src/functions.rs b/query-engine/query-engine-node-api/src/functions.rs index bcb64e240cae..5178d82d6120 100644 --- a/query-engine/query-engine-node-api/src/functions.rs +++ b/query-engine/query-engine-node-api/src/functions.rs @@ -2,6 +2,8 @@ use napi_derive::napi; use request_handlers::dmmf; use std::sync::Arc; +use crate::error::ApiError; + #[derive(serde::Serialize, Clone, Copy)] #[napi(object)] pub struct Version { From a34f74e9731813de6aeefc9ead312d28d51b68e1 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:10:14 +0100 Subject: [PATCH 102/134] chore: bumped wasm-bindgen version to 0.2.89 --- Cargo.lock | 22 +++++++++++----------- Cargo.toml | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f7ca06607f0..6b0d781666ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5812,7 +5812,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.3.23", + "rand 0.8.5", "static_assertions", ] @@ -6052,9 +6052,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -6062,9 +6062,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", @@ -6089,9 +6089,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6099,9 +6099,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", @@ -6112,9 +6112,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-bindgen-test" diff --git a/Cargo.toml b/Cargo.toml index 90504a422332..444a48f012f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ js-sys = { version = "0.3" } serde_repr = { version = "0.1.17" } serde-wasm-bindgen = { version = "0.5" } tsify = { version = "0.4.5" } -wasm-bindgen = { version = "0.2.88" } +wasm-bindgen = { version = "0.2.89" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2" } wasm-bindgen-test = { version = "0.3.0" } From c87bbb2f10ef55754f2bc9b12b3f8eae30fed10f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:32:10 +0100 Subject: [PATCH 103/134] chore: clean up quaint transitive dependencies --- Cargo.lock | 3 --- quaint/src/lib.rs | 3 ++- query-engine/driver-adapters/Cargo.toml | 5 ----- .../driver-adapters/src/conversion/js_to_quaint.rs | 9 +++------ query-engine/driver-adapters/src/conversion/postgres.rs | 2 +- 5 files changed, 6 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b0d781666ab..0ae9b7b97da7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1107,15 +1107,12 @@ name = "driver-adapters" version = "0.1.0" dependencies = [ "async-trait", - "bigdecimal", - "chrono", "expect-test", "futures", "js-sys", "metrics 0.18.1", "napi", "napi-derive", - "num-bigint", "once_cell", "pin-project", "quaint", diff --git a/quaint/src/lib.rs b/quaint/src/lib.rs index 1458a6ae1615..45c2a10a1698 100644 --- a/quaint/src/lib.rs +++ b/quaint/src/lib.rs @@ -113,7 +113,8 @@ mod macros; #[macro_use] extern crate metrics; -extern crate bigdecimal; +pub extern crate bigdecimal; +pub extern crate chrono; pub mod ast; pub mod connector; diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index befd68125880..ef2de701312d 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -16,11 +16,6 @@ pin-project = "1" wasm-rs-dbg = "0.1.2" serde_repr.workspace = true -# Note: these deps are temporarily specified here to avoid importing them from tiberius (the SQL server driver). -# They will be imported from quaint-core instead in a future PR. -num-bigint = "0.4.3" -bigdecimal = "0.3.0" -chrono = "0.4.20" futures = "0.3" [dev-dependencies] diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index f0b7de772c5e..02c06e122c36 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -2,17 +2,14 @@ use std::borrow::Cow; use std::str::FromStr; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; +use quaint::bigdecimal::{BigDecimal, FromPrimitive}; +use quaint::chrono::{DateTime, NaiveDate, NaiveTime, Utc}; use quaint::{ connector::ResultSet as QuaintResultSet, error::{Error as QuaintError, ErrorKind}, Value as QuaintValue, }; -// TODO(jkomyno): import these 3rd-party crates from the `quaint-core` crate. -use bigdecimal::{BigDecimal, FromPrimitive}; -use chrono::{DateTime, Utc}; -use chrono::{NaiveDate, NaiveTime}; - impl TryFrom for QuaintResultSet { type Error = quaint::error::Error; @@ -210,7 +207,7 @@ pub fn js_value_to_quaint( }, ColumnType::DateTime => match json_value { // TODO: change parsing order to prefer RFC3339 - serde_json::Value::String(s) => chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") + serde_json::Value::String(s) => quaint::chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") .map(|dt| DateTime::from_utc(dt, Utc)) .or_else(|_| DateTime::parse_from_rfc3339(&s).map(DateTime::::from)) .map(QuaintValue::datetime) diff --git a/query-engine/driver-adapters/src/conversion/postgres.rs b/query-engine/driver-adapters/src/conversion/postgres.rs index 113be5170a84..aaad219e9349 100644 --- a/query-engine/driver-adapters/src/conversion/postgres.rs +++ b/query-engine/driver-adapters/src/conversion/postgres.rs @@ -1,6 +1,6 @@ use crate::conversion::JSArg; -use chrono::format::StrftimeItems; use once_cell::sync::Lazy; +use quaint::chrono::format::StrftimeItems; use serde_json::value::Value as JsonValue; static TIME_FMT: Lazy = Lazy::new(|| StrftimeItems::new("%H:%M:%S%.f")); From 60b0d8c13f3bfd3df1492a647f6cd193defb23d2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:34:14 +0100 Subject: [PATCH 104/134] chore: removed wasm.rs test --- Cargo.lock | 31 --- query-engine/driver-adapters/Cargo.toml | 1 - query-engine/driver-adapters/tests/wasm.rs | 275 --------------------- 3 files changed, 307 deletions(-) delete mode 100644 query-engine/driver-adapters/tests/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 0ae9b7b97da7..2a0ec128a8c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1127,7 +1127,6 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-bindgen-test", "wasm-rs-dbg", ] @@ -4534,12 +4533,6 @@ dependencies = [ "user-facing-errors", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.2.0" @@ -6113,30 +6106,6 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" -[[package]] -name = "wasm-bindgen-test" -version = "0.3.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db36fc0f9fb209e88fb3642590ae0205bb5a56216dabd963ba15879fe53a30b" -dependencies = [ - "console_error_panic_hook", - "js-sys", - "scoped-tls", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-bindgen-test-macro", -] - -[[package]] -name = "wasm-bindgen-test-macro" -version = "0.3.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734759ae6b3b1717d661fe4f016efcfb9828f5edb4520c18eaee05af3b43be9" -dependencies = [ - "proc-macro2", - "quote", -] - [[package]] name = "wasm-logger" version = "0.2.0" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index ef2de701312d..24f69164f138 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -22,7 +22,6 @@ futures = "0.3" expect-test = "1" tokio = { version = "1.0", features = ["macros", "time", "sync"] } wasm-rs-dbg.workspace = true -wasm-bindgen-test.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] napi.workspace = true diff --git a/query-engine/driver-adapters/tests/wasm.rs b/query-engine/driver-adapters/tests/wasm.rs deleted file mode 100644 index 8f3aa30f7335..000000000000 --- a/query-engine/driver-adapters/tests/wasm.rs +++ /dev/null @@ -1,275 +0,0 @@ -#![cfg(target_os = "wasm32")] -use wasm_bindgen_test::*; - -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; -use tsify::Tsify; -use wasm_bindgen::prelude::*; - -// Recursive expansion of Deserialize macro -// ========================================= -// -// #[doc(hidden)] -// #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] -// const _: () = { -// #[allow(unused_extern_crates, clippy::useless_attribute)] -// extern crate serde as _serde; -// #[automatically_derived] -// impl<'de> _serde::Deserialize<'de> for ColumnType { -// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result -// where -// __D: _serde::Deserializer<'de>, -// { -// #[allow(non_camel_case_types)] -// #[doc(hidden)] -// enum __Field { -// __field0, -// __field1, -// } -// #[doc(hidden)] -// struct __FieldVisitor; - -// impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { -// type Value = __Field; -// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { -// _serde::__private::Formatter::write_str(__formatter, "variant identifier") -// } -// fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// 0u64 => _serde::__private::Ok(__Field::__field0), -// 1u64 => _serde::__private::Ok(__Field::__field1), -// _ => _serde::__private::Err(_serde::de::Error::invalid_value( -// _serde::de::Unexpected::Unsigned(__value), -// &"variant index 0 <= i < 2", -// )), -// } -// } -// fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// "Int32" => _serde::__private::Ok(__Field::__field0), -// "Int64" => _serde::__private::Ok(__Field::__field1), -// _ => _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)), -// } -// } -// fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// b"Int32" => _serde::__private::Ok(__Field::__field0), -// b"Int64" => _serde::__private::Ok(__Field::__field1), -// _ => { -// let __value = &_serde::__private::from_utf8_lossy(__value); -// _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) -// } -// } -// } -// } -// impl<'de> _serde::Deserialize<'de> for __Field { -// #[inline] -// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result -// where -// __D: _serde::Deserializer<'de>, -// { -// _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) -// } -// } -// #[doc(hidden)] -// struct __Visitor<'de> { -// marker: _serde::__private::PhantomData, -// lifetime: _serde::__private::PhantomData<&'de ()>, -// } -// impl<'de> _serde::de::Visitor<'de> for __Visitor<'de> { -// type Value = ColumnType; -// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { -// _serde::__private::Formatter::write_str(__formatter, "enum ColumnType") -// } -// fn visit_enum<__A>(self, __data: __A) -> _serde::__private::Result -// where -// __A: _serde::de::EnumAccess<'de>, -// { -// match _serde::de::EnumAccess::variant(__data)? { -// (__Field::__field0, __variant) => { -// _serde::de::VariantAccess::unit_variant(__variant)?; -// _serde::__private::Ok(ColumnType::Int32) -// } -// (__Field::__field1, __variant) => { -// _serde::de::VariantAccess::unit_variant(__variant)?; -// _serde::__private::Ok(ColumnType::Int64) -// } -// } -// } -// } -// #[doc(hidden)] -// const VARIANTS: &'static [&'static str] = &["Int32", "Int64"]; -// _serde::Deserializer::deserialize_enum( -// __deserializer, -// "ColumnType", -// VARIANTS, -// __Visitor { -// marker: _serde::__private::PhantomData::, -// lifetime: _serde::__private::PhantomData, -// }, -// ) -// } -// } -// }; -// -// -// Recursive expansion of Tsify macro -// =================================== -// -// #[automatically_derived] -// const _: () = { -// extern crate serde as _serde; -// use tsify::Tsify; -// use wasm_bindgen::{ -// convert::{FromWasmAbi, IntoWasmAbi, OptionFromWasmAbi, OptionIntoWasmAbi}, -// describe::WasmDescribe, -// prelude::*, -// }; -// #[wasm_bindgen] -// extern "C" { -// #[wasm_bindgen(typescript_type = "ColumnType")] -// pub type JsType; -// } -// impl Tsify for ColumnType { -// type JsType = JsType; -// const DECL: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; -// } -// #[wasm_bindgen(typescript_custom_section)] -// const TS_APPEND_CONTENT: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; -// impl WasmDescribe for ColumnType { -// #[inline] -// fn describe() { -// ::JsType::describe() -// } -// } -// impl IntoWasmAbi for ColumnType -// where -// Self: _serde::Serialize, -// { -// type Abi = ::Abi; -// #[inline] -// fn into_abi(self) -> Self::Abi { -// self.into_js().unwrap_throw().into_abi() -// } -// } -// impl OptionIntoWasmAbi for ColumnType -// where -// Self: _serde::Serialize, -// { -// #[inline] -// fn none() -> Self::Abi { -// ::none() -// } -// } -// impl FromWasmAbi for ColumnType -// where -// Self: _serde::de::DeserializeOwned, -// { -// type Abi = ::Abi; -// #[inline] -// unsafe fn from_abi(js: Self::Abi) -> Self { -// let result = Self::from_js(&JsType::from_abi(js)); -// if let Err(err) = result { -// wasm_bindgen::throw_str(err.to_string().as_ref()); -// } -// result.unwrap_throw() -// } -// } -// impl OptionFromWasmAbi for ColumnType -// where -// Self: _serde::de::DeserializeOwned, -// { -// #[inline] -// fn is_none(js: &Self::Abi) -> bool { -// ::is_none(js) -// } -// } -// }; -#[derive(Clone, Copy, Debug, Deserialize, Tsify)] -#[tsify(from_wasm_abi)] -pub enum ColumnType { - Int32 = 0, - Int64 = 1, -} - -#[derive(Debug, Deserialize, Tsify)] -#[tsify(from_wasm_abi)] -#[serde(rename_all = "camelCase")] -struct ColumnTypeWrapper { - column_type: ColumnType, -} - -// Recursive expansion of Deserialize_repr macro -// ============================================== -// -// impl<'de> serde::Deserialize<'de> for ColumnTypeWasmBindgen { -// #[allow(clippy::use_self)] -// fn deserialize(deserializer: D) -> ::core::result::Result -// where -// D: serde::Deserializer<'de>, -// { -// #[allow(non_camel_case_types)] -// struct discriminant; - -// #[allow(non_upper_case_globals)] -// impl discriminant { -// const Int32: u8 = ColumnTypeWasmBindgen::Int32 as u8; -// const Int64: u8 = ColumnTypeWasmBindgen::Int64 as u8; -// } -// match ::deserialize(deserializer)? { -// discriminant::Int32 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int32), -// discriminant::Int64 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int64), -// other => ::core::result::Result::Err(serde::de::Error::custom(format_args!( -// "invalid value: {}, expected {} or {}", -// other, -// discriminant::Int32, -// discriminant::Int64 -// ))), -// } -// } -// } -#[derive(Debug, Deserialize_repr, Tsify)] -#[tsify(from_wasm_abi)] -#[repr(u8)] -pub enum ColumnTypeWasmBindgen { - // #[serde(rename = "0")] - Int32 = 0, - - // #[serde(rename = "1")] - Int64 = 1, -} - -#[wasm_bindgen_test] -fn column_type_test() { - // Example deserialization code - let json_data = r#"0"#; - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - - // let json_data = "\"0\""; - let column_type = serde_json::from_str::(&json_data).unwrap(); -} - -// #[wasm_bindgen_test] -// fn column_type_test() { -// // Example deserialization code -// let json_data = r#"{ "columnType": 0 }"#; -// let column_type_wrapper = serde_json::from_str::(json_data); - -// panic!("{:?}", column_type_wrapper); -// } From 48770da119af66a68cf8ca6180b6859c27dc4740 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:37:16 +0100 Subject: [PATCH 105/134] chore: removed temporary wasm machinery --- Cargo.lock | 13 ------------- query-engine/query-engine-wasm/Cargo.toml | 3 --- query-engine/query-engine-wasm/src/lib.rs | 12 ------------ query-engine/query-engine-wasm/src/wasm/engine.rs | 1 - 4 files changed, 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a0ec128a8c9..93bf18b7be68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -673,16 +673,6 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "console_error_panic_hook" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" -dependencies = [ - "cfg-if", - "wasm-bindgen", -] - [[package]] name = "convert_case" version = "0.4.0" @@ -3821,11 +3811,9 @@ dependencies = [ "anyhow", "async-trait", "connection-string", - "console_error_panic_hook", "driver-adapters", "futures", "js-sys", - "log", "opentelemetry", "psl", "quaint", @@ -3848,7 +3836,6 @@ dependencies = [ "user-facing-errors", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-logger", "wasm-rs-dbg", ] diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index e08d412d5f97..171610f2831a 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -40,12 +40,9 @@ url = "2" serde.workspace = true tokio = { version = "1.25", features = ["macros", "sync", "io-util", "time"] } futures = "0.3" -log = "0.4.6" -wasm-logger = "0.2.0" tracing = "0.1" tracing-subscriber = { version = "0.3" } tracing-futures = "0.2" tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17"} -console_error_panic_hook = "0.1.7" diff --git a/query-engine/query-engine-wasm/src/lib.rs b/query-engine/query-engine-wasm/src/lib.rs index 2d2167f7c22e..bc22931513e8 100644 --- a/query-engine/query-engine-wasm/src/lib.rs +++ b/query-engine/query-engine-wasm/src/lib.rs @@ -17,18 +17,6 @@ mod arch { pub use super::wasm::*; pub(crate) type Result = std::result::Result; - - use wasm_bindgen::prelude::wasm_bindgen; - - /// Function that should be called before any other public function in this module. - #[wasm_bindgen] - pub fn init() { - // Set up temporary logging for the wasm module. - wasm_logger::init(wasm_logger::Config::default()); - - // Set up temporary panic hook for the wasm module. - std::panic::set_hook(Box::new(console_error_panic_hook::hook)); - } } pub use arch::*; diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index a40dfbd4ff94..3413a8af8d76 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -405,7 +405,6 @@ impl QueryEngine { #[wasm_bindgen] pub async fn metrics(&self, json_options: String) -> Result<(), wasm_bindgen::JsError> { - log::info!("Called `QueryEngine::metrics()`"); Err(ApiError::configuration("Metrics is not enabled in Wasm.").into()) } } From 57ce97710bc2495234ad514318acc2008ab1ba89 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:39:06 +0100 Subject: [PATCH 106/134] chore: fix clippy --- libs/crosstarget-utils/src/wasm/time.rs | 4 ++-- query-engine/core/src/executor/execute_operation.rs | 1 + query-engine/query-engine-wasm/src/wasm/logger.rs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs index 6f14ac001ee8..1a8833b80134 100644 --- a/libs/crosstarget-utils/src/wasm/time.rs +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -35,7 +35,7 @@ impl ElapsedTimeCounter { } } -pub async fn sleep(duration: Duration) -> () { +pub async fn sleep(duration: Duration) { JsFuture::from(Promise::new(&mut |resolve, _reject| { set_timeout(&resolve, duration.as_millis() as u32); })) @@ -55,5 +55,5 @@ where } fn now() -> f64 { - PERFORMANCE.as_ref().map(|p| p.now()).unwrap_or_else(|| Date::now()) + PERFORMANCE.as_ref().map(|p| p.now()).unwrap_or_else(Date::now) } diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 63555187fb7b..dabe071cd688 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -1,4 +1,5 @@ #![cfg_attr(target_arch = "wasm32", allow(unused_variables))] +#![cfg_attr(not(feature = "metrics"), allow(clippy::let_and_return))] use super::pipeline::QueryPipeline; use crate::{ diff --git a/query-engine/query-engine-wasm/src/wasm/logger.rs b/query-engine/query-engine-wasm/src/wasm/logger.rs index a4d03a83e82d..c0ccbf7f2a3e 100644 --- a/query-engine/query-engine-wasm/src/wasm/logger.rs +++ b/query-engine/query-engine-wasm/src/wasm/logger.rs @@ -146,6 +146,6 @@ impl Layer for CallbackLayer { let mut visitor = JsonVisitor::new(event.metadata().level(), event.metadata().target()); event.record(&mut visitor); - let _ = self.callback.call(&visitor.to_string()); + let _ = self.callback.call(visitor.to_string()); } } From 4d92a88290d1adf62efb551c5a850a2291bfb97b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:40:08 +0100 Subject: [PATCH 107/134] chore: remove unwrap from "sleep" --- libs/crosstarget-utils/src/wasm/time.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs index 1a8833b80134..1c230ba1eecc 100644 --- a/libs/crosstarget-utils/src/wasm/time.rs +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -36,12 +36,10 @@ impl ElapsedTimeCounter { } pub async fn sleep(duration: Duration) { - JsFuture::from(Promise::new(&mut |resolve, _reject| { + let _ = JsFuture::from(Promise::new(&mut |resolve, _reject| { set_timeout(&resolve, duration.as_millis() as u32); })) - .await - // TODO: - .unwrap(); + .await; } pub async fn timeout(duration: Duration, future: F) -> Result From ff7046af7279611c106a93ccf81718de8adc82c8 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:43:14 +0100 Subject: [PATCH 108/134] chore: revert unnecessary psl change --- psl/psl-core/src/datamodel_connector.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/psl/psl-core/src/datamodel_connector.rs b/psl/psl-core/src/datamodel_connector.rs index dc3a7e80bd10..72671e06688f 100644 --- a/psl/psl-core/src/datamodel_connector.rs +++ b/psl/psl-core/src/datamodel_connector.rs @@ -361,6 +361,7 @@ pub trait Connector: Send + Sync { } } +#[derive(Copy, Clone, Debug, PartialEq)] pub enum Flavour { Cockroach, Mongo, From 0cf9b4a0ddb23e6c0f58d2383a8737b3662d6b33 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 12:54:04 +0100 Subject: [PATCH 109/134] chore(driver-adapters): fix unit tests --- query-engine/driver-adapters/src/conversion/js_to_quaint.rs | 2 +- query-engine/driver-adapters/src/conversion/mysql.rs | 4 ++-- query-engine/driver-adapters/src/conversion/postgres.rs | 4 ++-- query-engine/driver-adapters/src/conversion/sqlite.rs | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index 02c06e122c36..7d11b13b7303 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -326,7 +326,7 @@ fn f64_to_f32(x: f64) -> quaint::Result { #[cfg(test)] mod proxy_test { - use num_bigint::BigInt; + use quaint::bigdecimal::num_bigint::BigInt; use serde_json::json; use super::*; diff --git a/query-engine/driver-adapters/src/conversion/mysql.rs b/query-engine/driver-adapters/src/conversion/mysql.rs index 114d7e3dfcfe..0f4c4bd8eec8 100644 --- a/query-engine/driver-adapters/src/conversion/mysql.rs +++ b/query-engine/driver-adapters/src/conversion/mysql.rs @@ -28,8 +28,8 @@ pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { #[cfg(test)] mod test { use super::*; - use bigdecimal::BigDecimal; - use chrono::*; + use quaint::bigdecimal::BigDecimal; + use quaint::chrono::*; use quaint::ValueType; use std::str::FromStr; diff --git a/query-engine/driver-adapters/src/conversion/postgres.rs b/query-engine/driver-adapters/src/conversion/postgres.rs index aaad219e9349..14e143e2ca8b 100644 --- a/query-engine/driver-adapters/src/conversion/postgres.rs +++ b/query-engine/driver-adapters/src/conversion/postgres.rs @@ -30,8 +30,8 @@ pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { #[cfg(test)] mod test { use super::*; - use bigdecimal::BigDecimal; - use chrono::*; + use quaint::bigdecimal::BigDecimal; + use quaint::chrono::*; use quaint::ValueType; use std::str::FromStr; diff --git a/query-engine/driver-adapters/src/conversion/sqlite.rs b/query-engine/driver-adapters/src/conversion/sqlite.rs index 032c16923256..785930fb9c30 100644 --- a/query-engine/driver-adapters/src/conversion/sqlite.rs +++ b/query-engine/driver-adapters/src/conversion/sqlite.rs @@ -25,8 +25,8 @@ pub fn value_to_js_arg(value: &quaint::Value) -> serde_json::Result { #[cfg(test)] mod test { use super::*; - use bigdecimal::BigDecimal; - use chrono::*; + use quaint::bigdecimal::BigDecimal; + use quaint::chrono::*; use quaint::ValueType; use serde_json::Value; use std::str::FromStr; From dabe0e2f7ec5560108f6aad4cf5eea4bd4da07da Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 14:19:38 +0100 Subject: [PATCH 110/134] chore(driver-adapters): add clippy check for wasm32 --- .github/workflows/formatting.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 50b635544b91..cc4f6192ca78 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,7 +27,9 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy - - run: cargo clippy --all-features + - run: | + cargo clippy --all-features + cargo clippy --all-features -p query-engine-wasm --target wasm32-unknown-unknown format: runs-on: ubuntu-latest From adf671d5c20bc2776be7f37185f792f226a4beaf Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 14:28:14 +0100 Subject: [PATCH 111/134] chore(driver-adapters): add clippy check for wasm32 --- .github/workflows/formatting.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index cc4f6192ca78..02955837fe53 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,6 +27,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy + targets: wasm32-unknown-unknown - run: | cargo clippy --all-features cargo clippy --all-features -p query-engine-wasm --target wasm32-unknown-unknown From 6c710ad3b7a83ab8db3bd29fd2ccac6edeb986c3 Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Thu, 30 Nov 2023 14:30:29 +0100 Subject: [PATCH 112/134] chore: fix clippy --- .github/workflows/formatting.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index cc4f6192ca78..02955837fe53 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,6 +27,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy + targets: wasm32-unknown-unknown - run: | cargo clippy --all-features cargo clippy --all-features -p query-engine-wasm --target wasm32-unknown-unknown From 481ba654658a00adc23e21d9dfc57875b899f4f2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 14:41:49 +0100 Subject: [PATCH 113/134] Revert "chore: removed wasm.rs test" This reverts commit 60b0d8c13f3bfd3df1492a647f6cd193defb23d2. --- Cargo.lock | 31 +++ query-engine/driver-adapters/Cargo.toml | 1 + query-engine/driver-adapters/tests/wasm.rs | 275 +++++++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 query-engine/driver-adapters/tests/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 93bf18b7be68..f31ab1efcf99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1117,6 +1117,7 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-bindgen-test", "wasm-rs-dbg", ] @@ -4520,6 +4521,12 @@ dependencies = [ "user-facing-errors", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -6093,6 +6100,30 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +[[package]] +name = "wasm-bindgen-test" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db36fc0f9fb209e88fb3642590ae0205bb5a56216dabd963ba15879fe53a30b" +dependencies = [ + "console_error_panic_hook", + "js-sys", + "scoped-tls", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0734759ae6b3b1717d661fe4f016efcfb9828f5edb4520c18eaee05af3b43be9" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "wasm-logger" version = "0.2.0" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 24f69164f138..ef2de701312d 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -22,6 +22,7 @@ futures = "0.3" expect-test = "1" tokio = { version = "1.0", features = ["macros", "time", "sync"] } wasm-rs-dbg.workspace = true +wasm-bindgen-test.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] napi.workspace = true diff --git a/query-engine/driver-adapters/tests/wasm.rs b/query-engine/driver-adapters/tests/wasm.rs new file mode 100644 index 000000000000..8f3aa30f7335 --- /dev/null +++ b/query-engine/driver-adapters/tests/wasm.rs @@ -0,0 +1,275 @@ +#![cfg(target_os = "wasm32")] +use wasm_bindgen_test::*; + +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use tsify::Tsify; +use wasm_bindgen::prelude::*; + +// Recursive expansion of Deserialize macro +// ========================================= +// +// #[doc(hidden)] +// #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] +// const _: () = { +// #[allow(unused_extern_crates, clippy::useless_attribute)] +// extern crate serde as _serde; +// #[automatically_derived] +// impl<'de> _serde::Deserialize<'de> for ColumnType { +// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result +// where +// __D: _serde::Deserializer<'de>, +// { +// #[allow(non_camel_case_types)] +// #[doc(hidden)] +// enum __Field { +// __field0, +// __field1, +// } +// #[doc(hidden)] +// struct __FieldVisitor; + +// impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { +// type Value = __Field; +// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { +// _serde::__private::Formatter::write_str(__formatter, "variant identifier") +// } +// fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// 0u64 => _serde::__private::Ok(__Field::__field0), +// 1u64 => _serde::__private::Ok(__Field::__field1), +// _ => _serde::__private::Err(_serde::de::Error::invalid_value( +// _serde::de::Unexpected::Unsigned(__value), +// &"variant index 0 <= i < 2", +// )), +// } +// } +// fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// "Int32" => _serde::__private::Ok(__Field::__field0), +// "Int64" => _serde::__private::Ok(__Field::__field1), +// _ => _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)), +// } +// } +// fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result +// where +// __E: _serde::de::Error, +// { +// match __value { +// b"Int32" => _serde::__private::Ok(__Field::__field0), +// b"Int64" => _serde::__private::Ok(__Field::__field1), +// _ => { +// let __value = &_serde::__private::from_utf8_lossy(__value); +// _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) +// } +// } +// } +// } +// impl<'de> _serde::Deserialize<'de> for __Field { +// #[inline] +// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result +// where +// __D: _serde::Deserializer<'de>, +// { +// _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) +// } +// } +// #[doc(hidden)] +// struct __Visitor<'de> { +// marker: _serde::__private::PhantomData, +// lifetime: _serde::__private::PhantomData<&'de ()>, +// } +// impl<'de> _serde::de::Visitor<'de> for __Visitor<'de> { +// type Value = ColumnType; +// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { +// _serde::__private::Formatter::write_str(__formatter, "enum ColumnType") +// } +// fn visit_enum<__A>(self, __data: __A) -> _serde::__private::Result +// where +// __A: _serde::de::EnumAccess<'de>, +// { +// match _serde::de::EnumAccess::variant(__data)? { +// (__Field::__field0, __variant) => { +// _serde::de::VariantAccess::unit_variant(__variant)?; +// _serde::__private::Ok(ColumnType::Int32) +// } +// (__Field::__field1, __variant) => { +// _serde::de::VariantAccess::unit_variant(__variant)?; +// _serde::__private::Ok(ColumnType::Int64) +// } +// } +// } +// } +// #[doc(hidden)] +// const VARIANTS: &'static [&'static str] = &["Int32", "Int64"]; +// _serde::Deserializer::deserialize_enum( +// __deserializer, +// "ColumnType", +// VARIANTS, +// __Visitor { +// marker: _serde::__private::PhantomData::, +// lifetime: _serde::__private::PhantomData, +// }, +// ) +// } +// } +// }; +// +// +// Recursive expansion of Tsify macro +// =================================== +// +// #[automatically_derived] +// const _: () = { +// extern crate serde as _serde; +// use tsify::Tsify; +// use wasm_bindgen::{ +// convert::{FromWasmAbi, IntoWasmAbi, OptionFromWasmAbi, OptionIntoWasmAbi}, +// describe::WasmDescribe, +// prelude::*, +// }; +// #[wasm_bindgen] +// extern "C" { +// #[wasm_bindgen(typescript_type = "ColumnType")] +// pub type JsType; +// } +// impl Tsify for ColumnType { +// type JsType = JsType; +// const DECL: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; +// } +// #[wasm_bindgen(typescript_custom_section)] +// const TS_APPEND_CONTENT: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; +// impl WasmDescribe for ColumnType { +// #[inline] +// fn describe() { +// ::JsType::describe() +// } +// } +// impl IntoWasmAbi for ColumnType +// where +// Self: _serde::Serialize, +// { +// type Abi = ::Abi; +// #[inline] +// fn into_abi(self) -> Self::Abi { +// self.into_js().unwrap_throw().into_abi() +// } +// } +// impl OptionIntoWasmAbi for ColumnType +// where +// Self: _serde::Serialize, +// { +// #[inline] +// fn none() -> Self::Abi { +// ::none() +// } +// } +// impl FromWasmAbi for ColumnType +// where +// Self: _serde::de::DeserializeOwned, +// { +// type Abi = ::Abi; +// #[inline] +// unsafe fn from_abi(js: Self::Abi) -> Self { +// let result = Self::from_js(&JsType::from_abi(js)); +// if let Err(err) = result { +// wasm_bindgen::throw_str(err.to_string().as_ref()); +// } +// result.unwrap_throw() +// } +// } +// impl OptionFromWasmAbi for ColumnType +// where +// Self: _serde::de::DeserializeOwned, +// { +// #[inline] +// fn is_none(js: &Self::Abi) -> bool { +// ::is_none(js) +// } +// } +// }; +#[derive(Clone, Copy, Debug, Deserialize, Tsify)] +#[tsify(from_wasm_abi)] +pub enum ColumnType { + Int32 = 0, + Int64 = 1, +} + +#[derive(Debug, Deserialize, Tsify)] +#[tsify(from_wasm_abi)] +#[serde(rename_all = "camelCase")] +struct ColumnTypeWrapper { + column_type: ColumnType, +} + +// Recursive expansion of Deserialize_repr macro +// ============================================== +// +// impl<'de> serde::Deserialize<'de> for ColumnTypeWasmBindgen { +// #[allow(clippy::use_self)] +// fn deserialize(deserializer: D) -> ::core::result::Result +// where +// D: serde::Deserializer<'de>, +// { +// #[allow(non_camel_case_types)] +// struct discriminant; + +// #[allow(non_upper_case_globals)] +// impl discriminant { +// const Int32: u8 = ColumnTypeWasmBindgen::Int32 as u8; +// const Int64: u8 = ColumnTypeWasmBindgen::Int64 as u8; +// } +// match ::deserialize(deserializer)? { +// discriminant::Int32 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int32), +// discriminant::Int64 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int64), +// other => ::core::result::Result::Err(serde::de::Error::custom(format_args!( +// "invalid value: {}, expected {} or {}", +// other, +// discriminant::Int32, +// discriminant::Int64 +// ))), +// } +// } +// } +#[derive(Debug, Deserialize_repr, Tsify)] +#[tsify(from_wasm_abi)] +#[repr(u8)] +pub enum ColumnTypeWasmBindgen { + // #[serde(rename = "0")] + Int32 = 0, + + // #[serde(rename = "1")] + Int64 = 1, +} + +#[wasm_bindgen_test] +fn column_type_test() { + // Example deserialization code + let json_data = r#"0"#; + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + let column_type = serde_json::from_str::(&json_data).unwrap(); + + // let json_data = "\"0\""; + let column_type = serde_json::from_str::(&json_data).unwrap(); +} + +// #[wasm_bindgen_test] +// fn column_type_test() { +// // Example deserialization code +// let json_data = r#"{ "columnType": 0 }"#; +// let column_type_wrapper = serde_json::from_str::(json_data); + +// panic!("{:?}", column_type_wrapper); +// } From 998f5c132bd4410809fc3c00bca7d5590f927b7e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 21:53:46 +0100 Subject: [PATCH 114/134] test(driver-adapters): add byte tests for conversion --- .../src/conversion/js_to_quaint.rs | 18 ++++++++++++++++++ .../driver-adapters/src/conversion/mysql.rs | 5 +++++ .../driver-adapters/src/conversion/postgres.rs | 4 ++++ .../driver-adapters/src/conversion/sqlite.rs | 4 ++++ 4 files changed, 31 insertions(+) diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index 7d11b13b7303..3f5a65395896 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -338,6 +338,24 @@ mod proxy_test { assert_eq!(quaint_value, quaint_none.into()); } + #[test] + fn js_value_binary_to_quaint() { + let column_type = ColumnType::Bytes; + + // null + test_null(QuaintValue::null_bytes(), column_type); + + // "" + let json_value = serde_json::Value::String("".to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::bytes(vec![])); + + // "hello" + let json_value = serde_json::Value::String("hello".to_string()); + let quaint_value = js_value_to_quaint(json_value, column_type, "column_name").unwrap(); + assert_eq!(quaint_value, QuaintValue::bytes(vec![104, 101, 108, 108, 111])); + } + #[test] fn js_value_int32_to_quaint() { let column_type = ColumnType::Int32; diff --git a/query-engine/driver-adapters/src/conversion/mysql.rs b/query-engine/driver-adapters/src/conversion/mysql.rs index 0f4c4bd8eec8..bd59d3b94ed0 100644 --- a/query-engine/driver-adapters/src/conversion/mysql.rs +++ b/query-engine/driver-adapters/src/conversion/mysql.rs @@ -93,6 +93,11 @@ mod test { JSArg::Value(JsonValue::String("23:13:01".to_string())) )) ), + ( + ValueType::Bytes(Some("hello".as_bytes().into())), + JSArg::Buffer("hello".as_bytes().to_vec()) + ), + ]; let mut errors: Vec = vec![]; diff --git a/query-engine/driver-adapters/src/conversion/postgres.rs b/query-engine/driver-adapters/src/conversion/postgres.rs index 14e143e2ca8b..949cc17e9eba 100644 --- a/query-engine/driver-adapters/src/conversion/postgres.rs +++ b/query-engine/driver-adapters/src/conversion/postgres.rs @@ -105,6 +105,10 @@ mod test { JSArg::Value(JsonValue::Null), )) ), + ( + ValueType::Bytes(Some("hello".as_bytes().into())).into_value(), + JSArg::Buffer("hello".as_bytes().to_vec()) + ), ]; let mut errors: Vec = vec![]; diff --git a/query-engine/driver-adapters/src/conversion/sqlite.rs b/query-engine/driver-adapters/src/conversion/sqlite.rs index 785930fb9c30..b11acdca0d7f 100644 --- a/query-engine/driver-adapters/src/conversion/sqlite.rs +++ b/query-engine/driver-adapters/src/conversion/sqlite.rs @@ -94,6 +94,10 @@ mod test { JSArg::Value(Value::Null), )) ), + ( + ValueType::Bytes(Some("hello".as_bytes().into())), + JSArg::Buffer("hello".as_bytes().to_vec()) + ), ]; let mut errors: Vec = vec![]; From 48b9381e6b4a6108c8c872f078ccf7cf40511849 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Thu, 30 Nov 2023 21:55:09 +0100 Subject: [PATCH 115/134] Revert "Revert "chore: removed wasm.rs test"" This reverts commit 481ba654658a00adc23e21d9dfc57875b899f4f2. --- Cargo.lock | 31 --- query-engine/driver-adapters/Cargo.toml | 1 - query-engine/driver-adapters/tests/wasm.rs | 275 --------------------- 3 files changed, 307 deletions(-) delete mode 100644 query-engine/driver-adapters/tests/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index f31ab1efcf99..93bf18b7be68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1117,7 +1117,6 @@ dependencies = [ "uuid", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-bindgen-test", "wasm-rs-dbg", ] @@ -4521,12 +4520,6 @@ dependencies = [ "user-facing-errors", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.2.0" @@ -6100,30 +6093,6 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" -[[package]] -name = "wasm-bindgen-test" -version = "0.3.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db36fc0f9fb209e88fb3642590ae0205bb5a56216dabd963ba15879fe53a30b" -dependencies = [ - "console_error_panic_hook", - "js-sys", - "scoped-tls", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-bindgen-test-macro", -] - -[[package]] -name = "wasm-bindgen-test-macro" -version = "0.3.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734759ae6b3b1717d661fe4f016efcfb9828f5edb4520c18eaee05af3b43be9" -dependencies = [ - "proc-macro2", - "quote", -] - [[package]] name = "wasm-logger" version = "0.2.0" diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index ef2de701312d..24f69164f138 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -22,7 +22,6 @@ futures = "0.3" expect-test = "1" tokio = { version = "1.0", features = ["macros", "time", "sync"] } wasm-rs-dbg.workspace = true -wasm-bindgen-test.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] napi.workspace = true diff --git a/query-engine/driver-adapters/tests/wasm.rs b/query-engine/driver-adapters/tests/wasm.rs deleted file mode 100644 index 8f3aa30f7335..000000000000 --- a/query-engine/driver-adapters/tests/wasm.rs +++ /dev/null @@ -1,275 +0,0 @@ -#![cfg(target_os = "wasm32")] -use wasm_bindgen_test::*; - -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; -use tsify::Tsify; -use wasm_bindgen::prelude::*; - -// Recursive expansion of Deserialize macro -// ========================================= -// -// #[doc(hidden)] -// #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] -// const _: () = { -// #[allow(unused_extern_crates, clippy::useless_attribute)] -// extern crate serde as _serde; -// #[automatically_derived] -// impl<'de> _serde::Deserialize<'de> for ColumnType { -// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result -// where -// __D: _serde::Deserializer<'de>, -// { -// #[allow(non_camel_case_types)] -// #[doc(hidden)] -// enum __Field { -// __field0, -// __field1, -// } -// #[doc(hidden)] -// struct __FieldVisitor; - -// impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { -// type Value = __Field; -// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { -// _serde::__private::Formatter::write_str(__formatter, "variant identifier") -// } -// fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// 0u64 => _serde::__private::Ok(__Field::__field0), -// 1u64 => _serde::__private::Ok(__Field::__field1), -// _ => _serde::__private::Err(_serde::de::Error::invalid_value( -// _serde::de::Unexpected::Unsigned(__value), -// &"variant index 0 <= i < 2", -// )), -// } -// } -// fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// "Int32" => _serde::__private::Ok(__Field::__field0), -// "Int64" => _serde::__private::Ok(__Field::__field1), -// _ => _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)), -// } -// } -// fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result -// where -// __E: _serde::de::Error, -// { -// match __value { -// b"Int32" => _serde::__private::Ok(__Field::__field0), -// b"Int64" => _serde::__private::Ok(__Field::__field1), -// _ => { -// let __value = &_serde::__private::from_utf8_lossy(__value); -// _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) -// } -// } -// } -// } -// impl<'de> _serde::Deserialize<'de> for __Field { -// #[inline] -// fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result -// where -// __D: _serde::Deserializer<'de>, -// { -// _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) -// } -// } -// #[doc(hidden)] -// struct __Visitor<'de> { -// marker: _serde::__private::PhantomData, -// lifetime: _serde::__private::PhantomData<&'de ()>, -// } -// impl<'de> _serde::de::Visitor<'de> for __Visitor<'de> { -// type Value = ColumnType; -// fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { -// _serde::__private::Formatter::write_str(__formatter, "enum ColumnType") -// } -// fn visit_enum<__A>(self, __data: __A) -> _serde::__private::Result -// where -// __A: _serde::de::EnumAccess<'de>, -// { -// match _serde::de::EnumAccess::variant(__data)? { -// (__Field::__field0, __variant) => { -// _serde::de::VariantAccess::unit_variant(__variant)?; -// _serde::__private::Ok(ColumnType::Int32) -// } -// (__Field::__field1, __variant) => { -// _serde::de::VariantAccess::unit_variant(__variant)?; -// _serde::__private::Ok(ColumnType::Int64) -// } -// } -// } -// } -// #[doc(hidden)] -// const VARIANTS: &'static [&'static str] = &["Int32", "Int64"]; -// _serde::Deserializer::deserialize_enum( -// __deserializer, -// "ColumnType", -// VARIANTS, -// __Visitor { -// marker: _serde::__private::PhantomData::, -// lifetime: _serde::__private::PhantomData, -// }, -// ) -// } -// } -// }; -// -// -// Recursive expansion of Tsify macro -// =================================== -// -// #[automatically_derived] -// const _: () = { -// extern crate serde as _serde; -// use tsify::Tsify; -// use wasm_bindgen::{ -// convert::{FromWasmAbi, IntoWasmAbi, OptionFromWasmAbi, OptionIntoWasmAbi}, -// describe::WasmDescribe, -// prelude::*, -// }; -// #[wasm_bindgen] -// extern "C" { -// #[wasm_bindgen(typescript_type = "ColumnType")] -// pub type JsType; -// } -// impl Tsify for ColumnType { -// type JsType = JsType; -// const DECL: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; -// } -// #[wasm_bindgen(typescript_custom_section)] -// const TS_APPEND_CONTENT: &'static str = "export type ColumnType = \"Int32\" | \"Int64\";"; -// impl WasmDescribe for ColumnType { -// #[inline] -// fn describe() { -// ::JsType::describe() -// } -// } -// impl IntoWasmAbi for ColumnType -// where -// Self: _serde::Serialize, -// { -// type Abi = ::Abi; -// #[inline] -// fn into_abi(self) -> Self::Abi { -// self.into_js().unwrap_throw().into_abi() -// } -// } -// impl OptionIntoWasmAbi for ColumnType -// where -// Self: _serde::Serialize, -// { -// #[inline] -// fn none() -> Self::Abi { -// ::none() -// } -// } -// impl FromWasmAbi for ColumnType -// where -// Self: _serde::de::DeserializeOwned, -// { -// type Abi = ::Abi; -// #[inline] -// unsafe fn from_abi(js: Self::Abi) -> Self { -// let result = Self::from_js(&JsType::from_abi(js)); -// if let Err(err) = result { -// wasm_bindgen::throw_str(err.to_string().as_ref()); -// } -// result.unwrap_throw() -// } -// } -// impl OptionFromWasmAbi for ColumnType -// where -// Self: _serde::de::DeserializeOwned, -// { -// #[inline] -// fn is_none(js: &Self::Abi) -> bool { -// ::is_none(js) -// } -// } -// }; -#[derive(Clone, Copy, Debug, Deserialize, Tsify)] -#[tsify(from_wasm_abi)] -pub enum ColumnType { - Int32 = 0, - Int64 = 1, -} - -#[derive(Debug, Deserialize, Tsify)] -#[tsify(from_wasm_abi)] -#[serde(rename_all = "camelCase")] -struct ColumnTypeWrapper { - column_type: ColumnType, -} - -// Recursive expansion of Deserialize_repr macro -// ============================================== -// -// impl<'de> serde::Deserialize<'de> for ColumnTypeWasmBindgen { -// #[allow(clippy::use_self)] -// fn deserialize(deserializer: D) -> ::core::result::Result -// where -// D: serde::Deserializer<'de>, -// { -// #[allow(non_camel_case_types)] -// struct discriminant; - -// #[allow(non_upper_case_globals)] -// impl discriminant { -// const Int32: u8 = ColumnTypeWasmBindgen::Int32 as u8; -// const Int64: u8 = ColumnTypeWasmBindgen::Int64 as u8; -// } -// match ::deserialize(deserializer)? { -// discriminant::Int32 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int32), -// discriminant::Int64 => ::core::result::Result::Ok(ColumnTypeWasmBindgen::Int64), -// other => ::core::result::Result::Err(serde::de::Error::custom(format_args!( -// "invalid value: {}, expected {} or {}", -// other, -// discriminant::Int32, -// discriminant::Int64 -// ))), -// } -// } -// } -#[derive(Debug, Deserialize_repr, Tsify)] -#[tsify(from_wasm_abi)] -#[repr(u8)] -pub enum ColumnTypeWasmBindgen { - // #[serde(rename = "0")] - Int32 = 0, - - // #[serde(rename = "1")] - Int64 = 1, -} - -#[wasm_bindgen_test] -fn column_type_test() { - // Example deserialization code - let json_data = r#"0"#; - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - let column_type = serde_json::from_str::(&json_data).unwrap(); - - // let json_data = "\"0\""; - let column_type = serde_json::from_str::(&json_data).unwrap(); -} - -// #[wasm_bindgen_test] -// fn column_type_test() { -// // Example deserialization code -// let json_data = r#"{ "columnType": 0 }"#; -// let column_type_wrapper = serde_json::from_str::(json_data); - -// panic!("{:?}", column_type_wrapper); -// } From 28c6ff92ff32605baceba8903419d6e0e1bacb4d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 4 Dec 2023 17:21:11 +0100 Subject: [PATCH 116/134] =?UTF-8?q?[skip=20ci]=C2=A0chore:=20fix=20build?= =?UTF-8?q?=20CI=20check=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- query-engine/query-engine-wasm/build.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 10e3008912f0..29086ca40183 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -15,10 +15,10 @@ OUT_NPM_NAME="@prisma/query-engine-wasm" sed -i '' 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml # use `wasm-pack build --release` on CI only -if [[ -z "$BUILDKITE" ]] || [[ -z "$GITHUB_ACTIONS" ]]; then - BUILD_PROFILE="--release" -else +if [[ -z "$BUILDKITE" ]] && [[ -z "$GITHUB_ACTIONS" ]]; then BUILD_PROFILE="--dev" +else + BUILD_PROFILE="--release" fi # Check if wasm-pack is installed From 7370d1fd0daab2482d2b0e0ec84e0cc3a0fee36b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 4 Dec 2023 17:30:21 +0100 Subject: [PATCH 117/134] =?UTF-8?q?Revert=20"[skip=20ci]=C2=A0chore:=20fix?= =?UTF-8?q?=20build=20CI=20check=20logic"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 28c6ff92ff32605baceba8903419d6e0e1bacb4d. --- query-engine/query-engine-wasm/build.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 29086ca40183..10e3008912f0 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -15,10 +15,10 @@ OUT_NPM_NAME="@prisma/query-engine-wasm" sed -i '' 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml # use `wasm-pack build --release` on CI only -if [[ -z "$BUILDKITE" ]] && [[ -z "$GITHUB_ACTIONS" ]]; then - BUILD_PROFILE="--dev" -else +if [[ -z "$BUILDKITE" ]] || [[ -z "$GITHUB_ACTIONS" ]]; then BUILD_PROFILE="--release" +else + BUILD_PROFILE="--dev" fi # Check if wasm-pack is installed From 0bc5b6c129df1999a498e927b744189bd6a95763 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Mon, 4 Dec 2023 18:02:21 +0100 Subject: [PATCH 118/134] Stop using removed method --- .../driver-adapters/connector-test-kit-executor/src/wasm.ts | 1 - query-engine/query-engine-wasm/example/example.js | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts index 439fd0c3f94f..6eea2ee36cef 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/wasm.ts @@ -9,6 +9,5 @@ const bytes = await fs.readFile(path.resolve(dirname, '..', '..', '..', 'query-e const module = new WebAssembly.Module(bytes) const instance = new WebAssembly.Instance(module, { './query_engine_bg.js': wasm }) wasm.__wbg_set_wasm(instance.exports); -wasm.init() export const WasmQueryEngine = wasm.QueryEngine \ No newline at end of file diff --git a/query-engine/query-engine-wasm/example/example.js b/query-engine/query-engine-wasm/example/example.js index 154f901a6e0a..5d3449010865 100644 --- a/query-engine/query-engine-wasm/example/example.js +++ b/query-engine/query-engine-wasm/example/example.js @@ -6,14 +6,12 @@ import { readFile } from 'fs/promises' import { PrismaLibSQL } from '@prisma/adapter-libsql' import { createClient } from '@libsql/client' import { bindAdapter } from '@prisma/driver-adapter-utils' -import { init, QueryEngine, getBuildTimeInfo } from '../pkg/query_engine.js' +import { QueryEngine, getBuildTimeInfo } from '../pkg/query_engine.js' async function main() { // Always initialize the Wasm library before using it. // This sets up the logging and panic hooks. - init() - const client = createClient({ url: "file:./prisma/dev.db"}) const adapter = new PrismaLibSQL(client) From 2a2565a5e1cab273fd2a2fc3f89a4197c2de9278 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Mon, 4 Dec 2023 18:14:09 +0100 Subject: [PATCH 119/134] Fix broken JS --- query-engine/driver-adapters/src/lib.rs | 5 ++++- query-engine/driver-adapters/src/wasm/js_object_extern.rs | 3 --- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 318291efa46a..6be1270b3268 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -38,8 +38,11 @@ pub(crate) use wasm::*; #[cfg(target_arch = "wasm32")] mod arch { + use std::str::FromStr; + pub(crate) use super::JsObject; pub(crate) use js_sys::JsString; + use js_sys::Reflect; use tsify::Tsify; use wasm_bindgen::JsValue; @@ -51,7 +54,7 @@ mod arch { } pub(crate) fn has_named_property(object: &JsObject, name: &str) -> JsResult { - Ok(JsObject::has_own(object, name.into())) + Ok(Reflect::has(&object, &JsString::from_str(name).unwrap().into())?) } pub(crate) fn to_rust_str(value: JsString) -> JsResult { diff --git a/query-engine/driver-adapters/src/wasm/js_object_extern.rs b/query-engine/driver-adapters/src/wasm/js_object_extern.rs index 4cb6996e67a2..29abc1d6ef6f 100644 --- a/query-engine/driver-adapters/src/wasm/js_object_extern.rs +++ b/query-engine/driver-adapters/src/wasm/js_object_extern.rs @@ -8,7 +8,4 @@ extern "C" { #[wasm_bindgen(method, catch, structural, indexing_getter)] pub fn get(this: &JsObjectExtern, key: JsString) -> Result; - - #[wasm_bindgen(static_method_of = JsObjectExtern, js_name = hasOwn)] - pub fn has_own(this: &JsObjectExtern, key: JsString) -> bool; } From 534e757f3cbb0807bca447c033d4deabe517e69d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:28:06 +0100 Subject: [PATCH 120/134] chore(review): rename threadsafe_fn to fn_ --- .../driver-adapters/src/wasm/async_js_function.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 8a176976886f..3e67c622d647 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -22,7 +22,7 @@ where ArgType: Serialize, ReturnType: FromJsValue, { - threadsafe_fn: JsFunction, + fn_: JsFunction, _phantom_arg: PhantomData, _phantom_return: PhantomData, @@ -45,7 +45,7 @@ where { fn from(js_fn: JsFunction) -> Self { Self { - threadsafe_fn: js_fn, + fn_: js_fn, _phantom_arg: PhantomData:: {}, _phantom_return: PhantomData:: {}, } @@ -70,7 +70,7 @@ where let arg1 = arg1 .serialize(&SERIALIZER) .map_err(|err| JsValue::from(JsError::from(&err)))?; - let return_value = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?; + let return_value = self.fn_.call1(&JsValue::null(), &arg1)?; let value = if let Some(promise) = return_value.dyn_ref::() { JsFuture::from(promise.to_owned()).await? @@ -85,7 +85,7 @@ where pub(crate) fn call_non_blocking(&self, arg: T) { if let Ok(arg) = serde_wasm_bindgen::to_value(&arg) { - _ = self.threadsafe_fn.call1(&JsValue::null(), &arg); + _ = self.fn_.call1(&JsValue::null(), &arg); } } } From 05bd6244e43849b09da4552c215957cc9bdc79ec Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:28:24 +0100 Subject: [PATCH 121/134] chore(review): add comment related to js_sys::Reflect in JsObjectExtern --- query-engine/driver-adapters/src/wasm/js_object_extern.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/query-engine/driver-adapters/src/wasm/js_object_extern.rs b/query-engine/driver-adapters/src/wasm/js_object_extern.rs index 29abc1d6ef6f..ac9f72619eac 100644 --- a/query-engine/driver-adapters/src/wasm/js_object_extern.rs +++ b/query-engine/driver-adapters/src/wasm/js_object_extern.rs @@ -6,6 +6,7 @@ extern "C" { #[wasm_bindgen(js_name = String, extends = JsObject, is_type_of = JsValue::is_object, typescript_type = "object")] pub type JsObjectExtern; + // Note: this custom getter allows us to avoid runtime reflection via `js_sys::Reflect`. #[wasm_bindgen(method, catch, structural, indexing_getter)] pub fn get(this: &JsObjectExtern, key: JsString) -> Result; } From 3e889573eddf4de5758a4a0b096e9b62e8b5246b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:34:00 +0100 Subject: [PATCH 122/134] chore(review): rename SendFuture to UnsafeFuture, improving comments --- query-engine/driver-adapters/src/proxy.rs | 14 ++++----- query-engine/driver-adapters/src/queryable.rs | 10 +++--- .../driver-adapters/src/send_future.rs | 31 +++++++++++-------- .../driver-adapters/src/transaction.rs | 6 ++-- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index d5ad59887526..4e44caf3db75 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -1,4 +1,4 @@ -use crate::send_future::SendFuture; +use crate::send_future::UnsafeFuture; pub use crate::types::{ColumnType, JSResultSet, Query, TransactionOptions}; use crate::{from_js_value, get_named_property, has_named_property, to_rust_str, JsObject, JsResult, JsString}; @@ -96,8 +96,8 @@ impl DriverProxy { Ok(Box::new(tx)) } - pub fn start_transaction(&self) -> SendFuture>> + '_> { - SendFuture(self.start_transaction_inner()) + pub fn start_transaction(&self) -> UnsafeFuture>> + '_> { + UnsafeFuture(self.start_transaction_inner()) } } @@ -135,9 +135,9 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling the transaction back even if the `commit` future was dropped while /// waiting on the JavaScript call to complete and deliver response. - pub fn commit(&self) -> SendFuture> + '_> { + pub fn commit(&self) -> UnsafeFuture> + '_> { self.closed.store(true, Ordering::Relaxed); - SendFuture(self.commit.call(())) + UnsafeFuture(self.commit.call(())) } /// Rolls back the transaction via the driver adapter. @@ -155,9 +155,9 @@ impl TransactionProxy { /// the underlying FFI call will be delivered to JavaScript side in lockstep, so the destructor /// will not attempt rolling back again even if the `rollback` future was dropped while waiting /// on the JavaScript call to complete and deliver response. - pub fn rollback(&self) -> SendFuture> + '_> { + pub fn rollback(&self) -> UnsafeFuture> + '_> { self.closed.store(true, Ordering::Relaxed); - SendFuture(self.rollback.call(())) + UnsafeFuture(self.rollback.call(())) } } diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index 2c03c8633310..3afa9ecd2180 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -3,7 +3,7 @@ use crate::types::{AdapterFlavour, Query}; use crate::JsObject; use super::conversion; -use crate::send_future::SendFuture; +use crate::send_future::UnsafeFuture; use async_trait::async_trait; use futures::Future; use quaint::{ @@ -170,8 +170,8 @@ impl JsBaseQueryable { &'a self, sql: &'a str, params: &'a [quaint::Value<'a>], - ) -> SendFuture> + 'a> { - SendFuture(self.do_query_raw_inner(sql, params)) + ) -> UnsafeFuture> + 'a> { + UnsafeFuture(self.do_query_raw_inner(sql, params)) } async fn do_execute_raw_inner(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { @@ -189,8 +189,8 @@ impl JsBaseQueryable { &'a self, sql: &'a str, params: &'a [quaint::Value<'a>], - ) -> SendFuture> + 'a> { - SendFuture(self.do_execute_raw_inner(sql, params)) + ) -> UnsafeFuture> + 'a> { + UnsafeFuture(self.do_execute_raw_inner(sql, params)) } } diff --git a/query-engine/driver-adapters/src/send_future.rs b/query-engine/driver-adapters/src/send_future.rs index ed5e78345afd..52a59d764708 100644 --- a/query-engine/driver-adapters/src/send_future.rs +++ b/query-engine/driver-adapters/src/send_future.rs @@ -1,17 +1,24 @@ use futures::Future; -// Allow asynchronous futures to be sent safely across threads, solving the following error: -// -// ```text -// future cannot be sent between threads safely -// the trait `Send` is not implemented for `dyn Future>`. -// ``` -// -// See: https://github.com/rustwasm/wasm-bindgen/issues/2409#issuecomment-820750943 +/// Allow asynchronous futures to be sent across threads, solving the following error on `wasm32-*` targets: +/// +/// ```text +/// future cannot be sent between threads safely +/// the trait `Send` is not implemented for `dyn Future>`. +/// ``` +/// +/// This wrapper is used by both the Napi.rs and Wasm implementation of `driver-adapters`, but is only really +/// needed because `wasm-bindgen` does not implement `Send` for `Future`, and most of the codebase +/// uses `#[async_trait]`, which requires `Send` on the future returned by `async fn` declarations. +/// +/// In fact, `UnsafeFuture` safely implements `Send` if `F` implements `Future + Send`, which is the case +/// with Napi.rs, but not with Wasm. +/// +/// See: https://github.com/rustwasm/wasm-bindgen/issues/2409#issuecomment-820750943 #[pin_project::pin_project] -pub struct SendFuture(#[pin] pub F); +pub struct UnsafeFuture(#[pin] pub F); -impl Future for SendFuture { +impl Future for UnsafeFuture { type Output = F::Output; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { @@ -21,7 +28,5 @@ impl Future for SendFuture { } } -// Note: on Napi.rs, we require the underlying future to be `Send`. -// On Wasm, that's currently not possible. #[cfg(target_arch = "wasm32")] -unsafe impl Send for SendFuture {} +unsafe impl Send for UnsafeFuture {} diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 7f7180ae6d30..264c363ea608 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -7,7 +7,7 @@ use quaint::{ }; use crate::proxy::{TransactionOptions, TransactionProxy}; -use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::SendFuture}; +use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; use crate::{JsObject, JsResult}; // Wrapper around JS transaction objects that implements Queryable @@ -48,7 +48,7 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(commit_stmt).await?; } - SendFuture(self.tx_proxy.commit()).await + UnsafeFuture(self.tx_proxy.commit()).await } async fn rollback(&self) -> quaint::Result<()> { @@ -64,7 +64,7 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(rollback_stmt).await?; } - SendFuture(self.tx_proxy.rollback()).await + UnsafeFuture(self.tx_proxy.rollback()).await } fn as_queryable(&self) -> &dyn Queryable { From 687c23c5abc7d98ae91178967ec963b77fcc85ad Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:42:48 +0100 Subject: [PATCH 123/134] chore(review): use fully-specialized types for wasm/napi-specific logic, when possible; apply clippy fixes --- query-engine/driver-adapters/src/lib.rs | 30 ++++++++++--------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 6be1270b3268..0fcd956d32fc 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -38,53 +38,47 @@ pub(crate) use wasm::*; #[cfg(target_arch = "wasm32")] mod arch { - use std::str::FromStr; - - pub(crate) use super::JsObject; pub(crate) use js_sys::JsString; - use js_sys::Reflect; + use std::str::FromStr; use tsify::Tsify; - use wasm_bindgen::JsValue; - pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult + pub(crate) fn get_named_property(object: &super::wasm::JsObjectExtern, name: &str) -> JsResult where - T: From, + T: From, { Ok(object.get(name.into())?.into()) } - pub(crate) fn has_named_property(object: &JsObject, name: &str) -> JsResult { - Ok(Reflect::has(&object, &JsString::from_str(name).unwrap().into())?) + pub(crate) fn has_named_property(object: &super::wasm::JsObjectExtern, name: &str) -> JsResult { + js_sys::Reflect::has(object, &JsString::from_str(name).unwrap().into()) } pub(crate) fn to_rust_str(value: JsString) -> JsResult { Ok(value.into()) } - pub(crate) fn from_js_value(value: JsValue) -> C + pub(crate) fn from_js_value(value: wasm_bindgen::JsValue) -> C where C: Tsify + serde::de::DeserializeOwned, { C::from_js(value).unwrap() } - pub(crate) type JsResult = core::result::Result; + pub(crate) type JsResult = core::result::Result; } #[cfg(not(target_arch = "wasm32"))] mod arch { - pub(crate) use super::JsObject; - use napi::bindgen_prelude::FromNapiValue; - pub(crate) use napi::JsString; + pub(crate) use ::napi::JsString; - pub(crate) fn get_named_property(object: &JsObject, name: &str) -> JsResult + pub(crate) fn get_named_property(object: &::napi::JsObject, name: &str) -> JsResult where - T: FromNapiValue, + T: ::napi::bindgen_prelude::FromNapiValue, { object.get_named_property(name) } - pub(crate) fn has_named_property(object: &JsObject, name: &str) -> JsResult { + pub(crate) fn has_named_property(object: &::napi::JsObject, name: &str) -> JsResult { object.has_named_property(name) } @@ -96,7 +90,7 @@ mod arch { value } - pub(crate) type JsResult = napi::Result; + pub(crate) type JsResult = ::napi::Result; } pub(crate) use arch::*; From 753e0655c87f5762c9ec339d1f7b922ed8a6f821 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:53:57 +0100 Subject: [PATCH 124/134] chore(review): rename JsResult into AdapterResult, reduce duplication, improve comments --- query-engine/driver-adapters/src/lib.rs | 27 ++++++++++++ .../src/napi/async_js_function.rs | 8 ++-- query-engine/driver-adapters/src/napi/mod.rs | 2 +- .../driver-adapters/src/napi/result.rs | 41 ++++++------------- .../src/wasm/async_js_function.rs | 6 +-- query-engine/driver-adapters/src/wasm/mod.rs | 2 +- .../driver-adapters/src/wasm/result.rs | 36 +++++----------- 7 files changed, 57 insertions(+), 65 deletions(-) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 0fcd956d32fc..8c9dc58ba573 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -15,6 +15,33 @@ pub(crate) mod send_future; pub(crate) mod transaction; pub(crate) mod types; +use crate::error::DriverAdapterError; +use quaint::error::{Error as QuaintError, ErrorKind}; + +#[cfg(target_arch = "wasm32")] +pub(crate) use wasm::result::AdapterResult; + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) use napi::result::AdapterResult; + +impl From for QuaintError { + fn from(value: DriverAdapterError) -> Self { + match value { + DriverAdapterError::UnsupportedNativeDataType { native_type } => { + QuaintError::builder(ErrorKind::UnsupportedColumnType { + column_type: native_type, + }) + .build() + } + DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), + DriverAdapterError::Postgres(e) => e.into(), + DriverAdapterError::Mysql(e) => e.into(), + DriverAdapterError::Sqlite(e) => e.into(), + // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here + } + } +} + pub use queryable::from_js; pub(crate) use transaction::JsTransaction; diff --git a/query-engine/driver-adapters/src/napi/async_js_function.rs b/query-engine/driver-adapters/src/napi/async_js_function.rs index d62931e2c767..5b53ecbadc65 100644 --- a/query-engine/driver-adapters/src/napi/async_js_function.rs +++ b/query-engine/driver-adapters/src/napi/async_js_function.rs @@ -5,10 +5,8 @@ use napi::{ threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, }; -use super::{ - error::{async_unwinding_panic, into_quaint_error}, - result::JsResult, -}; +use super::error::{async_unwinding_panic, into_quaint_error}; +use crate::AdapterResult; /// Wrapper for napi-rs's ThreadsafeFunction that is aware of /// JS drivers conventions. Performs following things: @@ -47,7 +45,7 @@ where let js_result = async_unwinding_panic(async { let promise = self .threadsafe_fn - .call_async::>>(arg) + .call_async::>>(arg) .await?; promise.await }) diff --git a/query-engine/driver-adapters/src/napi/mod.rs b/query-engine/driver-adapters/src/napi/mod.rs index c9bb8d24ac33..c53414c78c85 100644 --- a/query-engine/driver-adapters/src/napi/mod.rs +++ b/query-engine/driver-adapters/src/napi/mod.rs @@ -3,6 +3,6 @@ mod async_js_function; mod conversion; mod error; -mod result; +pub(crate) mod result; pub(crate) use async_js_function::AsyncJsFunction; diff --git a/query-engine/driver-adapters/src/napi/result.rs b/query-engine/driver-adapters/src/napi/result.rs index d815c9d86dbd..529455bf9a0b 100644 --- a/query-engine/driver-adapters/src/napi/result.rs +++ b/query-engine/driver-adapters/src/napi/result.rs @@ -1,7 +1,5 @@ -use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; -use quaint::error::{Error as QuaintError, ErrorKind}; - use crate::error::DriverAdapterError; +use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; impl FromNapiValue for DriverAdapterError { unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { @@ -11,26 +9,11 @@ impl FromNapiValue for DriverAdapterError { } } -impl From for QuaintError { - fn from(value: DriverAdapterError) -> Self { - match value { - DriverAdapterError::UnsupportedNativeDataType { native_type } => { - QuaintError::builder(ErrorKind::UnsupportedColumnType { - column_type: native_type, - }) - .build() - } - DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), - DriverAdapterError::Postgres(e) => e.into(), - DriverAdapterError::Mysql(e) => e.into(), - DriverAdapterError::Sqlite(e) => e.into(), - // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here - } - } -} - -/// Wrapper for JS-side result type -pub(crate) enum JsResult +/// Wrapper for JS-side result type. +/// This Napi-specific implementation has the same shape and API as the Wasm implementation, +/// but it asks for a `FromNapiValue` bound on the generic type. +/// The duplication is needed as it's currently impossible to have target-specific generic bounds in Rust. +pub(crate) enum AdapterResult where T: FromNapiValue, { @@ -38,7 +21,7 @@ where Err(DriverAdapterError), } -impl JsResult +impl AdapterResult where T: FromNapiValue, { @@ -55,7 +38,7 @@ where } } -impl FromNapiValue for JsResult +impl FromNapiValue for AdapterResult where T: FromNapiValue, { @@ -64,14 +47,14 @@ where } } -impl From> for quaint::Result +impl From> for quaint::Result where T: FromNapiValue, { - fn from(value: JsResult) -> Self { + fn from(value: AdapterResult) -> Self { match value { - JsResult::Ok(result) => Ok(result), - JsResult::Err(error) => Err(error.into()), + AdapterResult::Ok(result) => Ok(result), + AdapterResult::Err(error) => Err(error.into()), } } } diff --git a/query-engine/driver-adapters/src/wasm/async_js_function.rs b/query-engine/driver-adapters/src/wasm/async_js_function.rs index 3e67c622d647..bda40cc87a58 100644 --- a/query-engine/driver-adapters/src/wasm/async_js_function.rs +++ b/query-engine/driver-adapters/src/wasm/async_js_function.rs @@ -9,7 +9,7 @@ use wasm_bindgen_futures::JsFuture; use super::error::into_quaint_error; use super::from_js::FromJsValue; -use super::result::JsResult; +use crate::AdapterResult; // `serialize_missing_as_null` is required to make sure that "empty" values (e.g., `None` and `()`) // are serialized as `null` and not `undefined`. @@ -66,7 +66,7 @@ where } } - async fn call_internal(&self, arg1: T) -> Result, JsValue> { + async fn call_internal(&self, arg1: T) -> Result, JsValue> { let arg1 = arg1 .serialize(&SERIALIZER) .map_err(|err| JsValue::from(JsError::from(&err)))?; @@ -78,7 +78,7 @@ where return_value }; - let js_result = JsResult::::from_js_value(value)?; + let js_result = AdapterResult::::from_js_value(value)?; Ok(js_result) } diff --git a/query-engine/driver-adapters/src/wasm/mod.rs b/query-engine/driver-adapters/src/wasm/mod.rs index 655ea1a6080d..a71e6f5d21c9 100644 --- a/query-engine/driver-adapters/src/wasm/mod.rs +++ b/query-engine/driver-adapters/src/wasm/mod.rs @@ -4,7 +4,7 @@ mod async_js_function; mod error; mod from_js; mod js_object_extern; -mod result; +pub(crate) mod result; pub(crate) use async_js_function::AsyncJsFunction; pub(crate) use from_js::FromJsValue; diff --git a/query-engine/driver-adapters/src/wasm/result.rs b/query-engine/driver-adapters/src/wasm/result.rs index 2e656a205c41..18a9c4b26443 100644 --- a/query-engine/driver-adapters/src/wasm/result.rs +++ b/query-engine/driver-adapters/src/wasm/result.rs @@ -1,30 +1,14 @@ use js_sys::Boolean as JsBoolean; -use quaint::error::{Error as QuaintError, ErrorKind}; use wasm_bindgen::{JsCast, JsValue}; use super::from_js::FromJsValue; use crate::{error::DriverAdapterError, JsObjectExtern}; -impl From for QuaintError { - fn from(value: DriverAdapterError) -> Self { - match value { - DriverAdapterError::UnsupportedNativeDataType { native_type } => { - QuaintError::builder(ErrorKind::UnsupportedColumnType { - column_type: native_type, - }) - .build() - } - DriverAdapterError::GenericJs { id } => QuaintError::external_error(id), - DriverAdapterError::Postgres(e) => e.into(), - DriverAdapterError::Mysql(e) => e.into(), - DriverAdapterError::Sqlite(e) => e.into(), - // in future, more error types would be added and we'll need to convert them to proper QuaintErrors here - } - } -} - -/// Wrapper for JS-side result type -pub(crate) enum JsResult +/// Wrapper for JS-side result type. +/// This Wasm-specific implementation has the same shape and API as the Napi implementation, +/// but it asks for a `FromJsValue` bound on the generic type. +/// The duplication is needed as it's currently impossible to have target-specific generic bounds in Rust. +pub(crate) enum AdapterResult where T: FromJsValue, { @@ -32,7 +16,7 @@ where Err(DriverAdapterError), } -impl FromJsValue for JsResult +impl FromJsValue for AdapterResult where T: FromJsValue, { @@ -54,14 +38,14 @@ where } } -impl From> for quaint::Result +impl From> for quaint::Result where T: FromJsValue, { - fn from(value: JsResult) -> Self { + fn from(value: AdapterResult) -> Self { match value { - JsResult::Ok(result) => Ok(result), - JsResult::Err(error) => Err(error.into()), + AdapterResult::Ok(result) => Ok(result), + AdapterResult::Err(error) => Err(error.into()), } } } From f902ded0ca55a341728c035d91851797cf803846 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:54:52 +0100 Subject: [PATCH 125/134] chore(review): remove redudant full type qualifier --- libs/crosstarget-utils/src/native/spawn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/crosstarget-utils/src/native/spawn.rs b/libs/crosstarget-utils/src/native/spawn.rs index 8a8360c580fa..cd1d5246d123 100644 --- a/libs/crosstarget-utils/src/native/spawn.rs +++ b/libs/crosstarget-utils/src/native/spawn.rs @@ -2,7 +2,7 @@ use std::future::Future; use crate::common::SpawnError; -pub async fn spawn_if_possible(future: F) -> Result +pub async fn spawn_if_possible(future: F) -> Result where F: Future + 'static + Send, F::Output: Send + 'static, From 43b5146b03304b24c75211c567ade0db17dbe40e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:55:38 +0100 Subject: [PATCH 126/134] chore(review): revert changes to psl-core --- Cargo.lock | 1 - psl/psl-core/Cargo.toml | 3 --- 2 files changed, 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93bf18b7be68..a66cf4fcb9b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3508,7 +3508,6 @@ dependencies = [ "serde", "serde_json", "url", - "wasm-bindgen", ] [[package]] diff --git a/psl/psl-core/Cargo.toml b/psl/psl-core/Cargo.toml index 5cc959da9f33..0d4bea39b84e 100644 --- a/psl/psl-core/Cargo.toml +++ b/psl/psl-core/Cargo.toml @@ -22,6 +22,3 @@ indoc.workspace = true # For the connector API. lsp-types = "0.91.1" url = "2.2.1" - -[target.'cfg(target_arch = "wasm32")'.dependencies] -wasm-bindgen.workspace = true \ No newline at end of file From 02572928c1bf7361f7e100876c5d4cf55248fbab Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 14:59:08 +0100 Subject: [PATCH 127/134] chore(review): improve comments on js.rs --- .../sql-query-connector/src/database/js.rs | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index 449b9053bef1..16181755b04e 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -13,8 +13,10 @@ use quaint::{ }; use std::sync::{Arc, Mutex}; -// TODO: evaluate turning this into `Lazy>>>` to avoid -// a clone+drop on the adapter passed via `Js::from_source`. +/// TODO: evaluate turning this into `Lazy>>>` to avoid +/// a clone+drop on the adapter passed via `Js::from_source`. +/// Note: this is currently blocked by Napi causing linking errors when building test binaries, +/// as commented in [`DriverAdapter`]. static ACTIVE_DRIVER_ADAPTER: Lazy>> = Lazy::new(|| Mutex::new(None)); fn active_driver_adapter(provider: &str) -> connector::Result { @@ -86,23 +88,21 @@ impl Connector for Js { } } -// TODO: miguelff: I haven´t found a better way to do this, yet... please continue reading. -// -// There is a bug in NAPI-rs by wich compiling a binary crate that links code using napi-rs -// bindings breaks. We could have used a JsQueryable from the `driver-adapters` crate directly, as the -// `connection` field of a driver adapter, but that will imply using napi-rs transitively, and break -// the tests (which are compiled as binary creates) -// -// To avoid the problem above I separated interface from implementation, making DriverAdapter -// independent on napi-rs. Initially, I tried having a field Arc<&dyn TransactionCabable> to hold -// JsQueryable at runtime. I did this, because TransactionCapable is the trait bounds required to -// create a value of `SqlConnection` (see [SqlConnection::new])) to actually performt the queries. -// using JSQueryable. However, this didn't work because TransactionCapable is not object safe. -// (has Sized as a supertrait) -// -// The thing is that TransactionCapable is not object safe and cannot be used in a dynamic type -// declaration, so finally I couldn't come up with anything better then wrapping a QuaintQueryable -// in this object, and implementing TransactionCapable (and quaint::Queryable) explicitly for it. +/// There is a bug in NAPI-rs by wich compiling a binary crate that links code using napi-rs +/// bindings breaks. We could have used a JsQueryable from the `driver-adapters` crate directly, as the +/// `connection` field of a driver adapter, but that will imply using napi-rs transitively, and break +/// the tests (which are compiled as binary creates) +/// +/// To avoid the problem above I separated interface from implementation, making DriverAdapter +/// independent on napi-rs. Initially, I tried having a field Arc<&dyn TransactionCabable> to hold +/// JsQueryable at runtime. I did this, because TransactionCapable is the trait bounds required to +/// create a value of `SqlConnection` (see [SqlConnection::new])) to actually performt the queries. +/// using JSQueryable. However, this didn't work because TransactionCapable is not object safe. +/// (has Sized as a supertrait) +/// +/// The thing is that TransactionCapable is not object safe and cannot be used in a dynamic type +/// declaration, so finally I couldn't come up with anything better then wrapping a QuaintQueryable +/// in this object, and implementing TransactionCapable (and quaint::Queryable) explicitly for it. #[derive(Clone)] pub struct DriverAdapter { connector: Arc, From 9a99fd231774440deb5df88a477e1480ee0ab5fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Fern=C3=A1ndez?= Date: Tue, 5 Dec 2023 15:09:59 +0100 Subject: [PATCH 128/134] Test fixes for NAPI tests (#4515) --- .../query-tests-setup/src/connector_tag/js.rs | 2 +- .../src/connector_tag/js/external_process.rs | 62 ++++++++++++++----- query-engine/driver-adapters/src/lib.rs | 2 +- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs index c852924bbf69..2ec8513baeda 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js.rs @@ -3,7 +3,7 @@ mod external_process; use super::*; use external_process::*; use serde::de::DeserializeOwned; -use std::sync::atomic::AtomicU64; +use std::{collections::HashMap, sync::atomic::AtomicU64}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; pub(crate) async fn executor_process_request( diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs index 912a5e6d8abf..06d1551f9405 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/js/external_process.rs @@ -140,6 +140,42 @@ impl Display for ExecutorProcessDiedError { impl StdError for ExecutorProcessDiedError {} +struct PendingRequests { + map: HashMap>>, + last_id: Option, +} + +impl PendingRequests { + fn new() -> Self { + Self { + map: HashMap::new(), + last_id: None, + } + } + + fn insert(&mut self, id: jsonrpc_core::Id, sender: oneshot::Sender>) { + self.map.insert(id.clone(), sender); + self.last_id = Some(id); + } + + fn respond(&mut self, id: &jsonrpc_core::Id, response: Result) { + self.map + .remove(id) + .expect("no sender for response") + .send(response) + .unwrap(); + } + + fn respond_to_last(&mut self, response: Result) { + let last_id = self + .last_id + .as_ref() + .expect("Expected last response to exist") + .to_owned(); + self.respond(&last_id, response); + } +} + pub(super) static EXTERNAL_PROCESS: Lazy = Lazy::new(RestartableExecutorProcess::new); type ReqImpl = ( @@ -173,7 +209,7 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { let mut stdout = BufReader::new(process.stdout.unwrap()).lines(); let mut stdin = process.stdin.unwrap(); - let mut last_pending_request: Option<(jsonrpc_core::Id, oneshot::Sender>)> = None; + let mut pending_requests = PendingRequests::new(); loop { tokio::select! { @@ -186,24 +222,20 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { Ok(Some(line)) => // new response { match serde_json::from_str::(&line) { - Ok(response) => { - let (id, sender) = last_pending_request.take().expect("got a response from the external process, but there was no pending request"); - if &id != response.id() { - unreachable!("got a response from the external process, but the id didn't match. Are you running with cargo tests with `--test-threads=1`"); - } - - match response { + Ok(ref response) => { + let res: Result = match response { jsonrpc_core::Output::Success(success) => { // The other end may be dropped if the whole // request future was dropped and not polled to // completion, so we ignore send errors here. - _ = sender.send(Ok(success.result)); + Ok(success.result.clone()) } jsonrpc_core::Output::Failure(err) => { tracing::error!("error response from jsonrpc: {err:?}"); - _ = sender.send(Err(Box::new(err.error))); + Err(Box::new(err.error.clone())) } - } + }; + pending_requests.respond(response.id(), res) } Err(err) => { tracing::error!(%err, "error when decoding response from child node process. Response was: `{}`", &line); @@ -214,9 +246,8 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { Ok(None) => // end of the stream { tracing::error!("Error when reading from child node process. Process might have exited. Restarting..."); - if let Some((_, sender)) = last_pending_request.take() { - sender.send(Err(Box::new(ExecutorProcessDiedError))).unwrap(); - } + + pending_requests.respond_to_last(Err(Box::new(ExecutorProcessDiedError))); EXTERNAL_PROCESS.restart().await; break; } @@ -233,7 +264,8 @@ fn start_rpc_thread(mut receiver: mpsc::Receiver) -> Result<()> { exit_with_message(1, "The json-rpc client channel was closed"); } Some((request, response_sender)) => { - last_pending_request = Some((request.id.clone(), response_sender)); + pending_requests.insert(request.id.clone(), response_sender); + let mut req = serde_json::to_vec(&request).unwrap(); req.push(b'\n'); stdin.write_all(&req).await.unwrap(); diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 6be1270b3268..c175942dc04c 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -54,7 +54,7 @@ mod arch { } pub(crate) fn has_named_property(object: &JsObject, name: &str) -> JsResult { - Ok(Reflect::has(&object, &JsString::from_str(name).unwrap().into())?) + Reflect::has(object, &JsString::from_str(name).unwrap().into()) } pub(crate) fn to_rust_str(value: JsString) -> JsResult { From e6dc74fbfbe0a6db61d8c3e649d5f69a4cdecab9 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 15:17:32 +0100 Subject: [PATCH 129/134] chore(review): comment on spawn_controlled actors, add error tracing to unexpected case --- query-engine/core/src/interactive_transactions/actors.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 071512cee06d..cc6de7ce9036 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -388,6 +388,10 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { + // Note: tasks implemented via loops cannot be cancelled implicitly, so we need to spawn them in a + // "controlled" way, via `spawn_controlled`. + // The `rx_exit` receiver is used to signal the loop to exit, and that signal is emitted whenever + // the task is aborted (likely, due to the engine shutting down and cleaning up the allocated resources). spawn_controlled(Box::new( |mut rx_exit: tokio::sync::broadcast::Receiver<()>| async move { loop { @@ -406,6 +410,7 @@ pub(crate) fn spawn_client_list_clear_actor( } None => { // the `rx` channel is closed. + tracing::error!("rx channel is closed!"); break; } } From bd0612cb5f88084604edd57bd5500365a0d42a64 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 15:18:49 +0100 Subject: [PATCH 130/134] chore(review): remove unused dependency --- query-engine/driver-adapters/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 24f69164f138..9f9db91287f0 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -13,7 +13,6 @@ tracing-core = "0.1" metrics = "0.18" uuid = { version = "1", features = ["v4"] } pin-project = "1" -wasm-rs-dbg = "0.1.2" serde_repr.workspace = true futures = "0.3" From 4100e5b638f5a7570fef203290741bb58bef791e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 5 Dec 2023 15:32:07 +0100 Subject: [PATCH 131/134] chore(review): move wasm/napi-specific task JoinHandle stuff to crosstarget-utils --- Cargo.lock | 4 +- libs/crosstarget-utils/Cargo.toml | 6 +- libs/crosstarget-utils/src/native/mod.rs | 1 + libs/crosstarget-utils/src/native/task.rs | 46 +++++++ libs/crosstarget-utils/src/wasm/mod.rs | 1 + libs/crosstarget-utils/src/wasm/task.rs | 64 +++++++++ query-engine/core/Cargo.toml | 7 - query-engine/core/src/executor/mod.rs | 1 - query-engine/core/src/executor/task.rs | 124 ------------------ .../interactive_transactions/actor_manager.rs | 2 +- .../src/interactive_transactions/actors.rs | 2 +- 11 files changed, 120 insertions(+), 138 deletions(-) create mode 100644 libs/crosstarget-utils/src/native/task.rs create mode 100644 libs/crosstarget-utils/src/wasm/task.rs delete mode 100644 query-engine/core/src/executor/task.rs diff --git a/Cargo.lock b/Cargo.lock index a66cf4fcb9b9..4733dbc42f4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -842,7 +842,9 @@ dependencies = [ name = "crosstarget-utils" version = "0.1.0" dependencies = [ + "futures", "js-sys", + "pin-project", "tokio", "wasm-bindgen", "wasm-bindgen-futures", @@ -3672,7 +3674,6 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", - "pin-project", "psl", "query-connector", "query-engine-metrics", @@ -3688,7 +3689,6 @@ dependencies = [ "tracing-subscriber", "user-facing-errors", "uuid", - "wasm-bindgen-futures", ] [[package]] diff --git a/libs/crosstarget-utils/Cargo.toml b/libs/crosstarget-utils/Cargo.toml index 6fd110652afe..627efbf23c36 100644 --- a/libs/crosstarget-utils/Cargo.toml +++ b/libs/crosstarget-utils/Cargo.toml @@ -5,13 +5,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies] +futures = "0.3" [target.'cfg(target_arch = "wasm32")'.dependencies] js-sys.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true -tokio = { version = "1.25", features = ["macros"] } - +tokio = { version = "1.25", features = ["macros", "sync"] } +pin-project = "1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio.workspace = true diff --git a/libs/crosstarget-utils/src/native/mod.rs b/libs/crosstarget-utils/src/native/mod.rs index b801d82e3118..b19a356ff8ff 100644 --- a/libs/crosstarget-utils/src/native/mod.rs +++ b/libs/crosstarget-utils/src/native/mod.rs @@ -1,2 +1,3 @@ pub mod spawn; +pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/native/task.rs b/libs/crosstarget-utils/src/native/task.rs new file mode 100644 index 000000000000..017d6866eb2d --- /dev/null +++ b/libs/crosstarget-utils/src/native/task.rs @@ -0,0 +1,46 @@ +use futures::Future; +use tokio::sync::broadcast::{self}; + +pub struct JoinHandle { + handle: tokio::task::JoinHandle, + + sx_exit: Option>, +} + +impl JoinHandle { + pub fn abort(&mut self) { + if let Some(sx_exit) = self.sx_exit.as_ref() { + sx_exit.send(()).ok(); + } + + self.handle.abort(); + } +} + +pub fn spawn(future: T) -> JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + spawn_with_sx_exit::(future, None) +} + +pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); + let future = future_fn(rx_exit); + + spawn_with_sx_exit::(future, Some(sx_exit)) +} + +fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + let handle = tokio::spawn(future); + JoinHandle { handle, sx_exit } +} diff --git a/libs/crosstarget-utils/src/wasm/mod.rs b/libs/crosstarget-utils/src/wasm/mod.rs index b801d82e3118..b19a356ff8ff 100644 --- a/libs/crosstarget-utils/src/wasm/mod.rs +++ b/libs/crosstarget-utils/src/wasm/mod.rs @@ -1,2 +1,3 @@ pub mod spawn; +pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/wasm/task.rs b/libs/crosstarget-utils/src/wasm/task.rs new file mode 100644 index 000000000000..80bbc6991c89 --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/task.rs @@ -0,0 +1,64 @@ +use futures::Future; +use tokio::sync::{ + broadcast::{self}, + oneshot::{self}, +}; + +// Wasm-compatible alternative to `tokio::task::JoinHandle`. +// `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. +#[pin_project::pin_project] +pub struct JoinHandle { + #[pin] + receiver: oneshot::Receiver, + + sx_exit: Option>, +} + +impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + core::pin::Pin::new(&mut self.receiver).poll(cx) + } +} + +impl JoinHandle { + pub fn abort(&mut self) { + if let Some(sx_exit) = self.sx_exit.as_ref() { + sx_exit.send(()).ok(); + } + } +} + +pub fn spawn(future: T) -> JoinHandle +where + T: Future + 'static, + T::Output: Send + 'static, +{ + spawn_with_sx_exit::(future, None) +} + +pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle +where + T: Future + 'static, + T::Output: Send + 'static, +{ + let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); + let future = future_fn(rx_exit); + spawn_with_sx_exit::(future, Some(sx_exit)) +} + +fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle +where + T: Future + 'static, + T::Output: Send + 'static, +{ + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + + JoinHandle { receiver, sx_exit } +} diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index da9e8331dfdf..192f32b217ad 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -40,10 +40,3 @@ schema = { path = "../schema" } crosstarget-utils = { path = "../../libs/crosstarget-utils" } lru = "0.7.7" enumflags2 = "0.7" - -pin-project = "1" -wasm-bindgen-futures = "0.4" - -[target.'cfg(target_arch = "wasm32")'.dependencies] -pin-project = "1" -wasm-bindgen-futures.workspace = true diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 01a9e09674db..fee7bc68fe7b 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -10,7 +10,6 @@ mod execute_operation; mod interpreting_executor; mod pipeline; mod request_context; -pub(crate) mod task; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs deleted file mode 100644 index 3113ecd28f88..000000000000 --- a/query-engine/core/src/executor/task.rs +++ /dev/null @@ -1,124 +0,0 @@ -//! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. - -pub use arch::{spawn, spawn_controlled, JoinHandle}; -use futures::Future; - -// On native targets, `tokio::spawn` spawns a new asynchronous task. -#[cfg(not(target_arch = "wasm32"))] -mod arch { - use super::*; - use tokio::sync::broadcast::{self}; - - pub struct JoinHandle { - handle: tokio::task::JoinHandle, - - sx_exit: Option>, - } - - impl JoinHandle { - pub fn abort(&mut self) { - if let Some(sx_exit) = self.sx_exit.as_ref() { - sx_exit.send(()).ok(); - } - - self.handle.abort(); - } - } - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - spawn_with_sx_exit::(future, None) - } - - pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); - let future = future_fn(rx_exit); - - spawn_with_sx_exit::(future, Some(sx_exit)) - } - - fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let handle = tokio::spawn(future); - JoinHandle { handle, sx_exit } - } -} - -// On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. -#[cfg(target_arch = "wasm32")] -mod arch { - use super::*; - use tokio::sync::{ - broadcast::{self}, - oneshot::{self}, - }; - - // Wasm-compatible alternative to `tokio::task::JoinHandle`. - // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. - #[pin_project::pin_project] - pub struct JoinHandle { - #[pin] - receiver: oneshot::Receiver, - - sx_exit: Option>, - } - - impl Future for JoinHandle { - type Output = Result; - - fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - // the `self.project()` method is provided by the `pin_project` macro - core::pin::Pin::new(&mut self.receiver).poll(cx) - } - } - - impl JoinHandle { - pub fn abort(&mut self) { - if let Some(sx_exit) = self.sx_exit.as_ref() { - sx_exit.send(()).ok(); - } - } - } - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + 'static, - T::Output: Send + 'static, - { - spawn_with_sx_exit::(future, None) - } - - pub fn spawn_controlled(future_fn: Box) -> T>) -> JoinHandle - where - T: Future + 'static, - T::Output: Send + 'static, - { - let (sx_exit, rx_exit) = tokio::sync::broadcast::channel::<()>(1); - let future = future_fn(rx_exit); - spawn_with_sx_exit::(future, Some(sx_exit)) - } - - fn spawn_with_sx_exit(future: T, sx_exit: Option>) -> JoinHandle - where - T: Future + 'static, - T::Output: Send + 'static, - { - let (sender, receiver) = oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - sender.send(result).ok(); - }); - - JoinHandle { receiver, sx_exit } - } -} diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index f2d1f539ebbf..e6c1c7fbd1dc 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -1,6 +1,6 @@ -use crate::executor::task::JoinHandle; use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; use connector::Connection; +use crosstarget_utils::task::JoinHandle; use lru::LruCache; use once_cell::sync::Lazy; use schema::QuerySchemaRef; diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index cc6de7ce9036..86ebd5c13b84 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,10 +1,10 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; -use crate::executor::task::{spawn, spawn_controlled, JoinHandle}; use crate::{ execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, TxId, }; use connector::Connection; +use crosstarget_utils::task::{spawn, spawn_controlled, JoinHandle}; use crosstarget_utils::time::ElapsedTimeCounter; use schema::QuerySchemaRef; use std::{collections::HashMap, sync::Arc}; From c7c858647664d57121f462cce18cee47308bfd56 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 6 Dec 2023 10:44:56 +0100 Subject: [PATCH 132/134] qe-wasm: Partially fix tests (#4517) * qe-wasm: Partially fix the test suite 1. Bash script for build did not abort on error, hence sed error on linux went unnoticed. 2. We don't actually need sed trickery sincce `wasm-pack` has a flag for changing binary name. 3. `tracing` feature does not actually work on WASM even partially: it panics on `Instant` invokation as soon as first span is created. Since we are running tests with all preview features enabled, that means that practically any test panics now. Disabled it again. A lot of tests are still failing on ThreadRng invocation and stacktrace is not really helpful, but I still think it's better if we get it working. * Update query-engine/query-engine-wasm/build.sh --- query-engine/query-engine-wasm/build.sh | 21 ++++--------------- .../example/prisma/schema.prisma | 4 ++-- .../query-engine-wasm/src/wasm/engine.rs | 4 ++-- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/query-engine/query-engine-wasm/build.sh b/query-engine/query-engine-wasm/build.sh index 3a23babc5e30..e4db9fbad6da 100755 --- a/query-engine/query-engine-wasm/build.sh +++ b/query-engine/query-engine-wasm/build.sh @@ -1,23 +1,17 @@ #!/bin/bash -set -e # Call this script as `./build.sh ` +set -euo pipefail -OUT_VERSION="$1" +OUT_VERSION="${1:-}" OUT_FOLDER="pkg" OUT_JSON="${OUT_FOLDER}/package.json" OUT_TARGET="bundler" OUT_NPM_NAME="@prisma/query-engine-wasm" -# The local ./Cargo.toml file uses "name = "query_engine_wasm" as library name -# to avoid conflicts with libquery's `name = "query_engine"` library name declaration. -# This little `sed -i` trick below is a hack to publish "@prisma/query-engine-wasm" -# with the same binding filenames currently expected by the Prisma Client. -sed -i.bak 's/name = "query_engine_wasm"/name = "query_engine"/g' Cargo.toml - # use `wasm-pack build --release` on CI only -if [[ -z "$BUILDKITE" ]] && [[ -z "$GITHUB_ACTIONS" ]]; then +if [[ -z "${BUILDKITE:-}" ]] && [[ -z "${GITHUB_ACTIONS:-}" ]]; then BUILD_PROFILE="--dev" else BUILD_PROFILE="--release" @@ -31,14 +25,7 @@ then curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh fi -wasm-pack build $BUILD_PROFILE --target $OUT_TARGET - -sed -i.bak 's/name = "query_engine"/name = "query_engine_wasm"/g' Cargo.toml - -# Remove the backup file created by sed. We only created it because there's no -# cross-platform way to specify we don't need one (it's just `-i` in GNU sed -# but `-i ""` in BSD sed). -rm Cargo.toml.bak +wasm-pack build $BUILD_PROFILE --target $OUT_TARGET --out-name query_engine sleep 1 diff --git a/query-engine/query-engine-wasm/example/prisma/schema.prisma b/query-engine/query-engine-wasm/example/prisma/schema.prisma index 8e6b86202536..c6432a4a671f 100644 --- a/query-engine/query-engine-wasm/example/prisma/schema.prisma +++ b/query-engine/query-engine-wasm/example/prisma/schema.prisma @@ -4,8 +4,8 @@ datasource db { } generator client { - provider = "prisma-client-js" - previewFeatures = ["driverAdapters"] + provider = "prisma-client-js" + previewFeatures = ["driverAdapters", "tracing"] } model User { diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index 3413a8af8d76..92b352d76df5 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -7,7 +7,6 @@ use crate::{ }; use driver_adapters::JsObject; use js_sys::Function as JsFunction; -use psl::PreviewFeature; use query_core::{ protocol::EngineProtocol, schema::{self, QuerySchema}, @@ -180,7 +179,8 @@ impl QueryEngine { .validate_that_one_datasource_is_provided() .map_err(|errors| ApiError::conversion(errors, schema.db.source()))?; - let enable_tracing = config.preview_features().contains(PreviewFeature::Tracing); + // Telemetry panics on timings if preview feature is enabled + let enable_tracing = false; // config.preview_features().contains(PreviewFeature::Tracing); let engine_protocol = engine_protocol.unwrap_or(EngineProtocol::Json); let builder = EngineBuilder { From 14ceb7c2ade402ab7fb44958b6b83279912db725 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 6 Dec 2023 11:57:10 +0100 Subject: [PATCH 133/134] qe-wasm: Fix RNG on Node 18 in a test runner (#4526) --- .../driver-adapters/connector-test-kit-executor/src/index.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts index 4e847742e51b..632a01c89eab 100644 --- a/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts +++ b/query-engine/driver-adapters/connector-test-kit-executor/src/index.ts @@ -22,6 +22,9 @@ import { PrismaPlanetScale } from '@prisma/adapter-planetscale' import {bindAdapter, DriverAdapter, ErrorCapturingDriverAdapter} from "@prisma/driver-adapter-utils"; +import { webcrypto } from 'node:crypto'; + +(global as any).crypto = webcrypto const SUPPORTED_ADAPTERS: Record Promise> From a75899457fd3ca24f4d79f04e038dc697562df11 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 8 Dec 2023 09:01:15 +0100 Subject: [PATCH 134/134] qe: Skipping failing tests on WASM (#4527) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add a way to skip WASM tests * Temproary skip failing wasm tests on pg and neon * Skip timstamp test * Fix incorrect skips * Parse libsql.js.wasm version * Skip libsql wasm tests * Fix vitess version parsing * Skip wasm planetscale tests that are also skipped for napi * Rename jobs * Skip planetscale wasm tests * Fix some incorrect skips * And more fixes * Be careful when copy-pasting, please * One more fix * Fix null lists * Again: be careful when copy-pasting * Skip timestamp test on libsql and planetscale * Update .github/workflows/test-query-engine-driver-adapters.yml Co-authored-by: Joël Galeran --------- Co-authored-by: Joël Galeran --- .../test-query-engine-driver-adapters.yml | 16 ++++---- .../tests/new/assertion_violation_error.rs | 10 ++++- .../tests/new/interactive_tx.rs | 6 +-- .../query-engine-tests/tests/new/metrics.rs | 7 ++-- .../tests/new/multi_schema.rs | 5 ++- .../query-engine-tests/tests/new/occ.rs | 12 ++++-- .../new/ref_actions/on_delete/set_default.rs | 40 +++++++++++++++---- .../tests/new/regressions/max_integer.rs | 11 ++++- .../tests/new/regressions/prisma_12572.rs | 6 ++- .../tests/new/regressions/prisma_15204.rs | 16 ++++++-- .../tests/new/regressions/prisma_17103.rs | 2 +- .../tests/new/regressions/prisma_7434.rs | 8 +++- .../tests/queries/aggregation/avg.rs | 4 +- .../queries/aggregation/combination_spec.rs | 4 +- .../tests/queries/aggregation/count.rs | 2 +- .../tests/queries/aggregation/max.rs | 4 +- .../tests/queries/aggregation/min.rs | 4 +- .../tests/queries/aggregation/sum.rs | 4 +- .../queries/batch/in_selection_batching.rs | 19 ++++++++- .../queries/batch/transactional_batch.rs | 6 +-- .../tests/queries/data_types/bytes.rs | 10 ++++- .../queries/data_types/through_relation.rs | 8 +++- .../tests/queries/filters/bigint_filter.rs | 5 ++- .../tests/queries/filters/bytes_filter.rs | 5 ++- .../filters/field_reference/bigint_filter.rs | 16 ++++++-- .../filters/field_reference/bytes_filter.rs | 17 ++++++-- .../field_reference/datetime_filter.rs | 16 ++++++-- .../filters/field_reference/float_filter.rs | 16 ++++++-- .../filters/field_reference/int_filter.rs | 16 ++++++-- .../filters/field_reference/json_filter.rs | 4 +- .../filters/field_reference/string_filter.rs | 12 ++++-- .../tests/queries/filters/json.rs | 5 +-- .../tests/queries/filters/json_filters.rs | 14 +++---- .../tests/queries/filters/list_filters.rs | 8 +++- .../tests/queries/filters/search_filter.rs | 2 +- .../order_and_pagination/nested_pagination.rs | 6 +-- .../order_by_dependent.rs | 4 +- .../order_by_dependent_pagination.rs | 6 +-- .../order_and_pagination/pagination.rs | 6 +-- .../query-engine-tests/tests/raw/sql/casts.rs | 2 +- .../tests/raw/sql/errors.rs | 2 +- .../tests/raw/sql/input_coercion.rs | 2 +- .../tests/raw/sql/null_list.rs | 6 ++- .../tests/raw/sql/typed_output.rs | 4 +- .../tests/writes/data_types/bigint.rs | 6 ++- .../tests/writes/data_types/bytes.rs | 17 +++++++- .../data_types/native_types/postgres.rs | 8 +++- .../writes/data_types/scalar_list/base.rs | 6 +-- .../writes/data_types/scalar_list/defaults.rs | 2 +- .../tests/writes/ids/byoid.rs | 8 ++-- .../nested_update_many_inside_update.rs | 14 +++---- .../nested_create_many.rs | 2 +- .../compound_fks_mixed_requiredness.rs | 2 +- .../writes/top_level_mutations/create.rs | 2 +- .../writes/top_level_mutations/create_many.rs | 2 +- .../writes/top_level_mutations/update_many.rs | 4 +- .../writes/top_level_mutations/upsert.rs | 2 +- .../src/connector_tag/mod.rs | 20 ++++++---- .../src/connector_tag/postgres.rs | 18 ++++++--- .../src/connector_tag/sqlite.rs | 9 +++-- .../src/connector_tag/vitess.rs | 9 +++-- .../test-configs/libsql-wasm | 1 + .../test-configs/neon-wasm | 2 +- .../test-configs/pg-wasm | 2 +- .../test-configs/planetscale-wasm | 2 +- 65 files changed, 359 insertions(+), 157 deletions(-) diff --git a/.github/workflows/test-query-engine-driver-adapters.yml b/.github/workflows/test-query-engine-driver-adapters.yml index 08bd3b192eac..d9af2a375a98 100644 --- a/.github/workflows/test-query-engine-driver-adapters.yml +++ b/.github/workflows/test-query-engine-driver-adapters.yml @@ -25,24 +25,24 @@ jobs: fail-fast: false matrix: adapter: - - name: '@prisma/adapter-planetscale' + - name: 'planetscale (napi)' setup_task: 'dev-planetscale-js' - - name: '@prisma/adapter-pg (napi)' + - name: 'pg (napi)' setup_task: 'dev-pg-js' - - name: '@prisma/adapter-neon (ws) (napi)' + - name: 'neon (ws) (napi)' setup_task: 'dev-neon-js' - - name: '@prisma/adapter-libsql (Turso) (napi)' + - name: 'libsql (Turso) (napi)' setup_task: 'dev-libsql-js' - - name: '@prisma/adapter-planetscale' + - name: 'planetscale (wasm)' setup_task: 'dev-planetscale-wasm' needs_wasm_pack: true - - name: '@prisma/adapter-pg (wasm)' + - name: 'pg (wasm)' setup_task: 'dev-pg-wasm' needs_wasm_pack: true - - name: '@prisma/adapter-neon (ws) (wasm)' + - name: 'neon (ws) (wasm)' setup_task: 'dev-neon-wasm' needs_wasm_pack: true - - name: '@prisma/adapter-libsql (Turso) (wasm)' + - name: 'libsql (Turso) (wasm)' setup_task: 'dev-libsql-wasm' needs_wasm_pack: true node_version: ['18'] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs index 73455011d04e..62add25c3e72 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/assertion_violation_error.rs @@ -2,7 +2,15 @@ use query_engine_tests::*; #[test_suite(schema(generic))] mod raw_params { - #[connector_test(only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] + #[connector_test( + only(Postgres), + exclude( + Postgres("neon.js"), + Postgres("pg.js"), + Postgres("neon.js.wasm"), + Postgres("pg.js.wasm") + ) + )] async fn value_too_many_bind_variables(runner: Runner) -> TestResult<()> { let n = 32768; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index 33908a9e079e..4372b23c282d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -1,7 +1,7 @@ use query_engine_tests::test_suite; use std::borrow::Cow; -#[test_suite(schema(generic), exclude(Vitess("planetscale.js")))] +#[test_suite(schema(generic), exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] mod interactive_tx { use query_engine_tests::*; use tokio::time; @@ -573,7 +573,7 @@ mod itx_isolation { use query_engine_tests::*; // All (SQL) connectors support serializable. - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -595,7 +595,7 @@ mod itx_isolation { Ok(()) } - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; runner.set_active_tx(tx_id.clone()); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index cd270bb334c6..dff1ecdb03a5 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -3,10 +3,9 @@ use query_engine_tests::test_suite; #[test_suite( schema(generic), exclude( - Vitess("planetscale.js"), - Postgres("neon.js"), - Postgres("pg.js"), - Sqlite("libsql.js") + Vitess("planetscale.js", "planetscale.js.wasm"), + Postgres("neon.js", "pg.js", "neon.js.wasm", "pg.js.wasm"), + Sqlite("libsql.js", "libsql.js.wasm") ) )] mod metrics { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs index 29c93689f542..40f646277f2c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs @@ -1,6 +1,9 @@ use query_engine_tests::test_suite; -#[test_suite(capabilities(MultiSchema), exclude(Mysql, Vitess("planetscale.js")))] +#[test_suite( + capabilities(MultiSchema), + exclude(Mysql, Vitess("planetscale.js", "planetscale.js.wasm")) +)] mod multi_schema { use query_engine_tests::*; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs index b495c8627e5a..d074a223531e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs @@ -112,7 +112,10 @@ mod occ { assert_eq!(booked_user_id, found_booked_user_id); } - #[connector_test(schema(occ_simple), exclude(MongoDB, CockroachDb, Vitess("planetscale.js")))] + #[connector_test( + schema(occ_simple), + exclude(MongoDB, CockroachDb, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn occ_update_many_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); @@ -127,7 +130,10 @@ mod occ { Ok(()) } - #[connector_test(schema(occ_simple), exclude(CockroachDb, Vitess("planetscale.js")))] + #[connector_test( + schema(occ_simple), + exclude(CockroachDb, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn occ_update_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); @@ -158,7 +164,7 @@ mod occ { Ok(()) } - #[connector_test(schema(occ_simple), exclude(Vitess("planetscale.js")))] + #[connector_test(schema(occ_simple), exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn occ_delete_test(runner: Runner) -> TestResult<()> { let runner = Arc::new(runner); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs index 40ef54ed11f1..d96c3d3576ff 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs @@ -2,7 +2,10 @@ use indoc::indoc; use query_engine_tests::*; -#[test_suite(suite = "setdefault_onD_1to1_req", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] +#[test_suite( + suite = "setdefault_onD_1to1_req", + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) +)] mod one2one_req { fn required_with_default() -> String { let schema = indoc! { @@ -66,7 +69,10 @@ mod one2one_req { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] + #[connector_test( + schema(required_with_default), + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), @@ -103,7 +109,10 @@ mod one2one_req { } } -#[test_suite(suite = "setdefault_onD_1to1_opt", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] +#[test_suite( + suite = "setdefault_onD_1to1_opt", + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) +)] mod one2one_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -167,7 +176,10 @@ mod one2one_opt { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] + #[connector_test( + schema(optional_with_default), + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), @@ -206,7 +218,10 @@ mod one2one_opt { } } -#[test_suite(suite = "setdefault_onD_1toM_req", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] +#[test_suite( + suite = "setdefault_onD_1toM_req", + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) +)] mod one2many_req { fn required_with_default() -> String { let schema = indoc! { @@ -270,7 +285,10 @@ mod one2many_req { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(required_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] + #[connector_test( + schema(required_with_default), + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, children: { create: { id: 1 }}}) { id }}"#), @@ -307,7 +325,10 @@ mod one2many_req { } } -#[test_suite(suite = "setdefault_onD_1toM_opt", exclude(MongoDb, MySQL, Vitess("planetscale.js")))] +#[test_suite( + suite = "setdefault_onD_1toM_opt", + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) +)] mod one2many_opt { fn optional_with_default() -> String { let schema = indoc! { @@ -371,7 +392,10 @@ mod one2many_opt { } /// Deleting the parent reconnects the child to the default and fails (the default doesn't exist). - #[connector_test(schema(optional_with_default), exclude(MongoDb, MySQL, Vitess("planetscale.js")))] + #[connector_test( + schema(optional_with_default), + exclude(MongoDb, MySQL, Vitess("planetscale.js", "planetscale.js.wasm")) + )] async fn delete_parent_no_exist_fail(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, children: { create: { id: 1 }}}) { id }}"#), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs index 78206f6394a6..e00b2d22e198 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/max_integer.rs @@ -187,7 +187,16 @@ mod max_integer { schema.to_owned() } - #[connector_test(schema(overflow_pg), only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] + #[connector_test( + schema(overflow_pg), + only(Postgres), + exclude( + Postgres("neon.js"), + Postgres("pg.js"), + Postgres("neon.js.wasm"), + Postgres("pg.js.wasm") + ) + )] async fn unfitted_int_should_fail_pg_quaint(runner: Runner) -> TestResult<()> { // int assert_error!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_12572.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_12572.rs index 35f056f8fa80..a107b354d159 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_12572.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_12572.rs @@ -26,7 +26,11 @@ mod prisma_12572 { .to_owned() } - #[connector_test] + #[connector_test(exclude( + Postgres("pg.js.wasm", "neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ))] async fn all_generated_timestamps_are_the_same(runner: Runner) -> TestResult<()> { runner .query(r#"mutation { createOneTest1(data: {id:"one", test2s: { create: {id: "two"}}}) { id }}"#) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs index 8582c14d0bc0..9f4ccdcb3b11 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15204.rs @@ -24,7 +24,11 @@ mod conversion_error { schema.to_owned() } - #[connector_test(schema(schema_int), only(Sqlite), exclude(Sqlite("libsql.js")))] + #[connector_test( + schema(schema_int), + only(Sqlite), + exclude(Sqlite("libsql.js"), Sqlite("libsql.js.wasm")) + )] async fn convert_to_int_sqlite_quaint(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -38,7 +42,7 @@ mod conversion_error { Ok(()) } - #[connector_test(schema(schema_int), only(Sqlite("libsql.js")))] + #[connector_test(schema(schema_int), only(Sqlite("libsql.js"), Sqlite("libsql.js.wasm")))] async fn convert_to_int_sqlite_js(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -52,7 +56,11 @@ mod conversion_error { Ok(()) } - #[connector_test(schema(schema_bigint), only(Sqlite), exclude(Sqlite("libsql.js")))] + #[connector_test( + schema(schema_bigint), + only(Sqlite), + exclude(Sqlite("libsql.js"), Sqlite("libsql.js.wasm")) + )] async fn convert_to_bigint_sqlite_quaint(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -66,7 +74,7 @@ mod conversion_error { Ok(()) } - #[connector_test(schema(schema_bigint), only(Sqlite("libsql.js")))] + #[connector_test(schema(schema_bigint), only(Sqlite("libsql.js"), Sqlite("libsql.js.wasm")))] async fn convert_to_bigint_sqlite_js(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs index c9065ec54c58..8168b66a3a0f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_17103.rs @@ -21,7 +21,7 @@ mod prisma_17103 { schema.to_owned() } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn regression(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_7434.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_7434.rs index e5fa8388d66e..8e5fb2457b15 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_7434.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_7434.rs @@ -4,7 +4,13 @@ use query_engine_tests::*; mod not_in_batching { use query_engine_tests::Runner; - #[connector_test] + #[connector_test(exclude( + CockroachDb, + Postgres("pg.js.wasm"), + Postgres("neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ))] async fn not_in_batch_filter(runner: Runner) -> TestResult<()> { runner.query(r#"mutation { createManyTestModel(data: [{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}]) { count }}"#).await?.assert_success(); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs index 4793fa24ae2a..387d05dc5e21 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/avg.rs @@ -33,7 +33,7 @@ mod aggregation_avg { Ok(()) } - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn avg_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10" }"#).await?; @@ -126,7 +126,7 @@ mod decimal_aggregation_avg { Ok(()) } - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn avg_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs index 3c1f1b092690..e7116894cffe 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/combination_spec.rs @@ -87,7 +87,7 @@ mod combinations { } // Mongo precision issue. - #[connector_test(exclude(MongoDB, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDB, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn with_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: "1", float: 5.5, int: 5 }"#).await?; create_row(&runner, r#"{ id: "2", float: 4.5, int: 10 }"#).await?; @@ -369,7 +369,7 @@ mod decimal_combinations { } // Mongo precision issue. - #[connector_test(exclude(MongoDB, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDB, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn with_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: "1", dec: "5.5" }"#).await?; create_row(&runner, r#"{ id: "2", dec: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs index 78ab88fd59c6..043419a58b2d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/count.rs @@ -27,7 +27,7 @@ mod aggregation_count { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn count_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, string: "1" }"#).await?; create_row(&runner, r#"{ id: 2, string: "2" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs index 12f9b6861892..9c6c055e939d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/max.rs @@ -30,7 +30,7 @@ mod aggregation_max { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn max_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5", string: "2" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10", string: "f" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_max { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn max_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs index 332a5e10707f..c5ce60653d8f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/min.rs @@ -30,7 +30,7 @@ mod aggregation_min { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn min_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5", string: "2" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10", string: "f" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_min { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn min_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs index 14d194a1a4f4..b713d216edb7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/sum.rs @@ -30,7 +30,7 @@ mod aggregation_sum { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn sum_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 5.5, int: 5, bInt: "5" }"#).await?; create_row(&runner, r#"{ id: 2, float: 4.5, int: 10, bInt: "10" }"#).await?; @@ -120,7 +120,7 @@ mod decimal_aggregation_sum { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn sum_with_all_sorts_of_query_args(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, decimal: "5.5" }"#).await?; create_row(&runner, r#"{ id: 2, decimal: "4.5" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/in_selection_batching.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/in_selection_batching.rs index f5e7face6768..aacdb50f687c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/in_selection_batching.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/in_selection_batching.rs @@ -88,7 +88,13 @@ mod isb { Ok(()) } - #[connector_test(exclude(MongoDb))] + #[connector_test(exclude( + MongoDb, + Postgres("pg.js.wasm"), + Postgres("neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ))] async fn order_by_aggregation_should_fail(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -103,7 +109,16 @@ mod isb { Ok(()) } - #[connector_test(exclude(MongoDb), capabilities(FullTextSearchWithoutIndex))] + #[connector_test( + capabilities(FullTextSearchWithoutIndex), + exclude( + MongoDb, + Postgres("pg.js.wasm"), + Postgres("neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ) + )] async fn order_by_relevance_should_fail(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs index 2c332f95f29a..f4ad29cf0584 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batch/transactional_batch.rs @@ -44,7 +44,7 @@ mod transactional { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn one_success_one_fail(runner: Runner) -> TestResult<()> { let queries = vec![ r#"mutation { createOneModelA(data: { id: 1 }) { id }}"#.to_string(), @@ -77,7 +77,7 @@ mod transactional { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn one_query(runner: Runner) -> TestResult<()> { // Existing ModelA in the DB will prevent the nested ModelA creation in the batch. insta::assert_snapshot!( @@ -104,7 +104,7 @@ mod transactional { Ok(()) } - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn valid_isolation_level(runner: Runner) -> TestResult<()> { let queries = vec![r#"mutation { createOneModelB(data: { id: 1 }) { id }}"#.to_string()]; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/bytes.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/bytes.rs index a4957d75e1ab..265a75763794 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/bytes.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/bytes.rs @@ -1,6 +1,14 @@ use query_engine_tests::*; -#[test_suite(schema(common_nullable_types))] +#[test_suite( + schema(common_nullable_types), + exclude( + Postgres("pg.js.wasm"), + Postgres("neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ) +)] mod bytes { use query_engine_tests::{run_query, EngineProtocol, Runner}; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/through_relation.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/through_relation.rs index b2af72ab955e..8baceb69e98b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/through_relation.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/through_relation.rs @@ -34,7 +34,11 @@ mod scalar_relations { // TODO: fix https://github.com/prisma/team-orm/issues/684, https://github.com/prisma/team-orm/issues/685 and unexclude DAs #[connector_test( schema(schema_common), - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude( + Postgres("pg.js", "neon.js", "pg.js.wasm", "neon.js.wasm"), + Vitess("planetscale.js", "planetscale.js.wasm"), + Sqlite("libsql.js.wasm") + ) )] async fn common_types(runner: Runner) -> TestResult<()> { create_common_children(&runner).await?; @@ -236,7 +240,7 @@ mod scalar_relations { #[connector_test( schema(schema_scalar_lists), capabilities(ScalarLists), - exclude(Postgres("pg.js", "neon.js")) + exclude(Postgres("pg.js", "neon.js", "pg.js.wasm", "neon.js.wasm")) )] async fn scalar_lists(runner: Runner) -> TestResult<()> { create_child( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bigint_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bigint_filter.rs index 8230c7e2f04b..16e5804cd7f6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bigint_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bigint_filter.rs @@ -1,7 +1,10 @@ use super::common_test_data; use query_engine_tests::*; -#[test_suite(schema(schemas::common_nullable_types))] +#[test_suite( + schema(schemas::common_nullable_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) +)] mod bigint_filter_spec { use query_engine_tests::run_query; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bytes_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bytes_filter.rs index dd8963dca6e8..58ec7e08f8c8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bytes_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/bytes_filter.rs @@ -1,7 +1,10 @@ use super::common_test_data; use query_engine_tests::*; -#[test_suite(schema(schemas::common_nullable_types))] +#[test_suite( + schema(schemas::common_nullable_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) +)] mod bytes_filter_spec { use query_engine_tests::run_query; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bigint_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bigint_filter.rs index bcd12fb1b5b7..0ef65c7af43a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bigint_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bigint_filter.rs @@ -6,7 +6,10 @@ mod bigint_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn basic_where(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -28,7 +31,10 @@ mod bigint_filter { Ok(()) } - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn numeric_comparison_filters(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -137,7 +143,11 @@ mod bigint_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn scalar_list_filters(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bytes_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bytes_filter.rs index bcb4a76c6158..a77bf6e765b2 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bytes_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/bytes_filter.rs @@ -6,7 +6,10 @@ mod bytes_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn basic_where(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -28,7 +31,11 @@ mod bytes_filter { Ok(()) } - #[connector_test(schema(setup::common_mixed_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_mixed_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn inclusion_filter(runner: Runner) -> TestResult<()> { setup::test_data_common_mixed_types(&runner).await?; @@ -50,7 +57,11 @@ mod bytes_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn scalar_list_filters(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/datetime_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/datetime_filter.rs index 327379bd4903..2753471bc635 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/datetime_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/datetime_filter.rs @@ -6,7 +6,10 @@ mod datetime_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn basic_where(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -28,7 +31,10 @@ mod datetime_filter { Ok(()) } - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn numeric_comparison_filters(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -137,7 +143,11 @@ mod datetime_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + capabilities(ScalarLists), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")) + )] async fn scalar_list_filters(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/float_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/float_filter.rs index 5dfae5f09c36..f40f73bbc180 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/float_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/float_filter.rs @@ -6,7 +6,10 @@ mod float_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn basic_where(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -28,7 +31,10 @@ mod float_filter { Ok(()) } - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn numeric_comparison_filters(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -137,7 +143,11 @@ mod float_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn scalar_list_filters(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/int_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/int_filter.rs index 972539ec1f15..cedbb81c3a1f 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/int_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/int_filter.rs @@ -6,7 +6,10 @@ mod int_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn basic_where(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -28,7 +31,10 @@ mod int_filter { Ok(()) } - #[connector_test(schema(setup::common_types))] + #[connector_test( + schema(setup::common_types), + exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")) + )] async fn numeric_comparison_filters(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -137,7 +143,11 @@ mod int_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn scalar_list_filters(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs index b865731161c2..2666e8c80900 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/json_filter.rs @@ -126,7 +126,7 @@ mod json_filter { Ok(()) } - #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn string_comparison_filters(runner: Runner) -> TestResult<()> { test_string_data(&runner).await?; @@ -169,7 +169,7 @@ mod json_filter { Ok(()) } - #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(schema(schema), exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn array_comparison_filters(runner: Runner) -> TestResult<()> { test_array_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/string_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/string_filter.rs index f9c2e6e06acc..c62821ef4604 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/string_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/field_reference/string_filter.rs @@ -6,7 +6,7 @@ mod string_filter { use super::setup; use query_engine_tests::run_query; - #[connector_test] + #[connector_test(exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")))] async fn basic_where_sensitive(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -50,7 +50,7 @@ mod string_filter { Ok(()) } - #[connector_test] + #[connector_test(exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")))] async fn numeric_comparison_filters_sensitive(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; @@ -225,7 +225,7 @@ mod string_filter { Ok(()) } - #[connector_test] + #[connector_test(exclude(Sqlite("libsql.js.wasm"), Vitess("planetscale.js.wasm")))] async fn string_comparison_filters_sensitive(runner: Runner) -> TestResult<()> { setup::test_data_common_types(&runner).await?; run_query!( @@ -435,7 +435,11 @@ mod string_filter { Ok(()) } - #[connector_test(schema(setup::common_list_types), capabilities(ScalarLists))] + #[connector_test( + schema(setup::common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) + )] async fn scalar_list_filters_sensitive(runner: Runner) -> TestResult<()> { setup::test_data_list_common(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs index d1b62a086153..ca8cc885798a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs @@ -212,9 +212,8 @@ mod json { #[connector_test( schema(json_opt), exclude( - Vitess("planetscale.js"), - Postgres("neon.js"), - Postgres("pg.js"), + Vitess("planetscale.js", "planetscale.js.wasm"), + Postgres("neon.js", "pg.js", "neon.js.wasm", "pg.js.wasm"), Sqlite("libsql.js"), MySQL(5.6) ) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs index f3e4026a8678..a1a3072e242c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json_filters.rs @@ -27,7 +27,7 @@ mod json_filters { schema.to_owned() } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn no_path_without_filter(runner: Runner) -> TestResult<()> { assert_error!( runner, @@ -280,7 +280,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn array_contains(runner: Runner) -> TestResult<()> { array_contains_runner(runner).await?; @@ -389,7 +389,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn array_starts_with(runner: Runner) -> TestResult<()> { array_starts_with_runner(runner).await?; @@ -496,7 +496,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn array_ends_with(runner: Runner) -> TestResult<()> { array_ends_with_runner(runner).await?; @@ -535,7 +535,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn string_contains(runner: Runner) -> TestResult<()> { string_contains_runner(runner).await?; @@ -575,7 +575,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn string_starts_with(runner: Runner) -> TestResult<()> { string_starts_with_runner(runner).await?; @@ -614,7 +614,7 @@ mod json_filters { Ok(()) } - #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js")))] + #[connector_test(exclude(MySQL(5.6), Vitess("planetscale.js", "planetscale.js.wasm")))] async fn string_ends_with(runner: Runner) -> TestResult<()> { string_ends_with_runner(runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/list_filters.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/list_filters.rs index 16b9a0ab0437..f34675ba3ff1 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/list_filters.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/list_filters.rs @@ -1,6 +1,10 @@ use query_engine_tests::*; -#[test_suite(schema(common_list_types), capabilities(ScalarLists))] +#[test_suite( + schema(common_list_types), + exclude(Postgres("pg.js.wasm", "neon.js.wasm")), + capabilities(ScalarLists) +)] mod lists { use indoc::indoc; use query_engine_tests::run_query; @@ -623,7 +627,7 @@ mod lists { } // Cockroachdb does not like the bytes empty array check in v21 but this will be fixed in 22. - #[connector_test(exclude(CockroachDB))] + #[connector_test(exclude(CockroachDB), exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn is_empty_bytes(runner: Runner) -> TestResult<()> { test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs index 51637d3bbcb8..abf7f04efdf3 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/search_filter.rs @@ -229,7 +229,7 @@ mod search_filter_with_index { super::ensure_filter_tree_shake_works(runner).await } - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn throws_error_on_missing_index(runner: Runner) -> TestResult<()> { super::create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs index 6a67b87d56b1..34af3fc21ed9 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/nested_pagination.rs @@ -80,7 +80,7 @@ mod nested_pagination { ***************/ // should skip the first item - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn mid_lvl_skip_1(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -102,7 +102,7 @@ mod nested_pagination { } // should "skip all items" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn mid_lvl_skip_3(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -124,7 +124,7 @@ mod nested_pagination { } // should "skip all items" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn mid_lvl_skip_4(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs index c8f7429451a7..d12f7fcfed65 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent.rs @@ -223,7 +223,7 @@ mod order_by_dependent { } // "[Circular with differing records] Ordering by related record field ascending" should "work" - #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn circular_diff_related_record_asc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; @@ -258,7 +258,7 @@ mod order_by_dependent { } // "[Circular with differing records] Ordering by related record field descending" should "work" - #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn circular_diff_related_record_desc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs index f8e5e831971b..323192be180d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/order_by_dependent_pagination.rs @@ -79,7 +79,7 @@ mod order_by_dependent_pag { // "[Hops: 1] Ordering by related record field ascending with nulls" should "work" // TODO(julius): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn hop_1_related_record_asc_nulls(runner: Runner) -> TestResult<()> { // 1 record has the "full chain", one half, one none create_row(&runner, 1, Some(1), Some(1), None).await?; @@ -146,7 +146,7 @@ mod order_by_dependent_pag { // "[Hops: 2] Ordering by related record field ascending with nulls" should "work" // TODO(garren): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn hop_2_related_record_asc_null(runner: Runner) -> TestResult<()> { // 1 record has the "full chain", one half, one none create_row(&runner, 1, Some(1), Some(1), None).await?; @@ -227,7 +227,7 @@ mod order_by_dependent_pag { // "[Circular with differing records] Ordering by related record field ascending" should "work" // TODO(julius): should enable for SQL Server when partial indices are in the PSL - #[connector_test(exclude(SqlServer, Vitess("planetscale.js")))] + #[connector_test(exclude(SqlServer, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn circular_diff_related_record_asc(runner: Runner) -> TestResult<()> { // Records form circles with their relations create_row(&runner, 1, Some(1), Some(1), Some(3)).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs index 83c472a064e7..e6cbee21d9b7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/order_and_pagination/pagination.rs @@ -277,7 +277,7 @@ mod pagination { ********************/ // "A skip" should "return all records after the offset specified" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn skip_returns_all_after_offset(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -296,7 +296,7 @@ mod pagination { } // "A skip with order reversed" should "return all records after the offset specified" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn skip_reversed_order(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; @@ -315,7 +315,7 @@ mod pagination { } // "A skipping beyond all records" should "return no records" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn skipping_beyond_all_records(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs index c03067eed818..146892889beb 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/casts.rs @@ -18,7 +18,7 @@ mod casts { // // Bails with: ERROR: invalid input syntax for type integer: "42.51" // - #[connector_test(only(Postgres), exclude(Postgres("neon.js"), Postgres("pg.js")))] + #[connector_test(only(Postgres), exclude(Postgres("neon.js", "pg.js", "neon.js.wasm", "pg.js.wasm")))] async fn query_numeric_casts(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query_pretty!(&runner, fmt_query_raw(r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs index cb44a2285ff2..4d38c60b5b75 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/errors.rs @@ -37,7 +37,7 @@ mod raw_errors { #[connector_test( schema(common_nullable_types), only(Postgres), - exclude(Postgres("neon.js"), Postgres("pg.js")) + exclude(Postgres("neon.js", "pg.js", "neon.js.wasm", "pg.js.wasm")) )] async fn list_param_for_scalar_column_should_not_panic_quaint(runner: Runner) -> TestResult<()> { assert_error!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/input_coercion.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/input_coercion.rs index eac2bc42b4cd..215cd539af3c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/input_coercion.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/input_coercion.rs @@ -5,7 +5,7 @@ mod input_coercion { use query_engine_tests::fmt_execute_raw; // Checks that query raw inputs are coerced to the correct types - #[connector_test] + #[connector_test(only(Postgres), exclude(Postgres("pg.js.wasm", "neon.js.wasm"),))] async fn scalar_input_correctly_coerced(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/null_list.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/null_list.rs index 32a8a8ef281d..4ae2a2b6b57c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/null_list.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/null_list.rs @@ -5,7 +5,11 @@ use query_engine_tests::*; mod null_list { use query_engine_tests::{fmt_query_raw, run_query, run_query_pretty}; - #[connector_test(schema(common_list_types))] + #[connector_test( + schema(common_list_types), + only(Postgres), + exclude(Postgres("pg.js.wasm", "neon.js.wasm"),) + )] async fn null_scalar_lists(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs index c3687ddd9f3e..8434da64073e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs @@ -26,7 +26,7 @@ mod typed_output { schema.to_owned() } - #[connector_test(schema(schema_pg), only(Postgres))] + #[connector_test(schema(schema_pg), only(Postgres), exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn all_scalars_pg(runner: Runner) -> TestResult<()> { create_row( &runner, @@ -483,7 +483,7 @@ mod typed_output { schema.to_owned() } - #[connector_test(schema(schema_sqlite), only(Sqlite))] + #[connector_test(schema(schema_sqlite), only(Sqlite), exclude(Sqlite("libsql.js.wasm")))] async fn all_scalars_sqlite(runner: Runner) -> TestResult<()> { create_row( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bigint.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bigint.rs index 1cb9adf534a8..c78b522f4994 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bigint.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bigint.rs @@ -17,7 +17,11 @@ mod bigint { } // "Using a BigInt field" should "work" - #[connector_test] + #[connector_test(exclude( + Postgres("pg.js.wasm", "neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") + ))] async fn using_bigint_field(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs index 791b0a2137fb..654463f491f7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/bytes.rs @@ -1,6 +1,10 @@ use query_engine_tests::*; -#[test_suite] +#[test_suite(exclude( + Postgres("pg.js.wasm", "neon.js.wasm"), + Sqlite("libsql.js.wasm"), + Vitess("planetscale.js.wasm") +))] mod bytes { use indoc::indoc; use query_engine_tests::run_query; @@ -77,7 +81,16 @@ mod bytes { Ok(()) } - #[connector_test(schema(bytes_id), exclude(MySQL, Vitess, SqlServer))] + #[connector_test( + schema(bytes_id), + exclude( + MySQL, + Vitess, + SqlServer, + Postgres("pg.js.wasm", "neon.js.wasm"), + Sqlite("libsql.js.wasm") + ) + )] async fn byte_id_coercion(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(runner, r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/native_types/postgres.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/native_types/postgres.rs index 2d487ec4f137..2a83d17f6fb7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/native_types/postgres.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/native_types/postgres.rs @@ -23,7 +23,7 @@ mod postgres { } //"Postgres native int types" should "work" - #[connector_test(schema(schema_int))] + #[connector_test(schema(schema_int), only(Postgres), exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn native_int_types(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { @@ -191,7 +191,11 @@ mod postgres { } // "Other Postgres native types" should "work" - #[connector_test(schema(schema_other_types), only(Postgres), exclude(CockroachDb))] + #[connector_test( + schema(schema_other_types), + only(Postgres), + exclude(CockroachDb, Postgres("pg.js.wasm", "neon.js.wasm")) + )] async fn native_other_types(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/base.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/base.rs index 9a5e74dd8547..2bd989573da8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/base.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/base.rs @@ -28,7 +28,7 @@ mod basic_types { schema.to_owned() } - #[connector_test] + #[connector_test(exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn set_base(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, format!(r#"mutation {{ @@ -59,7 +59,7 @@ mod basic_types { // "Scalar lists" should "be behave like regular values for create and update operations" // Skipped for CockroachDB as enum array concatenation is not supported (https://github.com/cockroachdb/cockroach/issues/71388). - #[connector_test(exclude(CockroachDb))] + #[connector_test(exclude(CockroachDb, Postgres("pg.js.wasm", "neon.js.wasm")))] async fn behave_like_regular_val_for_create_and_update(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, format!(r#"mutation {{ @@ -158,7 +158,7 @@ mod basic_types { } // "A Create Mutation" should "create and return items with list values with shorthand notation" - #[connector_test] + #[connector_test(exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn create_mut_work_with_list_vals(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, format!(r#"mutation {{ diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/defaults.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/defaults.rs index 39370e62c572..c216b36ae458 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/defaults.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/data_types/scalar_list/defaults.rs @@ -29,7 +29,7 @@ mod basic { schema.to_owned() } - #[connector_test] + #[connector_test(exclude(Postgres("pg.js.wasm", "neon.js.wasm")))] async fn basic_write(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs index 5493ff7f2778..5d46b75a98fa 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/byoid.rs @@ -48,7 +48,7 @@ mod byoid { #[connector_test( schema(schema_1), only(MySql, Postgres, Sqlite, Vitess), - exclude(Vitess("planetscale.js")) + exclude(Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn create_and_return_item_woi_1(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( @@ -80,7 +80,7 @@ mod byoid { #[connector_test( schema(schema_2), only(MySql, Postgres, Sqlite, Vitess), - exclude(Vitess("planetscale.js")) + exclude(Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn create_and_return_item_woi_2(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( @@ -142,7 +142,7 @@ mod byoid { #[connector_test( schema(schema_1), only(MySql, Postgres, Sqlite, Vitess), - exclude(Vitess("planetscale.js")) + exclude(Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn nested_create_return_item_woi_1(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( @@ -174,7 +174,7 @@ mod byoid { #[connector_test( schema(schema_2), only(MySql, Postgres, Sqlite, Vitess), - exclude(Vitess("planetscale.js")) + exclude(Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn nested_create_return_item_woi_2(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/already_converted/nested_update_many_inside_update.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/already_converted/nested_update_many_inside_update.rs index 05931d16084b..c6b48405f8c9 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/already_converted/nested_update_many_inside_update.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/already_converted/nested_update_many_inside_update.rs @@ -59,7 +59,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneReq", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_req_should_work(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -98,7 +98,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneOpt", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_should_work(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -137,7 +137,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToMany", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_cm_should_work(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -176,7 +176,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneReq", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_req_many_ums(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -221,7 +221,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneReq", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_req_empty_filter(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -262,7 +262,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneReq", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_req_noop_no_hit(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; @@ -309,7 +309,7 @@ mod um_inside_update { #[relation_link_test( on_parent = "ToMany", on_child = "ToOneReq", - exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js")) + exclude(Postgres("pg.js", "neon.js"), Vitess("planetscale.js", "planetscale.js.wasm")) )] async fn pm_c1_req_many_filters(runner: &Runner, t: &DatamodelWithParams) -> TestResult<()> { let parent = setup_data(runner, t).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs index 45562b5f6be8..cd71df429ea3 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs @@ -78,7 +78,7 @@ mod nested_create_many { // "Nested createMany" should "error on duplicates by default" // TODO(dom): Not working for mongo - #[connector_test(exclude(Sqlite, MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(Sqlite, MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn nested_createmany_fail_dups(runner: Runner) -> TestResult<()> { assert_error!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs index 808af82deec4..8f91a6039de4 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/relations/compound_fks_mixed_requiredness.rs @@ -26,7 +26,7 @@ mod compound_fks { } // "A One to Many relation with mixed requiredness" should "be writable and readable" - #[connector_test(exclude(MySql(5.6), MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MySql(5.6), MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn one2m_mix_required_writable_readable(runner: Runner) -> TestResult<()> { // Setup user insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs index 1507ea0c082b..5c91f1c7f18a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs @@ -205,7 +205,7 @@ mod create { // TODO(dom): Not working on mongo // TODO(dom): 'Expected result to return an error, but found success: {"data":{"createOneScalarModel":{"optUnique":"test"}}}' // Comment(dom): Expected, we're not enforcing uniqueness for the test setup yet. - #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] + #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn gracefully_fails_when_uniq_violation(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs index 94118b669c1b..832205e66c60 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs @@ -165,7 +165,7 @@ mod create_many { } // "createMany" should "error on duplicates by default" - #[connector_test(schema(schema_4), exclude(Vitess("planetscale.js")))] + #[connector_test(schema(schema_4), exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn create_many_error_dups(runner: Runner) -> TestResult<()> { assert_error!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs index 749048fd3edc..80c59a1a65f4 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs @@ -123,7 +123,7 @@ mod update_many { } // "An updateMany mutation" should "correctly apply all number operations for Int" - #[connector_test(exclude(Vitess("planetscale.js"), CockroachDb))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm"), CockroachDb))] async fn apply_number_ops_for_int(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; create_row(&runner, r#"{ id: 2, optStr: "str2", optInt: 2 }"#).await?; @@ -240,7 +240,7 @@ mod update_many { } // "An updateMany mutation" should "correctly apply all number operations for Float" - #[connector_test(exclude(Vitess("planetscale.js")))] + #[connector_test(exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn apply_number_ops_for_float(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; create_row(&runner, r#"{ id: 2, optStr: "str2", optFloat: 2 }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs index f4f43eda05ac..e876bac06211 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/upsert.rs @@ -674,7 +674,7 @@ mod upsert { Ok(()) } - #[connector_test(schema(generic), exclude(Vitess("planetscale.js")))] + #[connector_test(schema(generic), exclude(Vitess("planetscale.js", "planetscale.js.wasm")))] async fn upsert_fails_if_filter_dont_match(runner: Runner) -> TestResult<()> { run_query!( &runner, diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs index 6cc6120f71c8..5a0dff3b49a2 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs @@ -99,7 +99,9 @@ pub(crate) fn connection_string( Some(PostgresVersion::V12) if is_ci => { format!("postgresql://postgres:prisma@test-db-postgres-12:5432/{database}") } - Some(PostgresVersion::V13) | Some(PostgresVersion::NeonJs) | Some(PostgresVersion::PgJs) if is_ci => { + Some(PostgresVersion::V13) | Some(PostgresVersion::NeonJsNapi) | Some(PostgresVersion::PgJsNapi) + if is_ci => + { format!("postgresql://postgres:prisma@test-db-postgres-13:5432/{database}") } Some(PostgresVersion::V14) if is_ci => { @@ -116,7 +118,11 @@ pub(crate) fn connection_string( Some(PostgresVersion::V10) => format!("postgresql://postgres:prisma@127.0.0.1:5432/{database}"), Some(PostgresVersion::V11) => format!("postgresql://postgres:prisma@127.0.0.1:5433/{database}"), Some(PostgresVersion::V12) => format!("postgresql://postgres:prisma@127.0.0.1:5434/{database}"), - Some(PostgresVersion::V13) | Some(PostgresVersion::NeonJs) | Some(PostgresVersion::PgJs) => { + Some(PostgresVersion::V13) + | Some(PostgresVersion::NeonJsNapi) + | Some(PostgresVersion::PgJsNapi) + | Some(PostgresVersion::PgJsWasm) + | Some(PostgresVersion::NeonJsWasm) => { format!("postgresql://postgres:prisma@127.0.0.1:5435/{database}") } Some(PostgresVersion::V14) => format!("postgresql://postgres:prisma@127.0.0.1:5437/{database}"), @@ -201,7 +207,7 @@ pub(crate) fn connection_string( } ConnectorVersion::Vitess(Some(VitessVersion::V8_0)) => "mysql://root@localhost:33807/test".into(), - ConnectorVersion::Vitess(Some(VitessVersion::PlanetscaleJs)) => { + ConnectorVersion::Vitess(Some(VitessVersion::PlanetscaleJsNapi | VitessVersion::PlanetscaleJsWasm)) => { format!("mysql://root@127.0.0.1:3310/{database}") } @@ -380,8 +386,8 @@ mod tests { let only = vec![("postgres", None)]; let exclude = vec![("postgres", Some("neon.js"))]; let postgres = &PostgresConnectorTag as ConnectorTag; - let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); - let pg = ConnectorVersion::Postgres(Some(PostgresVersion::PgJs)); + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJsNapi)); + let pg = ConnectorVersion::Postgres(Some(PostgresVersion::PgJsNapi)); assert!(!super::should_run(&postgres, &neon, &only, &exclude, Default::default())); assert!(super::should_run(&postgres, &pg, &only, &exclude, Default::default())); @@ -393,7 +399,7 @@ mod tests { let only = vec![("postgres", None)]; let exclude = vec![("postgres", None)]; let postgres = &PostgresConnectorTag as ConnectorTag; - let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJsNapi)); super::should_run(&postgres, &neon, &only, &exclude, Default::default()); } @@ -404,7 +410,7 @@ mod tests { let only = vec![("postgres", Some("neon.js"))]; let exclude = vec![("postgres", None)]; let postgres = &PostgresConnectorTag as ConnectorTag; - let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJs)); + let neon = ConnectorVersion::Postgres(Some(PostgresVersion::NeonJsNapi)); super::should_run(&postgres, &neon, &only, &exclude, Default::default()); } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs index 42d0a8c7afdc..2a839ab22584 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/postgres.rs @@ -36,8 +36,10 @@ pub enum PostgresVersion { V14, V15, PgBouncer, - NeonJs, - PgJs, + NeonJsNapi, + PgJsNapi, + NeonJsWasm, + PgJsWasm, } impl TryFrom<&str> for PostgresVersion { @@ -53,8 +55,10 @@ impl TryFrom<&str> for PostgresVersion { "14" => Self::V14, "15" => Self::V15, "pgbouncer" => Self::PgBouncer, - "neon.js" => Self::NeonJs, - "pg.js" => Self::PgJs, + "neon.js" => Self::NeonJsNapi, + "pg.js" => Self::PgJsNapi, + "pg.js.wasm" => Self::PgJsWasm, + "neon.js.wasm" => Self::NeonJsWasm, _ => return Err(TestError::parse_error(format!("Unknown Postgres version `{s}`"))), }; @@ -73,8 +77,10 @@ impl ToString for PostgresVersion { PostgresVersion::V14 => "14", PostgresVersion::V15 => "15", PostgresVersion::PgBouncer => "pgbouncer", - PostgresVersion::NeonJs => "neon.js", - PostgresVersion::PgJs => "pg.js", + PostgresVersion::NeonJsNapi => "neon.js", + PostgresVersion::PgJsNapi => "pg.js", + PostgresVersion::PgJsWasm => "pg.js.wasm", + PostgresVersion::NeonJsWasm => "pg.js.wasm", } .to_owned() } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs index 5f4dab56784a..2173bbdd38f2 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/sqlite.rs @@ -29,14 +29,16 @@ impl ConnectorTagInterface for SqliteConnectorTag { #[derive(Clone, Debug, PartialEq, Eq)] pub enum SqliteVersion { V3, - LibsqlJS, + LibsqlJsNapi, + LibsqlJsWasm, } impl ToString for SqliteVersion { fn to_string(&self) -> String { match self { SqliteVersion::V3 => "3".to_string(), - SqliteVersion::LibsqlJS => "libsql.js".to_string(), + SqliteVersion::LibsqlJsNapi => "libsql.js".to_string(), + SqliteVersion::LibsqlJsWasm => "libsql.js.wasm".to_string(), } } } @@ -47,7 +49,8 @@ impl TryFrom<&str> for SqliteVersion { fn try_from(s: &str) -> Result { let version = match s { "3" => Self::V3, - "libsql.js" => Self::LibsqlJS, + "libsql.js" => Self::LibsqlJsNapi, + "libsql.js.wasm" => Self::LibsqlJsWasm, _ => return Err(TestError::parse_error(format!("Unknown SQLite version `{s}`"))), }; Ok(version) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs index ce827927b403..ba0f4249cd7c 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/vitess.rs @@ -34,7 +34,8 @@ impl ConnectorTagInterface for VitessConnectorTag { #[derive(Debug, Clone, Copy, PartialEq)] pub enum VitessVersion { V8_0, - PlanetscaleJs, + PlanetscaleJsNapi, + PlanetscaleJsWasm, } impl FromStr for VitessVersion { @@ -43,7 +44,8 @@ impl FromStr for VitessVersion { fn from_str(s: &str) -> Result { let version = match s { "8.0" => Self::V8_0, - "planetscale.js" => Self::PlanetscaleJs, + "planetscale.js" => Self::PlanetscaleJsNapi, + "planetscale.js.wasm" => Self::PlanetscaleJsWasm, _ => return Err(TestError::parse_error(format!("Unknown Vitess version `{s}`"))), }; @@ -55,7 +57,8 @@ impl Display for VitessVersion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::V8_0 => write!(f, "8.0"), - Self::PlanetscaleJs => write!(f, "planetscale.js"), + Self::PlanetscaleJsNapi => write!(f, "planetscale.js"), + Self::PlanetscaleJsWasm => write!(f, "planetscale.js.wasm"), } } } diff --git a/query-engine/connector-test-kit-rs/test-configs/libsql-wasm b/query-engine/connector-test-kit-rs/test-configs/libsql-wasm index b93966875dea..96ca6a4d7f13 100644 --- a/query-engine/connector-test-kit-rs/test-configs/libsql-wasm +++ b/query-engine/connector-test-kit-rs/test-configs/libsql-wasm @@ -1,5 +1,6 @@ { "connector": "sqlite", + "version": "libsql.js.wasm", "driver_adapter": "libsql", "external_test_executor": "Wasm" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/neon-wasm b/query-engine/connector-test-kit-rs/test-configs/neon-wasm index 2697c5227399..132796d62ee7 100644 --- a/query-engine/connector-test-kit-rs/test-configs/neon-wasm +++ b/query-engine/connector-test-kit-rs/test-configs/neon-wasm @@ -1,6 +1,6 @@ { "connector": "postgres", - "version": "13", + "version": "neon.js.wasm", "driver_adapter": "neon:ws", "driver_adapter_config": { "proxy_url": "127.0.0.1:5488/v1" }, "external_test_executor": "Wasm" diff --git a/query-engine/connector-test-kit-rs/test-configs/pg-wasm b/query-engine/connector-test-kit-rs/test-configs/pg-wasm index b5d8ac3c7b15..a71ea4ece7bb 100644 --- a/query-engine/connector-test-kit-rs/test-configs/pg-wasm +++ b/query-engine/connector-test-kit-rs/test-configs/pg-wasm @@ -1,6 +1,6 @@ { "connector": "postgres", - "version": "13", + "version": "pg.js.wasm", "driver_adapter": "pg", "external_test_executor": "Wasm" } \ No newline at end of file diff --git a/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm b/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm index 62dd895e970c..b9f190e064c6 100644 --- a/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm +++ b/query-engine/connector-test-kit-rs/test-configs/planetscale-wasm @@ -1,6 +1,6 @@ { "connector": "vitess", - "version": "planetscale.js", + "version": "planetscale.js.wasm", "driver_adapter": "planetscale", "driver_adapter_config": { "proxy_url": "http://root:root@127.0.0.1:8085"