From 866151728744d8e7e61f6e8c49caf9607aca78d7 Mon Sep 17 00:00:00 2001 From: zz <517669936@qq.com> Date: Fri, 29 Apr 2022 10:50:51 +0800 Subject: [PATCH] add ws reconnection --- .gitignore | 1 + .../contract_log_filter_loss_connection.rs | 85 ++++++++++++ examples/readme.md | 2 +- examples/transport_ws_loss_connection.rs | 28 ++++ src/api/eth_subscribe.rs | 4 +- src/lib.rs | 2 +- src/transports/either.rs | 6 +- src/transports/ipc.rs | 13 +- src/transports/ws.rs | 126 ++++++++++++++++-- 9 files changed, 244 insertions(+), 23 deletions(-) create mode 100644 examples/contract_log_filter_loss_connection.rs create mode 100644 examples/transport_ws_loss_connection.rs diff --git a/.gitignore b/.gitignore index ded91645..03459db3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target Cargo.lock *.swp .idea/ +db diff --git a/examples/contract_log_filter_loss_connection.rs b/examples/contract_log_filter_loss_connection.rs new file mode 100644 index 00000000..38a3f105 --- /dev/null +++ b/examples/contract_log_filter_loss_connection.rs @@ -0,0 +1,85 @@ +use std::time::Duration; + +use ethabi::Address; +use futures::{future, TryStreamExt}; +use hex_literal::hex; +use web3::{ + api::BaseFilter, + contract::{Contract, Options}, + transports::WebSocket, + types::{Filter, FilterBuilder, Log}, + Web3, +}; + +#[tokio::main] +async fn main() -> web3::contract::Result<()> { + let _ = env_logger::try_init(); + let transport = web3::transports::WebSocket::new("ws://localhost:8545").await?; + let web3 = web3::Web3::new(transport); + + println!("Calling accounts."); + let accounts = web3.eth().accounts().await?; + + let bytecode = include_str!("./res/SimpleEvent.bin"); + let contract = Contract::deploy(web3.eth(), include_bytes!("./res/SimpleEvent.abi"))? + .confirmations(1) + .poll_interval(Duration::from_secs(10)) + .options(Options::with(|opt| opt.gas = Some(3_000_000u64.into()))) + .execute(bytecode, (), accounts[0]) + .await + .unwrap(); + + println!("contract deployed at: {}", contract.address()); + + tokio::spawn(interval_contract_call(contract.clone(), accounts[0])); + + // Filter for Hello event in our contract + let filter = FilterBuilder::default() + .address(vec![contract.address()]) + .topics( + Some(vec![hex!( + "d282f389399565f3671145f5916e51652b60eee8e5c759293a2f5771b8ddfd2e" + ) + .into()]), + None, + None, + None, + ) + .build(); + + loop { + let filter = get_filter(web3.clone(), &filter).await; + let logs_stream = filter.stream(Duration::from_secs(2)); + let res = logs_stream + .try_for_each(|log| { + println!("Get log: {:?}", log); + future::ready(Ok(())) + }) + .await; + + if let Err(e) = res { + println!("Log Filter Error: {}", e); + } + } +} + +async fn interval_contract_call(contract: Contract, account: Address) { + loop { + match contract.call("hello", (), account, Options::default()).await { + Ok(tx) => println!("got tx: {:?}", tx), + Err(e) => println!("get tx failed: {}", e), + } + + tokio::time::sleep(Duration::from_secs(1)).await; + } +} + +pub async fn get_filter(web3: Web3, filter: &Filter) -> BaseFilter { + loop { + match web3.eth_filter().create_logs_filter(filter.clone()).await { + Err(e) => println!("get filter failed: {}", e), + Ok(filter) => return filter, + } + tokio::time::sleep(Duration::from_secs(1)).await; + } +} diff --git a/examples/readme.md b/examples/readme.md index 5a86535b..f746de6f 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -1,6 +1,6 @@ First, run ganache - ganache-cli -b 3 -m "hamster coin cup brief quote trick stove draft hobby strong caught unable" + ganache-cli -b 3 -m "hamster coin cup brief quote trick stove draft hobby strong caught unable" --db ./db Using this mnemonic makes the static account addresses in the example line up diff --git a/examples/transport_ws_loss_connection.rs b/examples/transport_ws_loss_connection.rs new file mode 100644 index 00000000..1b778bcd --- /dev/null +++ b/examples/transport_ws_loss_connection.rs @@ -0,0 +1,28 @@ +use std::time::Duration; + +use ethabi::Address; +use web3::{api::Eth, transports::WebSocket}; + +#[tokio::main] +async fn main() -> web3::Result<()> { + let _ = env_logger::try_init(); + let transport = web3::transports::WebSocket::new("ws://localhost:8545").await?; + let web3 = web3::Web3::new(transport); + + println!("Calling accounts."); + let accounts = web3.eth().accounts().await?; + + interval_balance(&web3.eth(), accounts[0]).await; + + Ok(()) +} + +async fn interval_balance(eth: &Eth, account: Address) { + loop { + match eth.balance(account, None).await { + Ok(balance) => println!("Balance of {:?}: {}", account, balance), + Err(e) => println!("Get balance failed: {}", e), + } + tokio::time::sleep(Duration::from_secs(2)).await; + } +} diff --git a/src/api/eth_subscribe.rs b/src/api/eth_subscribe.rs index f5edae0a..69410657 100644 --- a/src/api/eth_subscribe.rs +++ b/src/api/eth_subscribe.rs @@ -52,7 +52,7 @@ pub struct SubscriptionStream { id: SubscriptionId, #[pin] rx: T::NotificationStream, - _marker: PhantomData, + _marker: PhantomData>, } impl SubscriptionStream { @@ -90,7 +90,7 @@ where fn poll_next(self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { let this = self.project(); let x = ready!(this.rx.poll_next(ctx)); - Poll::Ready(x.map(|result| serde_json::from_value(result).map_err(Into::into))) + Poll::Ready(x.map(|result| result.and_then(|v| serde_json::from_value(v).map_err(Into::into)))) } } diff --git a/src/lib.rs b/src/lib.rs index fd1eb5c3..44362930 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,7 +75,7 @@ pub trait BatchTransport: Transport { /// A transport implementation supporting pub sub subscriptions. pub trait DuplexTransport: Transport { /// The type of stream this transport returns - type NotificationStream: futures::Stream; + type NotificationStream: futures::Stream>; /// Add a subscription to this transport fn subscribe(&self, id: api::SubscriptionId) -> error::Result; diff --git a/src/transports/either.rs b/src/transports/either.rs index d86c2fa8..e262d4c8 100644 --- a/src/transports/either.rs +++ b/src/transports/either.rs @@ -72,10 +72,10 @@ where B: DuplexTransport, A::Out: 'static + Send, B::Out: 'static + Send, - AStream: futures::Stream + 'static + Send, - BStream: futures::Stream + 'static + Send, + AStream: futures::Stream> + 'static + Send, + BStream: futures::Stream> + 'static + Send, { - type NotificationStream = BoxStream<'static, rpc::Value>; + type NotificationStream = BoxStream<'static, error::Result>; fn subscribe(&self, id: api::SubscriptionId) -> error::Result { Ok(match *self { diff --git a/src/transports/ipc.rs b/src/transports/ipc.rs index 866cd13f..b5fa4a28 100644 --- a/src/transports/ipc.rs +++ b/src/transports/ipc.rs @@ -1,8 +1,9 @@ //! IPC transport use crate::{ - api::SubscriptionId, error::TransportError, helpers, BatchTransport, DuplexTransport, Error, RequestId, Result, - Transport, + api::SubscriptionId, + error::{self, TransportError}, + helpers, BatchTransport, DuplexTransport, Error, RequestId, Result, Transport, }; use futures::{ future::{join_all, JoinAll}, @@ -99,7 +100,7 @@ impl BatchTransport for Ipc { } impl DuplexTransport for Ipc { - type NotificationStream = UnboundedReceiverStream; + type NotificationStream = UnboundedReceiverStream>; fn subscribe(&self, id: SubscriptionId) -> Result { let (tx, rx) = mpsc::unbounded_channel(); @@ -158,7 +159,7 @@ type TransportRequest = (RequestId, rpc::Call, oneshot::Sender); enum TransportMessage { Single(TransportRequest), Batch(Vec), - Subscribe(SubscriptionId, mpsc::UnboundedSender), + Subscribe(SubscriptionId, mpsc::UnboundedSender>), Unsubscribe(SubscriptionId), } @@ -262,7 +263,7 @@ async fn run_server(unix_stream: UnixStream, messages_rx: UnboundedReceiverStrea } fn notify( - subscription_txs: &mut BTreeMap>, + subscription_txs: &mut BTreeMap>>, notification: rpc::Notification, ) -> std::result::Result<(), ()> { if let rpc::Params::Map(params) = notification.params { @@ -272,7 +273,7 @@ fn notify( if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { let id: SubscriptionId = id.clone().into(); if let Some(tx) = subscription_txs.get(&id) { - if let Err(e) = tx.send(result.clone()) { + if let Err(e) = tx.send(Ok(result.clone())) { log::error!("Error sending notification: {:?} (id: {:?}", e, id); } } else { diff --git a/src/transports/ws.rs b/src/transports/ws.rs index 429bdfb8..9b0f72b4 100644 --- a/src/transports/ws.rs +++ b/src/transports/ws.rs @@ -14,6 +14,7 @@ use futures::{ use soketto::{ connection, handshake::{Client, ServerResponse}, + Receiver, Sender, }; use std::{ collections::BTreeMap, @@ -22,6 +23,7 @@ use std::{ pin::Pin, sync::{atomic, Arc}, }; +use tokio_util::compat::Compat; use url::Url; impl From for Error { @@ -39,7 +41,7 @@ impl From for Error { type SingleResult = error::Result; type BatchResult = error::Result>; type Pending = oneshot::Sender; -type Subscription = mpsc::UnboundedSender; +type Subscription = mpsc::UnboundedSender>; /// Stream, either plain TCP or TLS. enum MaybeTlsStream { @@ -91,12 +93,83 @@ where } struct WsServerTask { + connect_info: ConnectInfo, pending: BTreeMap, subscriptions: BTreeMap, sender: connection::Sender>, receiver: connection::Receiver>, } +struct ConnectInfo { + scheme: String, + host: String, + port: u16, + addrs: String, + resource: String, + username: String, + password: Option, +} + +impl ConnectInfo { + pub async fn get_connection( + &self, + ) -> error::Result<( + Sender< + MaybeTlsStream, Compat>>, + >, + Receiver< + MaybeTlsStream, Compat>>, + >, + )> { + let stream = compat::raw_tcp_stream(self.addrs.clone()).await?; + stream.set_nodelay(true)?; + let socket = if self.scheme == "wss" { + #[cfg(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std"))] + { + let stream = async_native_tls::connect(&self.host, stream).await?; + MaybeTlsStream::Tls(compat::compat(stream)) + } + #[cfg(not(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std")))] + panic!("The library was compiled without TLS support. Enable ws-tls-tokio or ws-tls-async-std feature."); + } else { + let stream = compat::compat(stream); + MaybeTlsStream::Plain(stream) + }; + + let mut client = Client::new(socket, &self.host, &self.resource); + let maybe_encoded = self.password.as_ref().map(|password| { + use headers::authorization::{Authorization, Credentials}; + Authorization::basic(&self.username, &password) + .0 + .encode() + .as_bytes() + .to_vec() + }); + + let headers = maybe_encoded.as_ref().map(|head| { + [soketto::handshake::client::Header { + name: "Authorization", + value: head, + }] + }); + + if let Some(ref head) = headers { + client.set_headers(head); + } + let handshake = client.handshake(); + let (sender, receiver) = match handshake.await? { + ServerResponse::Accepted { .. } => client.into_builder().finish(), + ServerResponse::Redirect { status_code, .. } => { + return Err(error::Error::Transport(TransportError::Code(status_code))) + } + ServerResponse::Rejected { status_code } => { + return Err(error::Error::Transport(TransportError::Code(status_code))) + } + }; + Ok((sender, receiver)) + } +} + impl WsServerTask { /// Create new WebSocket transport. pub async fn new(url: &str) -> error::Result { @@ -126,7 +199,7 @@ impl WsServerTask { let addrs = format!("{}:{}", host, port); log::trace!("Connecting TcpStream with address: {}", addrs); - let stream = compat::raw_tcp_stream(addrs).await?; + let stream = compat::raw_tcp_stream(addrs.clone()).await?; stream.set_nodelay(true)?; let socket = if scheme == "wss" { #[cfg(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std"))] @@ -183,6 +256,15 @@ impl WsServerTask { }; Ok(Self { + connect_info: ConnectInfo { + scheme: scheme.to_owned(), + resource: resource, + username: url.username().to_owned(), + password: url.password().map(|v| v.to_owned()), + host: host.to_owned(), + port, + addrs, + }, pending: Default::default(), subscriptions: Default::default(), sender, @@ -192,6 +274,7 @@ impl WsServerTask { async fn into_task(self, requests: mpsc::UnboundedReceiver) { let Self { + connect_info, receiver, mut sender, mut pending, @@ -200,8 +283,8 @@ impl WsServerTask { let receiver = as_data_stream(receiver).fuse(); let requests = requests.fuse(); - pin_mut!(receiver); - pin_mut!(requests); + let mut receiver = Box::pin(receiver); + let mut requests = Box::pin(requests); loop { select! { msg = requests.next() => match msg { @@ -213,7 +296,7 @@ impl WsServerTask { let res2 = sender.flush().await; if let Err(e) = res.and(res2) { // TODO [ToDr] Re-connect. - log::error!("WS connection error: {:?}", e); + log::warn!("WS connection error: {:?}", e); pending.remove(&id); } } @@ -234,8 +317,31 @@ impl WsServerTask { handle_message(&data, &subscriptions, &mut pending); }, Some(Err(e)) => { - log::error!("WS connection error: {:?}", e); - break; + log::warn!("WS connection error: {:?}", e); + + for (_id, request) in pending{ + if let Err(err) = request.send(Err(Error::Transport(TransportError::Message(format!("WS connection error: {:?}",e))))) { + log::warn!("Sending a response to deallocated channel: {:?}", err); + } + } + + pending = BTreeMap::new(); + + for (id,stream) in subscriptions{ + if let Err(e) = stream.unbounded_send(Err(Error::Transport(TransportError::Message(format!("WS connection error: {:?}",e))))) { + log::error!("Error sending notification: {:?} (id: {:?}", e, id); + } + } + + subscriptions = BTreeMap::new(); + + if let Ok((new_sender,new_receiver)) = connect_info.get_connection().await{ + let new_receiver = as_data_stream(new_receiver).fuse(); + let new_receiver = Box::pin(new_receiver); + + sender = new_sender; + receiver = new_receiver; + } }, None => break, }, @@ -271,7 +377,7 @@ fn handle_message( if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { let id: SubscriptionId = id.clone().into(); if let Some(stream) = subscriptions.get(&id) { - if let Err(e) = stream.unbounded_send(result.clone()) { + if let Err(e) = stream.unbounded_send(Ok(result.clone())) { log::error!("Error sending notification: {:?} (id: {:?}", e, id); } } else { @@ -318,7 +424,7 @@ enum TransportMessage { }, Subscribe { id: SubscriptionId, - sink: mpsc::UnboundedSender, + sink: mpsc::UnboundedSender>, }, Unsubscribe { id: SubscriptionId, @@ -460,7 +566,7 @@ impl BatchTransport for WebSocket { } impl DuplexTransport for WebSocket { - type NotificationStream = mpsc::UnboundedReceiver; + type NotificationStream = mpsc::UnboundedReceiver>; fn subscribe(&self, id: SubscriptionId) -> error::Result { // TODO [ToDr] Not unbounded?