Skip to content

Commit

Permalink
fix(rpc module): fail subscription calls with bad params (#728)
Browse files Browse the repository at this point in the history
* fix(rpc module): fail subscription with bad params

* draft; show my point

* fix tests

* fix build

* add tests for proc macros too

* add tests for bad params in proc macros

* fix nits

* commit all files

* add ugly fix for proc macro code

* add more user friendly API

* make SubscriptionSink::close take mut self

* fix grumbles

* show james some code

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* remove needless clone

* fix build

* client fix docs + error type

* simplify code: merge connect reset and unsubscribe close reason

* remove unknown close reason

* refactor: remove Error::SubscriptionClosed

* add some nice APIs to ErrorObjectOwned

* unify api

* address grumbles

* remove redundant methods for close and reject

* proc macro: compile err when subscription -> Result

* rpc module: fix test subscription test

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <james@jsdw.me>

* Update proc-macros/src/lib.rs

Co-authored-by: James Wilson <james@jsdw.me>

* address grumbles

* remove faulty comment

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* fix: don't send `RPC Call failed: error`.

* remove debug assert

Co-authored-by: James Wilson <james@jsdw.me>
Co-authored-by: David <dvdplm@gmail.com>
  • Loading branch information
3 people committed Apr 20, 2022
1 parent b96a54b commit 9fa817d
Show file tree
Hide file tree
Showing 42 changed files with 920 additions and 537 deletions.
7 changes: 5 additions & 2 deletions benches/helpers.rs
Expand Up @@ -141,10 +141,13 @@ pub async fn ws_server(handle: tokio::runtime::Handle) -> (String, jsonrpsee::ws
let mut module = gen_rpc_module();

module
.register_subscription(SUB_METHOD_NAME, SUB_METHOD_NAME, UNSUB_METHOD_NAME, |_params, mut sink, _ctx| {
.register_subscription(SUB_METHOD_NAME, SUB_METHOD_NAME, UNSUB_METHOD_NAME, |_params, pending, _ctx| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let x = "Hello";
tokio::spawn(async move { sink.send(&x) });
Ok(())
})
.unwrap();

Expand Down
5 changes: 3 additions & 2 deletions client/http-client/src/client.rs
Expand Up @@ -32,6 +32,7 @@ use crate::types::{ErrorResponse, Id, NotificationSer, ParamsSer, RequestSer, Re
use async_trait::async_trait;
use jsonrpsee_core::client::{CertificateStore, ClientT, IdKind, RequestIdManager, Subscription, SubscriptionClientT};
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::error::CallError;
use rustc_hash::FxHashMap;
use serde::de::DeserializeOwned;

Expand Down Expand Up @@ -147,7 +148,7 @@ impl ClientT for HttpClient {
Ok(response) => response,
Err(_) => {
let err: ErrorResponse = serde_json::from_slice(&body).map_err(Error::ParseError)?;
return Err(Error::Call(err.error.to_call_error()));
return Err(Error::Call(CallError::Custom(err.error_object().clone().into_owned())));
}
};

Expand Down Expand Up @@ -186,7 +187,7 @@ impl ClientT for HttpClient {

let rps: Vec<Response<_>> =
serde_json::from_slice(&body).map_err(|_| match serde_json::from_slice::<ErrorResponse>(&body) {
Ok(e) => Error::Call(e.error.to_call_error()),
Ok(e) => Error::Call(CallError::Custom(e.error_object().clone().into_owned())),
Err(e) => Error::ParseError(e),
})?;

Expand Down
19 changes: 10 additions & 9 deletions client/http-client/src/tests.rs
Expand Up @@ -33,7 +33,7 @@ use jsonrpsee_core::Error;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::Id;
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::CallError;
use jsonrpsee_types::error::{CallError, ErrorObjectOwned};

#[tokio::test]
async fn method_call_works() {
Expand Down Expand Up @@ -98,34 +98,34 @@ async fn response_with_wrong_id() {
async fn response_method_not_found() {
let err =
run_request_with_response(method_not_found(Id::Num(0))).with_default_timeout().await.unwrap().unwrap_err();
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::MethodNotFound).to_call_error());
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::MethodNotFound).into_owned());
}

#[tokio::test]
async fn response_parse_error() {
let err = run_request_with_response(parse_error(Id::Num(0))).with_default_timeout().await.unwrap().unwrap_err();
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::ParseError).to_call_error());
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::ParseError).into_owned());
}

#[tokio::test]
async fn invalid_request_works() {
let err =
run_request_with_response(invalid_request(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InvalidRequest).to_call_error());
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InvalidRequest).into_owned());
}

#[tokio::test]
async fn invalid_params_works() {
let err =
run_request_with_response(invalid_params(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InvalidParams).to_call_error());
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InvalidParams).into_owned());
}

#[tokio::test]
async fn internal_error_works() {
let err =
run_request_with_response(internal_error(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InternalError).to_call_error());
assert_jsonrpc_error_response(err, ErrorObject::from(ErrorCode::InternalError).into_owned());
}

#[tokio::test]
Expand Down Expand Up @@ -172,10 +172,11 @@ async fn run_request_with_response(response: String) -> Result<String, Error> {
client.request("say_hello", None).with_default_timeout().await.unwrap()
}

fn assert_jsonrpc_error_response(err: Error, exp: CallError) {
fn assert_jsonrpc_error_response(err: Error, exp: ErrorObjectOwned) {
let exp = CallError::Custom(exp);
match &err {
Error::Call(e) => {
assert_eq!(e.to_string(), exp.to_string());
Error::Call(err) => {
assert_eq!(err.to_string(), exp.to_string());
}
e => panic!("Expected error: \"{}\", got: {:?}", err, e),
};
Expand Down
15 changes: 8 additions & 7 deletions client/ws-client/src/tests.rs
Expand Up @@ -35,7 +35,7 @@ use jsonrpsee_core::Error;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::CallError;
use jsonrpsee_types::error::{CallError, ErrorObjectOwned};
use serde_json::Value as JsonValue;

#[tokio::test]
Expand Down Expand Up @@ -109,34 +109,34 @@ async fn response_with_wrong_id() {
async fn response_method_not_found() {
let err =
run_request_with_response(method_not_found(Id::Num(0))).with_default_timeout().await.unwrap().unwrap_err();
assert_error_response(err, ErrorObject::from(ErrorCode::MethodNotFound).to_call_error());
assert_error_response(err, ErrorObject::from(ErrorCode::MethodNotFound).into_owned());
}

#[tokio::test]
async fn parse_error_works() {
let err = run_request_with_response(parse_error(Id::Num(0))).with_default_timeout().await.unwrap().unwrap_err();
assert_error_response(err, ErrorObject::from(ErrorCode::ParseError).to_call_error());
assert_error_response(err, ErrorObject::from(ErrorCode::ParseError).into_owned());
}

#[tokio::test]
async fn invalid_request_works() {
let err =
run_request_with_response(invalid_request(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_error_response(err, ErrorObject::from(ErrorCode::InvalidRequest).to_call_error());
assert_error_response(err, ErrorObject::from(ErrorCode::InvalidRequest).into_owned());
}

#[tokio::test]
async fn invalid_params_works() {
let err =
run_request_with_response(invalid_params(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_error_response(err, ErrorObject::from(ErrorCode::InvalidParams).to_call_error());
assert_error_response(err, ErrorObject::from(ErrorCode::InvalidParams).into_owned());
}

#[tokio::test]
async fn internal_error_works() {
let err =
run_request_with_response(internal_error(Id::Num(0_u64))).with_default_timeout().await.unwrap().unwrap_err();
assert_error_response(err, ErrorObject::from(ErrorCode::InternalError).to_call_error());
assert_error_response(err, ErrorObject::from(ErrorCode::InternalError).into_owned());
}

#[tokio::test]
Expand Down Expand Up @@ -283,7 +283,8 @@ async fn run_request_with_response(response: String) -> Result<String, Error> {
client.request("say_hello", None).with_default_timeout().await.unwrap()
}

fn assert_error_response(err: Error, exp: CallError) {
fn assert_error_response(err: Error, exp: ErrorObjectOwned) {
let exp = CallError::Custom(exp);
match &err {
Error::Call(e) => {
assert_eq!(e.to_string(), exp.to_string());
Expand Down
47 changes: 32 additions & 15 deletions core/src/client/async_client/helpers.rs
Expand Up @@ -28,10 +28,11 @@ use std::time::Duration;

use crate::client::async_client::manager::{RequestManager, RequestStatus};
use crate::client::{RequestMessage, TransportSenderT};
use crate::error::SubscriptionClosed;
use crate::Error;

use futures_channel::{mpsc, oneshot};
use jsonrpsee_types::error::CallError;
use jsonrpsee_types::response::SubscriptionError;
use jsonrpsee_types::{
ErrorResponse, Id, Notification, ParamsSer, RequestSer, Response, SubscriptionId, SubscriptionResponse,
};
Expand Down Expand Up @@ -81,20 +82,14 @@ pub(crate) fn process_subscription_response(
let sub_id = response.params.subscription.into_owned();
let request_id = match manager.get_request_id_by_subscription_id(&sub_id) {
Some(request_id) => request_id,
None => return Err(None),
None => {
tracing::error!("Subscription ID: {:?} is not an active subscription", sub_id);
return Err(None);
}
};

match manager.as_subscription_mut(&request_id) {
Some(send_back_sink) => match send_back_sink.try_send(response.params.result.clone()) {
// The server sent a subscription closed notification, then close down the subscription.
Ok(()) if serde_json::from_value::<SubscriptionClosed>(response.params.result).is_ok() => {
if manager.remove_subscription(request_id, sub_id.clone()).is_some() {
Ok(())
} else {
tracing::error!("The server tried to close down an invalid subscription: {:?}", sub_id);
Err(None)
}
}
Some(send_back_sink) => match send_back_sink.try_send(response.params.result) {
Ok(()) => Ok(()),
Err(err) => {
tracing::error!("Dropping subscription {:?} error: {:?}", sub_id, err);
Expand All @@ -110,6 +105,27 @@ pub(crate) fn process_subscription_response(
}
}

/// Attempts to close a subscription when a [`SubscriptionError`] is received.
///
/// Returns `Ok(())` if the subscription was removed
/// Return `Err(e)` if the subscription was not found.
pub(crate) fn process_subscription_close_response(
manager: &mut RequestManager,
response: SubscriptionError<JsonValue>,
) -> Result<(), Error> {
let sub_id = response.params.subscription.into_owned();
let request_id = match manager.get_request_id_by_subscription_id(&sub_id) {
Some(request_id) => request_id,
None => {
tracing::error!("The server tried to close down an invalid subscription: {:?}", sub_id);
return Err(Error::InvalidSubscriptionId);
}
};

manager.remove_subscription(request_id, sub_id).expect("Both request ID and sub ID in RequestManager; qed");
Ok(())
}

/// Attempts to process an incoming notification
///
/// Returns Ok() if the response was successfully handled
Expand Down Expand Up @@ -217,17 +233,18 @@ pub(crate) fn build_unsubscribe_message(
/// Returns `Ok` if the response was successfully sent.
/// Returns `Err(_)` if the response ID was not found.
pub(crate) fn process_error_response(manager: &mut RequestManager, err: ErrorResponse) -> Result<(), Error> {
let id = err.id.clone().into_owned();
let id = err.id().clone().into_owned();

match manager.request_status(&id) {
RequestStatus::PendingMethodCall => {
let send_back = manager.complete_pending_call(id).expect("State checked above; qed");
let _ = send_back.map(|s| s.send(Err(Error::Call(err.error.to_call_error()))));
let _ =
send_back.map(|s| s.send(Err(Error::Call(CallError::Custom(err.error_object().clone().into_owned())))));
Ok(())
}
RequestStatus::PendingSubscription => {
let (_, send_back, _) = manager.complete_pending_subscription(id).expect("State checked above; qed");
let _ = send_back.send(Err(Error::Call(err.error.to_call_error())));
let _ = send_back.send(Err(Error::Call(CallError::Custom(err.error_object().clone().into_owned()))));
Ok(())
}
_ => Err(Error::InvalidRequestId),
Expand Down
13 changes: 10 additions & 3 deletions core/src/client/async_client/mod.rs
Expand Up @@ -4,8 +4,9 @@ mod manager;
use std::time::Duration;

use crate::client::{
BatchMessage, ClientT, RegisterNotificationMessage, RequestMessage, Subscription, SubscriptionClientT,
SubscriptionKind, SubscriptionMessage, TransportReceiverT, TransportSenderT,
async_client::helpers::process_subscription_close_response, BatchMessage, ClientT, RegisterNotificationMessage,
RequestMessage, Subscription, SubscriptionClientT, SubscriptionKind, SubscriptionMessage, TransportReceiverT,
TransportSenderT,
};
use helpers::{
build_unsubscribe_message, call_with_timeout, process_batch_response, process_error_response, process_notification,
Expand All @@ -20,7 +21,8 @@ use futures_util::future::Either;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use jsonrpsee_types::{
ErrorResponse, Id, Notification, NotificationSer, ParamsSer, RequestSer, Response, SubscriptionResponse,
response::SubscriptionError, ErrorResponse, Id, Notification, NotificationSer, ParamsSer, RequestSer, Response,
SubscriptionResponse,
};
use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -489,6 +491,11 @@ async fn background_task<S: TransportSenderT, R: TransportReceiverT>(
let _ = stop_subscription(&mut sender, &mut manager, unsub).await;
}
}
// Subscription error response.
else if let Ok(response) = serde_json::from_str::<SubscriptionError<_>>(&raw) {
tracing::debug!("[backend]: recv subscription closed {:?}", response);
let _ = process_subscription_close_response(&mut manager, response);
}
// Incoming Notification
else if let Ok(notif) = serde_json::from_str::<Notification<_>>(&raw) {
tracing::debug!("[backend]: recv notification {:?}", notif);
Expand Down
20 changes: 4 additions & 16 deletions core/src/client/mod.rs
Expand Up @@ -30,15 +30,15 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task;

use crate::error::{Error, SubscriptionClosed};
use crate::error::Error;
use async_trait::async_trait;
use core::marker::PhantomData;
use futures_channel::{mpsc, oneshot};
use futures_util::future::FutureExt;
use futures_util::sink::SinkExt;
use futures_util::stream::{Stream, StreamExt};
use jsonrpsee_types::{Id, ParamsSer, SubscriptionId};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde::de::DeserializeOwned;
use serde_json::Value as JsonValue;

#[doc(hidden)]
Expand Down Expand Up @@ -162,17 +162,6 @@ pub enum SubscriptionKind {
Method(String),
}

/// Internal type to detect whether a subscription response from
/// the server was a valid notification or should be treated as an error.
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum NotifResponse<Notif> {
/// Successful response.
Ok(Notif),
/// Subscription was closed.
Err(SubscriptionClosed),
}

/// Active subscription on the client.
///
/// It will automatically unsubscribe in the [`Subscription::drop`] so no need to explicitly call
Expand Down Expand Up @@ -301,9 +290,8 @@ where
type Item = Result<Notif, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Option<Self::Item>> {
let n = futures_util::ready!(self.notifs_rx.poll_next_unpin(cx));
let res = n.map(|n| match serde_json::from_value::<NotifResponse<Notif>>(n) {
Ok(NotifResponse::Ok(parsed)) => Ok(parsed),
Ok(NotifResponse::Err(e)) => Err(Error::SubscriptionClosed(e)),
let res = n.map(|n| match serde_json::from_value::<Notif>(n) {
Ok(parsed) => Ok(parsed),
Err(e) => Err(Error::ParseError(e)),
});
task::Poll::Ready(res)
Expand Down

0 comments on commit 9fa817d

Please sign in to comment.