Skip to content

Commit

Permalink
Add new run_client and run_server methods for already-connected sockets
Browse files Browse the repository at this point in the history
This commit adds top-level functions run_client() and run_server() for
running an AsyncSSH client or sevrer on an already-connected socket.
This can also be done via the new `sock` argument to connect() and
connect_reverse(), but these new method names read a little better
and automatically enforce restrictions about what arguments are mutually
exclusive with the `sock` argument.

This commit also fixes a couple of copy/paste errors in comments, and
adds some additional unit tests.
  • Loading branch information
ronf committed Aug 10, 2022
1 parent 3be488d commit e21acc3
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 23 deletions.
5 changes: 3 additions & 2 deletions asyncssh/__init__.py
Expand Up @@ -46,7 +46,7 @@
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
from .connection import create_connection, create_server, connect, listen
from .connection import connect_reverse, listen_reverse, get_server_host_key
from .connection import get_server_auth_methods
from .connection import get_server_auth_methods, run_client, run_server

from .editor import SSHLineEditorChannel

Expand Down Expand Up @@ -163,5 +163,6 @@
'match_known_hosts', 'read_authorized_keys', 'read_certificate',
'read_certificate_list', 'read_known_hosts', 'read_private_key',
'read_private_key_list', 'read_public_key', 'read_public_key_list',
'scp', 'set_debug_level', 'set_log_level', 'set_sftp_log_level',
'run_client', 'run_server', 'scp', 'set_debug_level', 'set_log_level',
'set_sftp_log_level',
]
146 changes: 126 additions & 20 deletions asyncssh/connection.py
Expand Up @@ -399,7 +399,13 @@ async def _connect(options: 'SSHConnectionOptions',
tunnel: _TunnelConnectorProtocol

try:
if new_tunnel:
if sock:
logger.info('%s already-connected socket', msg)

_, session = await loop.create_connection(conn_factory, sock=sock)

conn = cast(_Conn, session)
elif new_tunnel:
new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel)

# pylint: disable=broad-except
Expand Down Expand Up @@ -428,13 +434,9 @@ async def _connect(options: 'SSHConnectionOptions',
else:
logger.info('%s %s', msg, (host, port))

if sock:
_, session = await loop.create_connection(
conn_factory, sock=sock)
else:
_, session = await loop.create_connection(
conn_factory, host, port, family=family,
flags=flags, local_addr=local_addr)
_, session = await loop.create_connection(
conn_factory, host, port, family=family,
flags=flags, local_addr=local_addr)

conn = cast(_Conn, session)
except asyncio.CancelledError:
Expand Down Expand Up @@ -476,7 +478,13 @@ def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession:
new_tunnel = await _open_tunnel(tunnel, options.passphrase)
tunnel: _TunnelListenerProtocol

if new_tunnel:
if sock:
logger.info('%s already-connected socket', msg)

server: asyncio.AbstractServer = await loop.create_server(
conn_factory, sock=sock, backlog=backlog,
reuse_address=reuse_address, reuse_port=reuse_port)
elif new_tunnel:
new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel)

# pylint: disable=broad-except
Expand All @@ -499,15 +507,10 @@ def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession:
else:
logger.info('%s %s', msg, (host, port))

if sock:
server = await loop.create_server(
conn_factory, sock=sock, backlog=backlog,
reuse_address=reuse_address, reuse_port=reuse_port)
else:
server = await loop.create_server(
conn_factory, host, port, family=family, flags=flags,
backlog=backlog, reuse_address=reuse_address,
reuse_port=reuse_port)
server = await loop.create_server(
conn_factory, host, port, family=family, flags=flags,
backlog=backlog, reuse_address=reuse_address,
reuse_port=reuse_port)

return SSHAcceptor(server, options)

Expand Down Expand Up @@ -7623,6 +7626,109 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore
self.max_pktsize = max_pktsize


@async_context_manager
async def run_client(sock: socket.socket, config: DefTuple[ConfigPaths] = (),
options: Optional[SSHClientConnectionOptions] = None,
**kwargs: object) -> SSHClientConnection:
"""Start an SSH client connection on an already-connected socket
This function is a coroutine which starts an SSH client on an
existing already-connected socket. It can be used instead of
:func:`connect` when a socket is connected outside of asyncio.
:param sock:
An existing already-connected socket to run an SSH client on,
instead of opening up a new connection.
:param config: (optional)
Paths to OpenSSH client configuration files to load. This
configuration will be used as a fallback to override the
defaults for settings which are not explcitly specified using
AsyncSSH's configuration options. If no paths are specified and
no config paths were set when constructing the `options`
argument (if any), an attempt will be made to load the
configuration from the file :file:`.ssh/config`. If this
argument is explicitly set to `None`, no new configuration
files will be loaded, but any configuration loaded when
constructing the `options` argument will still apply. See
:ref:`SupportedClientConfigOptions` for details on what
configuration options are currently supported.
:param options: (optional)
Options to use when establishing the SSH client connection. These
options can be specified either through this parameter or as direct
keyword arguments to this function.
:type sock: :class:`socket.socket`
:type config: `list` of `str`
:type options: :class:`SSHClientConnectionOptions`
:returns: :class:`SSHClientConnection`
"""

def conn_factory() -> SSHClientConnection:
"""Return an SSH client connection factory"""

return SSHClientConnection(loop, new_options, wait='auth')

loop = asyncio.get_event_loop()

new_options = cast(SSHClientConnectionOptions, await _run_in_executor(
loop, SSHClientConnectionOptions, options, config=config, **kwargs))

return await asyncio.wait_for(
_connect(new_options, loop, 0, sock, conn_factory,
'Starting SSH client on'),
timeout=new_options.connect_timeout)


@async_context_manager
async def run_server(sock: socket.socket, config: DefTuple[ConfigPaths] = (),
options: Optional[SSHServerConnectionOptions] = None,
**kwargs: object) -> SSHServerConnection:
"""Start an SSH server connection on an already-connected socket
This function is a coroutine which starts an SSH server on an
existing already-connected TCP socket. It can be used instead of
:func:`listen` when connections are accepted outside of asyncio.
:param sock:
An existing already-connected socket to run SSH over, instead of
opening up a new connection.
:param config: (optional)
Paths to OpenSSH server configuration files to load. This
configuration will be used as a fallback to override the
defaults for settings which are not explcitly specified using
AsyncSSH's configuration options. By default, no OpenSSH
configuration files will be loaded. See
:ref:`SupportedServerConfigOptions` for details on what
configuration options are currently supported.
:param options: (optional)
Options to use when starting the reverse-direction SSH server.
These options can be specified either through this parameter
or as direct keyword arguments to this function.
:type sock: :class:`socket.socket`
:type config: `list` of `str`
:type options: :class:`SSHServerConnectionOptions`
:returns: :class:`SSHServerConnection`
"""

def conn_factory() -> SSHServerConnection:
"""Return an SSH server connection factory"""

return SSHServerConnection(loop, new_options, wait='auth')

loop = asyncio.get_event_loop()

new_options = cast(SSHServerConnectionOptions, await _run_in_executor(
loop, SSHServerConnectionOptions, options, config=config, **kwargs))

return await asyncio.wait_for(
_connect(new_options, loop, 0, sock, conn_factory,
'Starting SSH server on'),
timeout=new_options.connect_timeout)


@async_context_manager
async def connect(host = '', port: DefTuple[int] = (), *,
tunnel: DefTuple[_TunnelConnector] = (),
Expand Down Expand Up @@ -7804,7 +7910,7 @@ async def connect_reverse(
"""

def conn_factory() -> SSHServerConnection:
"""Return an SSH client connection factory"""
"""Return an SSH server connection factory"""

return SSHServerConnection(loop, new_options, wait='auth')

Expand Down Expand Up @@ -7912,7 +8018,7 @@ async def listen(host = '', port: DefTuple[int] = (), *,
"""

def conn_factory() -> SSHServerConnection:
"""Return an SSH client connection factory"""
"""Return an SSH server connection factory"""

return SSHServerConnection(loop, new_options, acceptor, error_handler)

Expand Down
10 changes: 10 additions & 0 deletions docs/api.rst
Expand Up @@ -145,6 +145,16 @@ listen_reverse

.. autofunction:: listen_reverse

run_client
----------

.. autofunction:: run_client

run_server
----------

.. autofunction:: run_server

create_connection
-----------------

Expand Down
16 changes: 16 additions & 0 deletions tests/server.py
Expand Up @@ -326,6 +326,22 @@ async def connect_reverse(self, options=None, gss_host=None, **kwargs):
self._server_port,
options=options, **kwargs)

@async_context_manager
async def run_client(self, sock, config=(), options=None, **kwargs):
"""Run an SSH client on an already-connected socket"""

return await asyncssh.run_client(sock, config, options, **kwargs)

@async_context_manager
async def run_server(self, sock, config=(), options=None, **kwargs):
"""Run an SSH server on an already-connected socket"""

options = asyncssh.SSHServerConnectionOptions(options,
server_factory=Server,
server_host_keys=['skey'])

return await asyncssh.run_server(sock, config, options, **kwargs)

async def create_connection(self, client_factory, **kwargs):
"""Create a connection to the test server"""

Expand Down
36 changes: 35 additions & 1 deletion tests/test_connection.py
Expand Up @@ -399,7 +399,7 @@ async def test_connect(self):

@asynctest
async def test_connect_sock(self):
"""Test connecting using already-connected socket"""
"""Test connecting using an already-connected socket"""

sock = socket.socket()
await self.loop.sock_connect(sock, (self._server_addr,
Expand All @@ -408,6 +408,17 @@ async def test_connect_sock(self):
async with asyncssh.connect(sock=sock):
pass

@asynctest
async def test_run_client(self):
"""Test running an SSH client on an already-connected socket"""

sock = socket.socket()
await self.loop.sock_connect(sock, (self._server_addr,
self._server_port))

async with self.run_client(sock):
pass

@asynctest
async def test_connect_encrypted_key(self):
"""Test connecting with encrytped client key and no passphrase"""
Expand Down Expand Up @@ -1522,6 +1533,7 @@ async def test_connect(self):
async with self.connect():
pass


class _TestConnectionAsyncAcceptor(ServerTestCase):
"""Unit test for async acceptor"""

Expand Down Expand Up @@ -1590,6 +1602,28 @@ async def test_connect_reverse(self):
async with self.connect_reverse():
pass

@asynctest
async def test_connect_reverse_sock(self):
"""Test reverse connection using an already-connected socket"""

sock = socket.socket()
await self.loop.sock_connect(sock, (self._server_addr,
self._server_port))

async with self.connect_reverse(sock=sock):
pass

@asynctest
async def test_run_server(self):
"""Test running an SSH server on an already-connected socket"""

sock = socket.socket()
await self.loop.sock_connect(sock, (self._server_addr,
self._server_port))

async with self.run_server(sock):
pass

@unittest.skipUnless(_nc_available, 'Netcat not available')
@asynctest
async def test_connect_reverse_proxy(self):
Expand Down

0 comments on commit e21acc3

Please sign in to comment.