Skip to content

Commit

Permalink
Harden AsyncSSH state machine against message injection during handshake
Browse files Browse the repository at this point in the history
This commit puts additional restrictions on when messages are accepted
during the SSH handshake to avoid message injection attacks from a
rogue client or server.

More detailed information will be available in CVE-2023-46445 and
CVE-2023-46446, to be published shortly.

Thanks go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for
identifying and reporting these vulnerabilities and providing
detailed analysis and suggestions for how to protect against them,
as well as review comments on the proposed fix.
  • Loading branch information
ronf committed Nov 9, 2023
1 parent f67234f commit 83e43f5
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 76 deletions.
132 changes: 83 additions & 49 deletions asyncssh/connection.py
Expand Up @@ -899,6 +899,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
self._can_send_ext_info = False
self._extensions_to_send: 'OrderedDict[bytes, bytes]' = OrderedDict()

self._can_recv_ext_info = False

self._server_sig_algs: Set[bytes] = set()

self._next_service: Optional[bytes] = None
Expand All @@ -908,6 +910,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
self._auth: Optional[Auth] = None
self._auth_in_progress = False
self._auth_complete = False
self._auth_final = False
self._auth_methods = [b'none']
self._auth_was_trivial = True
self._username = ''
Expand Down Expand Up @@ -1538,15 +1541,25 @@ def _recv_packet(self) -> bool:
skip_reason = ''
exc_reason = ''

if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
if self._ignore_first_kex: # pragma: no cover
skip_reason = 'ignored first kex'
self._ignore_first_kex = False
if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
if self._kex:
if self._ignore_first_kex: # pragma: no cover
skip_reason = 'ignored first kex'
self._ignore_first_kex = False
else:
handler = self._kex
else:
handler = self._kex
elif (self._auth and
MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST):
handler = self._auth
skip_reason = 'kex not in progress'
exc_reason = 'Key exchange not in progress'
elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST:
if self._auth:
handler = self._auth
else:
skip_reason = 'auth not in progress'
exc_reason = 'Authentication not in progress'
elif pkttype > MSG_KEX_LAST and not self._recv_encryption:
skip_reason = 'invalid request before kex complete'
exc_reason = 'Invalid request before key exchange was complete'
elif pkttype > MSG_USERAUTH_LAST and not self._auth_complete:
skip_reason = 'invalid request before auth complete'
exc_reason = 'Invalid request before authentication was complete'
Expand Down Expand Up @@ -1579,6 +1592,9 @@ def _recv_packet(self) -> bool:
if exc_reason:
raise ProtocolError(exc_reason)

if pkttype > MSG_USERAUTH_LAST:
self._auth_final = True

if self._transport:
self._recv_seq = (seq + 1) & 0xffffffff
self._recv_handler = self._recv_pkthdr
Expand All @@ -1596,9 +1612,7 @@ def send_packet(self, pkttype: int, *args: bytes,
self._send_kexinit()
self._kexinit_sent = True

if (((pkttype in {MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or
pkttype > MSG_KEX_LAST) and not self._kex_complete) or
(pkttype == MSG_USERAUTH_BANNER and
if ((pkttype == MSG_USERAUTH_BANNER and
not (self._auth_in_progress or self._auth_complete)) or
(pkttype > MSG_USERAUTH_LAST and not self._auth_complete)):
self._deferred_packets.append((pkttype, args))
Expand Down Expand Up @@ -1810,9 +1824,11 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
not self._waiter.cancelled():
self._waiter.set_result(None)
self._wait = None
else:
self.send_service_request(_USERAUTH_SERVICE)
return
else:
self._extensions_to_send[b'server-sig-algs'] = \
b','.join(self._sig_algs)

self._send_encryption = next_enc_sc
self._send_enchdrlen = 1 if etm_sc else 5
self._send_blocksize = max(8, enc_blocksize_sc)
Expand All @@ -1833,17 +1849,18 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
recv_mac=self._mac_alg_cs.decode('ascii'),
recv_compression=self._cmp_alg_cs.decode('ascii'))

if first_kex:
self._next_service = _USERAUTH_SERVICE

self._extensions_to_send[b'server-sig-algs'] = \
b','.join(self._sig_algs)

if self._can_send_ext_info:
self._send_ext_info()
self._can_send_ext_info = False

self._kex_complete = True

if first_kex:
if self.is_client():
self.send_service_request(_USERAUTH_SERVICE)
else:
self._next_service = _USERAUTH_SERVICE

self._send_deferred_packets()

def send_service_request(self, service: bytes) -> None:
Expand Down Expand Up @@ -2080,18 +2097,25 @@ def _process_service_request(self, _pkttype: int, _pktid: int,
service = packet.get_string()
packet.check_end()

if service == self._next_service:
self.logger.debug2('Accepting request for service %s', service)
if self.is_client():
raise ProtocolError('Unexpected service request received')

self.send_packet(MSG_SERVICE_ACCEPT, String(service))
if not self._recv_encryption:
raise ProtocolError('Service request received before kex complete')

if (self.is_server() and # pragma: no branch
not self._auth_in_progress and
service == _USERAUTH_SERVICE):
self._auth_in_progress = True
self._send_deferred_packets()
else:
raise ServiceNotAvailable('Unexpected service request received')
if service != self._next_service:
raise ServiceNotAvailable('Unexpected service in service request')

self.logger.debug2('Accepting request for service %s', service)

self.send_packet(MSG_SERVICE_ACCEPT, String(service))

self._next_service = None

if service == _USERAUTH_SERVICE: # pragma: no branch
self._auth_in_progress = True
self._can_recv_ext_info = False
self._send_deferred_packets()

def _process_service_accept(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
Expand All @@ -2100,27 +2124,35 @@ def _process_service_accept(self, _pkttype: int, _pktid: int,
service = packet.get_string()
packet.check_end()

if service == self._next_service:
self.logger.debug2('Request for service %s accepted', service)
if self.is_server():
raise ProtocolError('Unexpected service accept received')

self._next_service = None
if not self._recv_encryption:
raise ProtocolError('Service accept received before kex complete')

if (self.is_client() and # pragma: no branch
service == _USERAUTH_SERVICE):
self.logger.info('Beginning auth for user %s', self._username)
if service != self._next_service:
raise ServiceNotAvailable('Unexpected service in service accept')

self._auth_in_progress = True
self.logger.debug2('Request for service %s accepted', service)

# This method is only in SSHClientConnection
# pylint: disable=no-member
cast('SSHClientConnection', self).try_next_auth()
else:
raise ServiceNotAvailable('Unexpected service accept received')
self._next_service = None

if service == _USERAUTH_SERVICE: # pragma: no branch
self.logger.info('Beginning auth for user %s', self._username)

self._auth_in_progress = True

# This method is only in SSHClientConnection
# pylint: disable=no-member
cast('SSHClientConnection', self).try_next_auth()

def _process_ext_info(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process extension information"""

if not self._can_recv_ext_info:
raise ProtocolError('Unexpected ext_info received')

extensions: Dict[bytes, bytes] = {}

self.logger.debug2('Received extension info')
Expand Down Expand Up @@ -2246,6 +2278,7 @@ def _process_newkeys(self, _pkttype: int, _pktid: int,
self._decompress_after_auth = self._next_decompress_after_auth

self._next_recv_encryption = None
self._can_recv_ext_info = True
else:
raise ProtocolError('New keys not negotiated')

Expand Down Expand Up @@ -2273,8 +2306,10 @@ def _process_userauth_request(self, _pkttype: int, _pktid: int,
if self.is_client():
raise ProtocolError('Unexpected userauth request')
elif self._auth_complete:
# Silently ignore requests if we're already authenticated
pass
# Silently ignore additional auth requests after auth succeeds,
# until the client sends a non-auth message
if self._auth_final:
raise ProtocolError('Unexpected userauth request')
else:
if username != self._username:
self.logger.info('Beginning auth for user %s', username)
Expand Down Expand Up @@ -2316,7 +2351,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes,
self._auth = lookup_server_auth(cast(SSHServerConnection, self),
self._username, method, packet)

def _process_userauth_failure(self, _pkttype: int, pktid: int,
def _process_userauth_failure(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a user authentication failure response"""

Expand Down Expand Up @@ -2356,10 +2391,9 @@ def _process_userauth_failure(self, _pkttype: int, pktid: int,
# pylint: disable=no-member
cast(SSHClientConnection, self).try_next_auth()
else:
self.logger.debug2('Unexpected userauth failure response')
self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid))
raise ProtocolError('Unexpected userauth failure response')

def _process_userauth_success(self, _pkttype: int, pktid: int,
def _process_userauth_success(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a user authentication success response"""

Expand All @@ -2385,6 +2419,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
self._auth = None
self._auth_in_progress = False
self._auth_complete = True
self._can_recv_ext_info = False

if self._agent:
self._agent.close()
Expand Down Expand Up @@ -2412,8 +2447,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
self._waiter.set_result(None)
self._wait = None
else:
self.logger.debug2('Unexpected userauth success response')
self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid))
raise ProtocolError('Unexpected userauth success response')

def _process_userauth_banner(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
Expand Down

0 comments on commit 83e43f5

Please sign in to comment.