diff --git a/Cargo.lock b/Cargo.lock index 5975de103fc..1bf449f7fb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5761,6 +5761,25 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.58", +] + [[package]] name = "subtle" version = "2.5.0" @@ -5954,6 +5973,8 @@ dependencies = [ "snap", "speedb", "storekey", + "strum", + "strum_macros", "surrealdb-derive", "surrealdb-jsonwebtoken", "surrealdb-tikv-client", diff --git a/core/Cargo.toml b/core/Cargo.toml index 57263e0a97b..c0c10cc1032 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -153,6 +153,8 @@ criterion = { version = "0.5.1", features = ["async_tokio"] } env_logger = "0.10.1" pprof = { version = "0.13.0", features = ["flamegraph", "criterion"] } serial_test = "2.0.0" +strum = "0.26.2" +strum_macros = "0.26.2" temp-dir = "0.1.11" test-log = { version = "0.2.13", features = ["trace"] } time = { version = "0.3.30", features = ["serde"] } diff --git a/core/src/rpc/method.rs b/core/src/rpc/method.rs index 89de5bb743a..13ec5f2dfd7 100644 --- a/core/src/rpc/method.rs +++ b/core/src/rpc/method.rs @@ -1,4 +1,10 @@ +#[cfg(test)] +use strum::IntoEnumIterator; +#[cfg(test)] +use strum_macros::EnumIter; + #[non_exhaustive] +#[cfg_attr(test, derive(Debug, Copy, Clone, PartialEq, EnumIter))] pub enum Method { Unknown, Ping, @@ -88,6 +94,36 @@ impl Method { } } +impl From for Method { + fn from(n: u8) -> Self { + match n { + 1 => Self::Ping, + 2 => Self::Info, + 3 => Self::Use, + 4 => Self::Signup, + 5 => Self::Signin, + 6 => Self::Invalidate, + 7 => Self::Authenticate, + 8 => Self::Kill, + 9 => Self::Live, + 10 => Self::Set, + 11 => Self::Unset, + 12 => Self::Select, + 13 => Self::Insert, + 14 => Self::Create, + 15 => Self::Update, + 16 => Self::Merge, + 17 => Self::Patch, + 18 => Self::Delete, + 19 => Self::Version, + 20 => Self::Query, + 21 => Self::Relate, + 22 => Self::Run, + _ => Self::Unknown, + } + } +} + impl Method { pub fn is_valid(&self) -> bool { !matches!(self, Self::Unknown) @@ -112,3 +148,20 @@ impl Method { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn all_variants_from_u8() { + for method in Method::iter() { + assert_eq!(method.clone(), Method::from(method as u8)); + } + } + + #[test] + fn unknown_from_out_of_range_u8() { + assert_eq!(Method::Unknown, Method::from(182)); + } +} \ No newline at end of file diff --git a/core/src/rpc/request.rs b/core/src/rpc/request.rs index 4a742508900..b89ef69dc57 100644 --- a/core/src/rpc/request.rs +++ b/core/src/rpc/request.rs @@ -1,17 +1,18 @@ use crate::rpc::format::cbor::Cbor; use crate::rpc::format::msgpack::Pack; use crate::rpc::RpcError; -use crate::sql::Part; -use crate::sql::{Array, Value}; +use crate::sql::{Array, Number, Part, Value}; use once_cell::sync::Lazy; +use super::method::Method; + pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]); pub struct Request { pub id: Option, - pub method: String, + pub method: Method, pub params: Array, } @@ -44,7 +45,11 @@ impl TryFrom for Request { }; // Fetch the 'method' argument let method = match val.pick(&*METHOD) { - Value::Strand(v) => v.to_raw(), + Value::Strand(v) => Method::parse(v.to_raw()), + Value::Number(Number::Int(v)) => match u8::try_from(v) { + Ok(v) => Method::from(v), + _ => return Err(RpcError::InvalidRequest), + }, _ => return Err(RpcError::InvalidRequest), }; // Fetch the 'params' argument diff --git a/core/src/rpc/rpc_context.rs b/core/src/rpc/rpc_context.rs index f1ddbaa5814..c9c09b8b226 100644 --- a/core/src/rpc/rpc_context.rs +++ b/core/src/rpc/rpc_context.rs @@ -35,7 +35,7 @@ pub trait RpcContext { async { unreachable!() } } - async fn execute(&mut self, method: Method, params: Array) -> Result { + async fn execute(&mut self, method: &Method, params: Array) -> Result { match method { Method::Ping => Ok(Value::None.into()), Method::Info => self.info().await.map(Into::into).map_err(Into::into), @@ -65,7 +65,7 @@ pub trait RpcContext { } } - async fn execute_immut(&self, method: Method, params: Array) -> Result { + async fn execute_immut(&self, method: &Method, params: Array) -> Result { match method { Method::Ping => Ok(Value::None.into()), Method::Info => self.info().await.map(Into::into).map_err(Into::into), diff --git a/src/net/rpc.rs b/src/net/rpc.rs index ddf8f86ffac..e20561bdfe5 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -23,7 +23,6 @@ use http_body::Body as HttpBody; use surrealdb::dbs::Session; use surrealdb::rpc::format::Format; use surrealdb::rpc::format::PROTOCOLS; -use surrealdb::rpc::method::Method; use tower_http::request_id::RequestId; use uuid::Uuid; @@ -116,7 +115,7 @@ async fn post_handler( match fmt.req_http(body) { Ok(req) => { - let res = rpc_ctx.execute(Method::parse(req.method), req.params).await; + let res = rpc_ctx.execute(&req.method, req.params).await; fmt.res_http(res.into_response(None)).map_err(Error::from) } Err(err) => Err(Error::from(err)), diff --git a/src/rpc/connection.rs b/src/rpc/connection.rs index c3f145a5e3f..39774f10919 100644 --- a/src/rpc/connection.rs +++ b/src/rpc/connection.rs @@ -297,15 +297,17 @@ impl Connection { // Parse the RPC request structure match fmt.req_ws(msg) { Ok(req) => { + let method_str = req.method.to_str(); + // Now that we know the method, we can update the span and create otel context - span.record("rpc.method", &req.method); - span.record("otel.name", format!("surrealdb.rpc/{}", req.method)); + span.record("rpc.method", method_str); + span.record("otel.name", format!("surrealdb.rpc/{}", method_str)); span.record( "rpc.request_id", req.id.clone().map(Value::as_string).unwrap_or_default(), ); let otel_cx = Arc::new(TelemetryContext::current_with_value( - req_cx.with_method(&req.method).with_size(len), + req_cx.with_method(method_str).with_size(len), )); // Process the message let res = @@ -333,11 +335,10 @@ impl Connection { pub async fn process_message( rpc: Arc>, - method: &str, + method: &Method, params: Array, ) -> Result { debug!("Process RPC request"); - let method = Method::parse(method); if !method.is_valid() { return Err(Failure::METHOD_NOT_FOUND); } diff --git a/tests/http_integration.rs b/tests/http_integration.rs index df9336368b6..805011fc482 100644 --- a/tests/http_integration.rs +++ b/tests/http_integration.rs @@ -746,6 +746,54 @@ mod http_integration { Ok(()) } + #[test(tokio::test)] + async fn rpc_endpoint_with_method_as_string() -> Result<(), Box> { + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let url = &format!("http://{addr}/rpc"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", Ulid::new().to_string().parse()?); + headers.insert("DB", Ulid::new().to_string().parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + let res = client.post(url).json(&json!({"method": "ping"})).send().await?; + assert!(res.status().is_success()); + + let result = res.text().await?; + assert_eq!(result, "{\"result\":null}"); + + Ok(()) + } + + #[test(tokio::test)] + async fn rpc_endpoint_with_method_as_number() -> Result<(), Box> { + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let url = &format!("http://{addr}/rpc"); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", Ulid::new().to_string().parse()?); + headers.insert("DB", Ulid::new().to_string().parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + let res = client.post(url).json(&json!({"method": 1})).send().await?; + assert!(res.status().is_success()); + + let result = res.text().await?; + assert_eq!(result, "{\"result\":null}"); + + Ok(()) + } + #[test(tokio::test)] async fn signin_endpoint() -> Result<(), Box> { let (addr, _server) = common::start_server_with_auth_level().await.unwrap();