Skip to content

Commit

Permalink
Reset sequence numbers on rekey
Browse files Browse the repository at this point in the history
  • Loading branch information
bitprophet committed Dec 17, 2023
1 parent 75e311d commit fa46de7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
6 changes: 6 additions & 0 deletions paramiko/packet.py
Expand Up @@ -130,6 +130,12 @@ def __init__(self, socket):
def closed(self):
return self.__closed

def reset_seqno_out(self):
self.__sequence_number_out = 0

def reset_seqno_in(self):
self.__sequence_number_in = 0

def set_log(self, log):
"""
Set the Python log object to use for logging.
Expand Down
22 changes: 20 additions & 2 deletions paramiko/transport.py
Expand Up @@ -2499,9 +2499,13 @@ def _parse_kex_init(self, m):

# CVE mitigation: expect zeroed-out seqno anytime we are performing kex
# init phase, if strict mode was negotiated.
if self.agreed_on_strict_kex and m.seqno != 0:
if (
self.agreed_on_strict_kex
and not self.initial_kex_done
and m.seqno != 0
):
raise MessageOrderError(
f"Got nonzero seqno ({m.seqno}) during strict KEXINIT!"
"In strict-kex mode, but KEXINIT was not the first packet!"
)

# as a server, we pick the first item in the client's list that we
Expand Down Expand Up @@ -2703,13 +2707,27 @@ def _activate_inbound(self):
):
self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
# Reset inbound sequence number if strict mode.
if self.agreed_on_strict_kex:
self._log(
DEBUG,
f"Resetting inbound seqno after NEWKEYS due to strict mode",
)
self.packetizer.reset_seqno_in()

def _activate_outbound(self):
"""switch on newly negotiated encryption parameters for
outbound traffic"""
m = Message()
m.add_byte(cMSG_NEWKEYS)
self._send_message(m)
# Reset outbound sequence number if strict mode.
if self.agreed_on_strict_kex:
self._log(
DEBUG,
f"Resetting outbound sequence number after NEWKEYS due to strict mode",
)
self.packetizer.reset_seqno_out()
block_size = self._cipher_info[self.local_cipher]["block-size"]
if self.server_mode:
IV_out = self._compute_key("B", block_size)
Expand Down
25 changes: 23 additions & 2 deletions tests/test_transport.py
Expand Up @@ -1345,5 +1345,26 @@ def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self):
):
pass # kexinit happens at connect...

def test_sequence_numbers_reset_on_newkeys(self):
skip()
def test_sequence_numbers_reset_on_newkeys_when_strict(self):
with server(defer=True) as (tc, ts):
# When in strict mode, these should all be zero or close to it
# (post-kexinit, pre-auth).
# Server->client will be 1 (EXT_INFO got sent after NEWKEYS)
assert tc.packetizer._Packetizer__sequence_number_in == 1
assert ts.packetizer._Packetizer__sequence_number_out == 1
# Client->server will be 0
assert tc.packetizer._Packetizer__sequence_number_out == 0
assert ts.packetizer._Packetizer__sequence_number_in == 0

def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self):
with server(defer=True, client_init=dict(strict_kex=False)) as (
tc,
ts,
):
# When not in strict mode, these will all be ~3-4 or so
# (post-kexinit, pre-auth). Not encoding exact values as it will
# change anytime we mess with the test harness...
assert tc.packetizer._Packetizer__sequence_number_in != 0
assert tc.packetizer._Packetizer__sequence_number_out != 0
assert ts.packetizer._Packetizer__sequence_number_in != 0
assert ts.packetizer._Packetizer__sequence_number_out != 0

0 comments on commit fa46de7

Please sign in to comment.