Skip to content

Commit

Permalink
Fix async unix implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Maxim Zhiburt <zhiburt@gmail.com>
  • Loading branch information
zhiburt committed Feb 17, 2022
1 parent 4b18a1c commit 83b0e9a
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 195 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ keywords = ["expect", "pty", "testing", "terminal", "automation"]
readme = "README.md"

[features]
default = ["async"]
async = ["futures-lite", "futures-timer", "async-io", "blocking"]

[dependencies]
Expand Down
118 changes: 30 additions & 88 deletions src/interact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,32 +302,22 @@ impl<C> InteractOptions<Proc, Stream, crate::stream::stdin::Stdin, std::io::Stdo
}

#[cfg(feature = "async")]
impl<R, W, C> InteractOptions<R, W, C>
where
R: futures_lite::AsyncRead + std::marker::Unpin,
W: Write,
{
/// Runs interact interactively.
/// See [Session::interact]
pub async fn interact(self, session: &mut S) -> Result<WaitStatus, Error> {
interact(session, self)
}
}

#[cfg(feature = "async")]
impl<R, W, C> InteractOptions<R, W, C>
impl<P, S, R, W, C> InteractOptions<P, S, R, W, C>
where
R: Read + std::marker::Unpin,
P: Healthcheck + Unpin,
S: futures_lite::AsyncRead + futures_lite::AsyncWrite + Unpin,
R: futures_lite::AsyncRead + Unpin,
W: Write,
{
/// Runs interact interactively.
/// See [Session::interact]
#[cfg(windows)]
pub async fn interact(mut self, session: &mut Session) -> Result<(), Error> {
match self.input_from {
InputFrom::Terminal => interact_in_terminal(session, &mut self).await,
InputFrom::Other => interact(session, &mut self).await,
}
pub async fn interact(
&mut self,
session: &mut Session<P, S>,
mut input: R,
mut output: W,
) -> Result<(), Error> {
interact(self, session, &mut input, &mut output).await
}
}

Expand Down Expand Up @@ -453,63 +443,15 @@ where

// copy paste of sync version with async await syntax
#[cfg(all(unix, feature = "async"))]
async fn interact_in_terminal<R, W, C>(
session: &mut Session,
options: InteractOptions<R, W, C>,
) -> Result<WaitStatus, Error>
where
R: futures_lite::AsyncRead + std::marker::Unpin,
W: Write,
{
use futures_lite::AsyncWriteExt;

// flush buffers
session.flush().await?;

let origin_pty_echo = session.get_echo().map_err(to_io_error)?;
// tcgetattr issues error if a provided fd is not a tty,
// but we can work with such input as it may be redirected.
let origin_stdin_flags = termios::tcgetattr(STDIN_FILENO);

// verify: possible controlling fd can be stdout and stderr as well?
// https://stackoverflow.com/questions/35873843/when-setting-terminal-attributes-via-tcsetattrfd-can-fd-be-either-stdout
let isatty_terminal = isatty(STDIN_FILENO).map_err(to_io_error)?;

if isatty_terminal {
set_raw(STDIN_FILENO).map_err(to_io_error)?;
}

session.set_echo(true, None).map_err(to_io_error)?;

let result = interact(session, options).await;

if isatty_terminal {
// it's suppose to be always OK.
// but we don't use unwrap just in case.
let origin_stdin_flags = origin_stdin_flags.map_err(to_io_error)?;

termios::tcsetattr(
STDIN_FILENO,
termios::SetArg::TCSAFLUSH,
&origin_stdin_flags,
)
.map_err(to_io_error)?;
}

session
.set_echo(origin_pty_echo, None)
.map_err(to_io_error)?;

result
}

// copy paste of sync version with async await syntax
#[cfg(all(unix, feature = "async"))]
async fn interact<R, W, C>(
session: &mut Session,
mut options: InteractOptions<R, W, C>,
) -> Result<WaitStatus, Error>
async fn interact<P, S, R, W, C>(
options: &mut InteractOptions<P, S, R, W, C>,
session: &mut Session<P, S>,
input: &mut R,
output: &mut W,
) -> Result<(), Error>
where
P: Healthcheck + Unpin,
S: futures_lite::AsyncRead + futures_lite::AsyncWrite + Unpin,
R: futures_lite::AsyncRead + Unpin,
W: Write,
{
Expand All @@ -530,8 +472,8 @@ where
// fill buffer to run callbacks if there was something in.
//
// We ignore errors because there might be errors like EOCHILD etc.
let status = session.status().map_err(to_io_error).map_err(|e| e.into());
if !matches!(status, Ok(WaitStatus::StillAlive)) {
let status = session.is_alive();
if matches!(status, Ok(false)) {
exited = true;
}

Expand All @@ -543,29 +485,29 @@ where
}

output_buffer.extend_from_slice(&buf[..n]);
options.check_output(session, &mut output_buffer, eof)?;
options.check_output(input, output, session, &mut output_buffer, eof)?;

let bytes = if let Some(filter) = options.output_filter.as_mut() {
(filter)(&buf[..n])?
} else {
Cow::Borrowed(&buf[..n])
};

options.output.write_all(&bytes)?;
options.output.flush()?;
output.write_all(&bytes)?;
output.flush()?;
}

if exited {
return status;
return Ok(());
}

// We dont't print user input back to the screen.
// In terminal mode it will be ECHOed back automatically.
// This way we preserve terminal seetings for example when user inputs password.
// The terminal must have been prepared before.
match options.input.read(&mut buf).await {
match input.read(&mut buf).await {
Ok(0) => {
return status;
return Ok(());
}
Ok(n) => {
let bytes = &buf[..n];
Expand All @@ -578,7 +520,7 @@ where
let buffer = if let Some(check_buffer) = input_buffer.as_mut() {
check_buffer.extend_from_slice(&bytes);
loop {
match options.check_input(session, check_buffer)? {
match options.check_input(input, output, session, check_buffer)? {
Match::Yes(n) => {
check_buffer.drain(..n);
if check_buffer.is_empty() {
Expand All @@ -602,7 +544,7 @@ where
match escape_char_position {
Some(pos) => {
session.write_all(&buffer[..pos]).await?;
return status;
return Ok(());
}
None => {
session.write_all(&buffer[..]).await?;
Expand All @@ -613,7 +555,7 @@ where
Err(err) => return Err(err.into()),
}

options.call_idle_handler(session)?;
options.call_idle_handler(input, output, session)?;
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/process/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ impl IntoAsyncStream for PtyStream {
}
}

impl AsRawFd for PtyStream {
fn as_raw_fd(&self) -> RawFd {
self.handle.as_raw_fd()
}
}

#[cfg(feature = "async")]
pub struct AsyncPtyStream {
stream: async_io::Async<PtyStream>,
Expand Down
120 changes: 19 additions & 101 deletions src/session/async_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,114 +10,32 @@ use std::{
use futures_lite::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt};

use super::async_stream::Stream;
use crate::{
process::{self, IntoAsyncStream, Process},
stream::log::LoggedStream,
ControlCode, Error, Found, Needle,
};

#[cfg(unix)]
pub type Session = PtySession<process::unix::UnixProcess, process::unix::AsyncPtyStream>;

#[cfg(windows)]
pub type Session =
PtySession<process::windows::WinProcess, blocking::Unblock<process::windows::ProcessStream>>;
use crate::{stream::log::LoggedStream, ControlCode, Error, Found, Needle};

impl<P> PtySession<P, <P::Stream as IntoAsyncStream>::AsyncsStream>
where
P: Process,
P::Stream: IntoAsyncStream,
{
pub fn spawn(command: P::Command) -> Result<Self, Error> {
let mut process = P::spawn_command(command)?;
let stream = process.open_stream()?;
let stream = stream.into_async_stream()?;
let session = Self::new(process, stream)?;

Ok(session)
}

pub fn spawn_cmd(cmd: impl AsRef<str>) -> Result<Self, Error> {
let mut process = P::spawn(cmd.as_ref())?;
let stream = process.open_stream()?;
let stream = stream.into_async_stream()?;
let session = Self::new(process, stream)?;

Ok(session)
}
/// Session represents a spawned process and its streams.
/// It controlls process and communication with it.
#[derive(Debug)]
pub struct Session<P, S> {
process: P,
stream: Stream<S>,
}

impl<P, S> Session<P, S> {
/// Set logger.
pub async fn with_log<W: io::Write>(
self,
logger: W,
) -> Result<PtySession<P, LoggedStream<<P::Stream as IntoAsyncStream>::AsyncsStream, W>>, Error>
{
) -> Result<Session<P, LoggedStream<S, W>>, Error> {
let stream = self.stream.into_inner();
let stream = LoggedStream::new(stream, logger);
let session = PtySession::new(self.process, stream)?;
let session = Session::new(self.process, stream)?;
Ok(session)
}
}

impl Session {
/// Interact gives control of the child process to the interactive user (the
/// human at the keyboard).
///
/// Returns a status of a process ater interactions.
/// Why it's crusial to return a status is after check of is_alive the actuall
/// status might be gone.
///
/// Keystrokes are sent to the child process, and
/// the `stdout` and `stderr` output of the child process is printed.
///
/// When the user types the `escape_character` this method will return control to a running process.
/// The escape_character will not be transmitted.
/// The default for escape_character is entered as `Ctrl-]`, the very same as BSD telnet.
///
/// This simply echos the child `stdout` and `stderr` to the real `stdout` and
/// it echos the real `stdin` to the child `stdin`.
#[cfg(unix)]
pub async fn interact(&mut self) -> Result<crate::WaitStatus, Error> {
crate::interact::InteractOptions::terminal()?
.interact(self)
.await
}

/// Interact gives control of the child process to the interactive user (the
/// human at the keyboard).
///
/// Returns a status of a process ater interactions.
/// Why it's crusial to return a status is after check of is_alive the actuall
/// status might be gone.
///
/// Keystrokes are sent to the child process, and
/// the `stdout` and `stderr` output of the child process is printed.
///
/// When the user types the `escape_character` this method will return control to a running process.
/// The escape_character will not be transmitted.
/// The default for escape_character is entered as `Ctrl-]`, the very same as BSD telnet.
///
/// This simply echos the child `stdout` and `stderr` to the real `stdout` and
/// it echos the real `stdin` to the child `stdin`.
#[cfg(windows)]
pub async fn interact(&mut self) -> Result<(), Error> {
crate::interact::InteractOptions::terminal()?
.interact(self)
.await
}
}

/// Session represents a spawned process and its streams.
/// It controlls process and communication with it.
#[derive(Debug)]
pub struct PtySession<P, S> {
process: P,
stream: Stream<S>,
}

// GEt back to the solution where Logger is just dyn Write instead of all these magic with type system.....

impl<P, S> PtySession<P, S> {
impl<P, S> Session<P, S> {
pub fn new(process: P, stream: S) -> io::Result<Self> {
Ok(Self {
process,
Expand All @@ -131,7 +49,7 @@ impl<P, S> PtySession<P, S> {
}
}

impl<P, S: AsyncRead + Unpin> PtySession<P, S> {
impl<P, S: AsyncRead + Unpin> Session<P, S> {
pub async fn expect<N: Needle>(&mut self, needle: N) -> Result<Found, Error> {
self.stream.expect(needle).await
}
Expand Down Expand Up @@ -168,7 +86,7 @@ impl<P, S: AsyncRead + Unpin> PtySession<P, S> {
}
}

impl<P, S: AsyncWrite + Unpin> PtySession<P, S> {
impl<P, S: AsyncWrite + Unpin> Session<P, S> {
/// Send text to child's `STDIN`.
///
/// To write bytes you can use a [std::io::Write] operations instead.
Expand Down Expand Up @@ -216,21 +134,21 @@ impl<P, S: AsyncWrite + Unpin> PtySession<P, S> {
}
}

impl<P, S> Deref for PtySession<P, S> {
impl<P, S> Deref for Session<P, S> {
type Target = P;

fn deref(&self) -> &Self::Target {
&self.process
}
}

impl<P, S> DerefMut for PtySession<P, S> {
impl<P, S> DerefMut for Session<P, S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.process
}
}

impl<P: Unpin, S: AsyncWrite + Unpin> AsyncWrite for PtySession<P, S> {
impl<P: Unpin, S: AsyncWrite + Unpin> AsyncWrite for Session<P, S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -256,7 +174,7 @@ impl<P: Unpin, S: AsyncWrite + Unpin> AsyncWrite for PtySession<P, S> {
}
}

impl<P: Unpin, S: AsyncRead + Unpin> AsyncRead for PtySession<P, S> {
impl<P: Unpin, S: AsyncRead + Unpin> AsyncRead for Session<P, S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -266,7 +184,7 @@ impl<P: Unpin, S: AsyncRead + Unpin> AsyncRead for PtySession<P, S> {
}
}

impl<P: Unpin, S: AsyncRead + Unpin> AsyncBufRead for PtySession<P, S> {
impl<P: Unpin, S: AsyncRead + Unpin> AsyncBufRead for Session<P, S> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().stream).poll_fill_buf(cx)
}
Expand Down

0 comments on commit 83b0e9a

Please sign in to comment.