Skip to content

Commit

Permalink
new ChannelFile subclass for stdin calls shutdown_write on close()
Browse files Browse the repository at this point in the history
implement paramiko#322

SSHClient.exec_command() previously returned a naive
ChannelFile object for its stdin value; such objects
don't know to properly shut down the remote end's stdin when they
`.close()` - this leads to issues when running remote commands that read
from stdin.

A new subclass, ChannelStdinFile, has been created which
closes remote stdin when it itself is closed.
SSHClient.exec_command() has been updated to use that class
for its stdin return value.

Thanks to Brandon Rhodes for the report & steps to reproduce.
  • Loading branch information
bitprophet authored and ploxiln committed Jun 15, 2019
1 parent ea456dd commit 66432e2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
2 changes: 1 addition & 1 deletion paramiko/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from paramiko.auth_handler import AuthHandler
from paramiko.ssh_gss import GSSAuth, GSS_AUTH_AVAILABLE, GSS_EXCEPTIONS
from paramiko.channel import Channel, ChannelFile, ChannelStderrFile
from paramiko.channel import Channel, ChannelFile, ChannelStderrFile, ChannelStdinFile
from paramiko.ssh_exception import (
SSHException, PasswordRequiredException, BadAuthenticationType,
ChannelException, BadHostKeyException, AuthenticationException,
Expand Down
23 changes: 23 additions & 0 deletions paramiko/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,23 @@ def makefile_stderr(self, *params):
"""
return ChannelStderrFile(*([self] + list(params)))

def makefile_stdin(self, *params):
"""
Return a file-like object associated with this channel's stdin
stream.
The optional ``mode`` and ``bufsize`` arguments are interpreted the
same way as by the built-in ``file()`` function in Python. For a
client, it only makes sense to open this file for writing. For a
server, it only makes sense to open this file for reading.
:returns:
`.ChannelStdinFile` object which can be used for Python file I/O.
.. versionadded:: 2.6
"""
return ChannelStdinFile(*([self] + list(params)))

def fileno(self):
"""
Returns an OS-level file descriptor which can be used for polling, but
Expand Down Expand Up @@ -1345,3 +1362,9 @@ def _read(self, size):
def _write(self, data):
self.channel.sendall_stderr(data)
return len(data)


class ChannelStdinFile(ChannelFile):
def close(self):
super(ChannelStdinFile, self).close()
self.channel.shutdown_write()
2 changes: 1 addition & 1 deletion paramiko/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def exec_command(
if environment:
chan.update_environment(environment)
chan.exec_command(command)
stdin = chan.makefile('wb', bufsize)
stdin = chan.makefile_stdin('wb', bufsize)
stdout = chan.makefile('r', bufsize)
stderr = chan.makefile_stderr('r', bufsize)
return stdin, stdout, stderr
Expand Down
31 changes: 25 additions & 6 deletions tests/test_channelfile.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
from mock import patch, MagicMock

from paramiko import Channel, ChannelFile, ChannelStderrFile
from paramiko import Channel, ChannelFile, ChannelStderrFile, ChannelStdinFile


class TestChannelFile(object):
class ChannelFileBase(object):
@patch("paramiko.channel.ChannelFile._set_mode")
def test_defaults_to_unbuffered_reading(self, setmode):
ChannelFile(Channel(None))
self.klass(Channel(None))
setmode.assert_called_once_with("r", -1)

@patch("paramiko.channel.ChannelFile._set_mode")
def test_can_override_mode_and_bufsize(self, setmode):
ChannelFile(Channel(None), mode="w", bufsize=25)
self.klass(Channel(None), mode="w", bufsize=25)
setmode.assert_called_once_with("w", 25)

def test_read_recvs_from_channel(self):
chan = MagicMock()
cf = ChannelFile(chan)
cf = self.klass(chan)
cf.read(100)
chan.recv.assert_called_once_with(100)

def test_write_calls_channel_sendall(self):
chan = MagicMock()
cf = ChannelFile(chan, mode="w")
cf = self.klass(chan, mode="w")
cf.write("ohai")
chan.sendall.assert_called_once_with(b"ohai")


class TestChannelFile(ChannelFileBase):
klass = ChannelFile


class TestChannelStderrFile(object):
def test_read_calls_channel_recv_stderr(self):
chan = MagicMock()
Expand All @@ -39,3 +43,18 @@ def test_write_calls_channel_sendall(self):
cf = ChannelStderrFile(chan, mode="w")
cf.write("ohai")
chan.sendall_stderr.assert_called_once_with(b"ohai")


class TestChannelStdinFile(ChannelFileBase):
klass = ChannelStdinFile

def test_close_calls_channel_shutdown_write(self):
chan = MagicMock()
cf = ChannelStdinFile(chan, mode="wb")
cf.flush = MagicMock()
cf.close()
# Sanity check that we still call BufferedFile.close()
cf.flush.assert_called_once_with()
assert cf._closed is True
# Actual point of test
chan.shutdown_write.assert_called_once_with()
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _test_connection(self, **kwargs):
schan = self.ts.accept(1.0)

# Nobody else tests the API of exec_command so let's do it here for now
assert isinstance(stdin, paramiko.ChannelFile)
assert isinstance(stdin, paramiko.ChannelStdinFile)
assert isinstance(stdout, paramiko.ChannelFile)
assert isinstance(stderr, paramiko.ChannelStderrFile)

Expand Down

0 comments on commit 66432e2

Please sign in to comment.