diff --git a/Cargo.toml b/Cargo.toml index 2bb37a7..e9c308b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,23 +12,22 @@ repository = "https://github.com/wiktor-k/ssh-agent.rs" edition = "2021" [dependencies] -byteorder = "1.2.7" -serde = {version = "1.0.87", features = ["derive"]} +byteorder = "1.4" +serde = {version = "1", features = ["derive"]} -bytes = { version = "0.4.11", optional = true } -futures = { version = "0.1.25", optional = true } -log = { version = "0.4.6", optional = true } -tokio = { version = "0.1.15", optional = true } -tokio-uds = { version = "0.2.5", optional = true } +bytes = { version = "1.1", optional = true } +#futures = { version = "0.1.25", optional = true } +futures = { version = "0.3.30", optional = true } +log = { version = "0.4.16", optional = true } +tokio = { version = "1", optional = true, features = ["rt", "net", "rt-multi-thread"] } +tokio-util = { version = "0.7", optional = true, features = ["codec"] } +tokio-stream = { version = "0.1.14", optional = true, features = ["net"] } +#tokio-uds = { version = "0.2.5", optional = true } [features] default = ["agent"] -agent = ["futures", "log", "tokio", "tokio-uds", "bytes"] +agent = ["log", "tokio", "tokio-util", "bytes", "futures", "tokio-stream"] [[example]] name = "key_storage" required-features = ["agent"] - -[dev-dependencies] -env_logger = "0.6.0" -openssl = "0.10.16" diff --git a/src/agent.rs b/src/agent.rs index 57d3de7..e902802 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,18 +1,24 @@ use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{BufMut, BytesMut}; -use futures::future::FutureResult; +use bytes::BytesMut; use log::{error, info}; -use tokio::codec::{Decoder, Encoder, Framed}; use tokio::net::TcpListener; -use tokio::prelude::*; -use tokio_uds::UnixListener; +use tokio::net::UnixListener; +use tokio_util::codec::{Decoder, Encoder, Framed}; use std::error::Error; use std::fmt::Debug; +use std::future::Future; use std::mem::size_of; use std::net::SocketAddr; use std::path::Path; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; + +use bytes::Buf; +use bytes::BufMut; +use futures::SinkExt; +use futures::StreamExt; +use futures::TryFutureExt; +use futures::TryStreamExt; use super::error::AgentError; use super::proto::message::Message; @@ -24,7 +30,7 @@ impl Decoder for MessageCodec { type Item = Message; type Error = AgentError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { let mut bytes = &src[..]; if bytes.len() < size_of::() { @@ -43,39 +49,41 @@ impl Decoder for MessageCodec { } } -impl Encoder for MessageCodec { - type Item = Message; +impl Encoder for MessageCodec { type Error = AgentError; - fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { let bytes = to_bytes(&to_bytes(&item)?)?; - dst.put(bytes); + dst.put(&bytes[..]); Ok(()) } } macro_rules! handle_clients { - ($self:ident, $socket:ident) => {{ + ($self:ident, $wrapper:ident, $socket:ident) => {{ + use futures::FutureExt; + use futures::TryFutureExt; info!("Listening; socket = {:?}", $socket); - let arc_self = Arc::new($self); - $socket - .incoming() + let arc_self = Arc::new(Mutex::new($self)); + tokio_stream::wrappers::$wrapper::new($socket) .map_err(|e| error!("Failed to accept socket; error = {:?}", e)) .for_each(move |socket| { - let (write, read) = Framed::new(socket, MessageCodec).split(); - let arc_self = arc_self.clone(); - let connection = write - .send_all(read.and_then(move |message| { - arc_self.handle_async(message).map_err(|e| { - error!("Error handling message; error = {:?}", e); - AgentError::User - }) - })) - .map(|_| ()) - .map_err(|e| error!("Error while handling message; error = {:?}", e)); - tokio::spawn(connection) + let socket = socket.unwrap(); //FIXME + let (mut write, read) = Framed::new(socket, MessageCodec).split(); + let arc_self = Arc::clone(&arc_self); + let arc_self = arc_self.lock().unwrap(); + let connection = write.send_all(&mut std::pin::pin!(read.and_then(|message| { + arc_self.handle_async(message).map_err(move |e| { + error!("Error handling message; error = {:?}", e); + AgentError::User + }) + }))); + //.map(move |_| ()) + //.map_err(|e| error!("Error while handling message; error = {:?}", e)); + tokio::task::spawn_local(connection).map(move |_| ()) + //async { tokio::task::spawn_local(connection); } }) - .map_err(|e| e.into()) + //.map_err(|e| e.into()) }}; } @@ -84,16 +92,15 @@ pub trait Agent: 'static + Sync + Send + Sized { fn handle(&self, message: Message) -> Result; - fn handle_async( - &self, - message: Message, - ) -> Box + Send + Sync> { - Box::new(FutureResult::from(self.handle(message))) + async fn handle_async(&self, message: Message) -> Result { + self.handle(message) } #[allow(clippy::unit_arg)] fn run_listener(self, socket: UnixListener) -> Result<(), Box> { - Ok(tokio::run(handle_clients!(self, socket))) + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let res = rt.block_on(handle_clients!(self, UnixListenerStream, socket)); + Ok(res) } fn run_unix(self, path: impl AsRef) -> Result<(), Box> { @@ -102,7 +109,13 @@ pub trait Agent: 'static + Sync + Send + Sized { #[allow(clippy::unit_arg)] fn run_tcp(self, addr: &str) -> Result<(), Box> { - let socket = TcpListener::bind(&addr.parse::()?)?; - Ok(tokio::run(handle_clients!(self, socket))) + let mut rt = tokio::runtime::Runtime::new().unwrap(); + let res = rt.block_on(async { + let socket = TcpListener::bind(&addr.parse::().unwrap()) + .await + .unwrap(); // FIXMEx2 + handle_clients!(self, TcpListenerStream, socket); + }); + Ok(res) } }