Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle both number and string for rpc request method #3949

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions core/src/rpc/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,36 @@ impl Method {
}
}

impl From<u8> for Method {
fn from(n: u8) -> Self {
match n {
1 => Self::Ping,
2 => Self::Info,
3 => Self::Use,
4 => Self::Signup,
5 => Self::Signin,
6 => Self::Invalidate,
7 => Self::Authenticate,
8 => Self::Kill,
9 => Self::Live,
10 => Self::Set,
11 => Self::Unset,
12 => Self::Select,
13 => Self::Insert,
14 => Self::Create,
15 => Self::Update,
16 => Self::Merge,
17 => Self::Patch,
18 => Self::Delete,
19 => Self::Version,
20 => Self::Query,
21 => Self::Relate,
22 => Self::Run,
_ => Self::Unknown,
}
}
}

impl Method {
pub fn is_valid(&self) -> bool {
!matches!(self, Self::Unknown)
Expand Down
13 changes: 9 additions & 4 deletions core/src/rpc/request.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use crate::rpc::format::cbor::Cbor;
use crate::rpc::format::msgpack::Pack;
use crate::rpc::RpcError;
use crate::sql::Part;
use crate::sql::{Array, Value};
use crate::sql::{Array, Number, Part, Value};
use once_cell::sync::Lazy;

use super::method::Method;

pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]);

pub struct Request {
pub id: Option<Value>,
pub method: String,
pub method: Method,
pub params: Array,
}

Expand Down Expand Up @@ -44,7 +45,11 @@ impl TryFrom<Value> for Request {
};
// Fetch the 'method' argument
let method = match val.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(),
Value::Strand(v) => Method::parse(v.to_raw()),
Value::Number(Number::Int(v)) => match u8::try_from(v) {
Ok(v) => Method::from(v),
_ => return Err(RpcError::InvalidRequest),
},
_ => return Err(RpcError::InvalidRequest),
};
// Fetch the 'params' argument
Expand Down
4 changes: 2 additions & 2 deletions core/src/rpc/rpc_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait RpcContext {
async { unreachable!() }
}

async fn execute(&mut self, method: Method, params: Array) -> Result<Data, RpcError> {
async fn execute(&mut self, method: &Method, params: Array) -> Result<Data, RpcError> {
match method {
Method::Ping => Ok(Value::None.into()),
Method::Info => self.info().await.map(Into::into).map_err(Into::into),
Expand Down Expand Up @@ -65,7 +65,7 @@ pub trait RpcContext {
}
}

async fn execute_immut(&self, method: Method, params: Array) -> Result<Data, RpcError> {
async fn execute_immut(&self, method: &Method, params: Array) -> Result<Data, RpcError> {
match method {
Method::Ping => Ok(Value::None.into()),
Method::Info => self.info().await.map(Into::into).map_err(Into::into),
Expand Down
3 changes: 1 addition & 2 deletions src/net/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use http_body::Body as HttpBody;
use surrealdb::dbs::Session;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::format::PROTOCOLS;
use surrealdb::rpc::method::Method;
use tower_http::request_id::RequestId;
use uuid::Uuid;

Expand Down Expand Up @@ -116,7 +115,7 @@ async fn post_handler(

match fmt.req_http(body) {
Ok(req) => {
let res = rpc_ctx.execute(Method::parse(req.method), req.params).await;
let res = rpc_ctx.execute(&req.method, req.params).await;
fmt.res_http(res.into_response(None)).map_err(Error::from)
}
Err(err) => Err(Error::from(err)),
Expand Down
11 changes: 6 additions & 5 deletions src/rpc/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,17 @@ impl Connection {
// Parse the RPC request structure
match fmt.req_ws(msg) {
Ok(req) => {
let method_str = req.method.to_str();

// Now that we know the method, we can update the span and create otel context
span.record("rpc.method", &req.method);
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
span.record("rpc.method", method_str);
span.record("otel.name", format!("surrealdb.rpc/{}", method_str));
span.record(
"rpc.request_id",
req.id.clone().map(Value::as_string).unwrap_or_default(),
);
let otel_cx = Arc::new(TelemetryContext::current_with_value(
req_cx.with_method(&req.method).with_size(len),
req_cx.with_method(method_str).with_size(len),
));
// Process the message
let res =
Expand Down Expand Up @@ -333,11 +335,10 @@ impl Connection {

pub async fn process_message(
rpc: Arc<RwLock<Connection>>,
method: &str,
method: &Method,
params: Array,
) -> Result<Data, Failure> {
debug!("Process RPC request");
let method = Method::parse(method);
if !method.is_valid() {
return Err(Failure::METHOD_NOT_FOUND);
}
Expand Down
48 changes: 48 additions & 0 deletions tests/http_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,54 @@ mod http_integration {
Ok(())
}

#[test(tokio::test)]
async fn rpc_endpoint_with_method_as_string() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let url = &format!("http://{addr}/rpc");

// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", Ulid::new().to_string().parse()?);
headers.insert("DB", Ulid::new().to_string().parse()?);
headers.insert(header::ACCEPT, "application/json".parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()?;

let res = client.post(url).json(&json!({"method": "ping"})).send().await?;
assert!(res.status().is_success());

let result = res.text().await?;
assert_eq!(result, "{\"result\":null}");

Ok(())
}

#[test(tokio::test)]
async fn rpc_endpoint_with_method_as_number() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let url = &format!("http://{addr}/rpc");

// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", Ulid::new().to_string().parse()?);
headers.insert("DB", Ulid::new().to_string().parse()?);
headers.insert(header::ACCEPT, "application/json".parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()?;

let res = client.post(url).json(&json!({"method": 1})).send().await?;
assert!(res.status().is_success());

let result = res.text().await?;
assert_eq!(result, "{\"result\":null}");

Ok(())
}

#[test(tokio::test)]
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_auth_level().await.unwrap();
Expand Down