Skip to content

Commit

Permalink
Merge pull request #293 from tomato42/0-rtt-tolerance
Browse files Browse the repository at this point in the history
0-RTT tolerance
  • Loading branch information
tomato42 authored Jul 19, 2018
2 parents 21b0b4b + 10b8e26 commit e1a037e
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 47 deletions.
11 changes: 11 additions & 0 deletions tlslite/handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ class HandshakeSettings(object):
:vartype psk_modes: list
:ivar psk_modes: acceptable modes for the PSK key exchange in TLS 1.3
:ivar int max_early_data: maximum number of bytes acceptable for 0-RTT
early_data processing. In other words, how many bytes will the server
try to process, but ignore, in case the Client Hello includes
early_data extension.
"""

def _init_key_settings(self):
Expand Down Expand Up @@ -231,6 +236,7 @@ def _init_misc_extensions(self):
self.ticketKeys = []
self.ticketCipher = "aes256gcm"
self.ticketLifetime = 24 * 60 * 60
self.max_early_data = 2 ** 14 + 16 # full record + tag

def __init__(self):
"""Initialise default values for settings."""
Expand Down Expand Up @@ -435,6 +441,10 @@ def _sanityCheckTicketSettings(other):
raise ValueError("Ticket lifetime must be a positive integer "
"smaller or equal 604800 (7 days)")

# while not ticket setting per-se, it is related to session tickets
if not 0 < other.max_early_data <= 2**64:
raise ValueError("max_early_data must be between 0 and 2GiB")

def _copy_cipher_settings(self, other):
"""Copy values related to cipher selection."""
other.cipherNames = self.cipherNames
Expand All @@ -457,6 +467,7 @@ def _copy_extension_settings(self, other):
other.ticketKeys = self.ticketKeys
other.ticketCipher = self.ticketCipher
other.ticketLifetime = self.ticketLifetime
other.max_early_data = self.max_early_data

@staticmethod
def _remove_all_matches(values, needle):
Expand Down
154 changes: 113 additions & 41 deletions tlslite/recordlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import socket
import errno
import copy
try:
# in python 3 the native zip() returns iterator
from itertools import izip
Expand Down Expand Up @@ -242,6 +243,17 @@ def getSeqNumBytes(self):
self.seqnum += 1
return writer.bytes

def __copy__(self):
"""Return a copy of the object."""
ret = ConnectionState()
ret.macContext = copy.copy(self.macContext)
ret.encContext = copy.copy(self.encContext)
ret.fixedNonce = self.fixedNonce
ret.seqnum = self.seqnum
ret.encryptThenMAC = self.encryptThenMAC
return ret


class RecordLayer(object):

"""
Expand All @@ -253,6 +265,11 @@ class RecordLayer(object):
:ivar handshake_finished: used in SSL2, True if handshake protocol is over
:ivar tls13record: if True, the record layer will use the TLS 1.3 version
and content type hiding
:ivar bool early_data_ok: if True, it's ok to ignore undecryptable records
up to the size of max_early_data (sum of payloads)
:ivar int max_early_data: maximum number of bytes that will be processed
before aborting the connection on data that can not be validated,
works only if early_data_ok is set to True
"""

def __init__(self, sock):
Expand All @@ -273,6 +290,30 @@ def __init__(self, sock):

self.padding_cb = None

self._early_data_ok = False
self.max_early_data = 0
self._early_data_processed = 0

@property
def early_data_ok(self):
"""
Set or get the state of early data acceptability.
If processing of the early_data records is to suceed, even if the
encryption is not correct, set this property to True. It will be
automatically reset to False as soon as a decryptable record is
processed.
Use max_early_data to set the limit of the total size of records
that will be processed like this.
"""
return self._early_data_ok

@early_data_ok.setter
def early_data_ok(self, val):
self._early_data_processed = 0
self._early_data_ok = val

@property
def encryptThenMAC(self):
"""
Expand Down Expand Up @@ -814,49 +855,80 @@ def recvRecord(self):
:raises TLSBadRecordMAC: when record has bad MAC or padding
:raises socket.error: when reading from socket was unsuccessful
"""
result = None
for result in self._recordSocket.recv():
if result in (0, 1):
yield result
else: break
assert result is not None

(header, data) = result

if isinstance(header, RecordHeader2):
data = self._decryptSSL2(data, header.padding)
if self.handshake_finished:
header.type = ContentType.application_data
# in TLS 1.3, the other party may send an unprotected CCS message
# at any point in connection
elif self._is_tls13_plus() and \
header.type == ContentType.change_cipher_spec:
pass
elif self._readState and \
self._readState.encContext and \
self._readState.encContext.isAEAD:
data = self._decryptAndUnseal(header, data)
elif self._readState and self._readState.encryptThenMAC:
data = self._macThenDecrypt(header.type, data)
elif self._readState and \
self._readState.encContext and \
self._readState.encContext.isBlockCipher:
data = self._decryptThenMAC(header.type, data)
else:
data = self._decryptStreamThenMAC(header.type, data)

# TLS 1.3 encrypts the type, CCS is not encrypted
if self._is_tls13_plus() and self._readState and \
self._readState.encContext and\
header.type != ContentType.change_cipher_spec:
data, contentType = self._tls13_de_pad(data)
header = RecordHeader3().create((3, 4), contentType, len(data))
while True:
result = None
for result in self._recordSocket.recv():
if result in (0, 1):
yield result
else: break
assert result is not None

# RFC 5246, section 6.2.1
if len(data) > 2**14:
raise TLSRecordOverflow()
(header, data) = result
# as trying decryption increments sequence number, we need to
# keep the old one (we do copy of the whole object in case
# some cipher has an internal state itself)
read_state_copy = None
if self.early_data_ok:
# do the copy only when needed
read_state_copy = copy.copy(self._readState)

yield (header, Parser(data))
try:
if isinstance(header, RecordHeader2):
data = self._decryptSSL2(data, header.padding)
if self.handshake_finished:
header.type = ContentType.application_data
# in TLS 1.3, the other party may send an unprotected CCS
# message at any point in connection
elif self._is_tls13_plus() and \
header.type == ContentType.change_cipher_spec:
pass
elif self._readState and \
self._readState.encContext and \
self._readState.encContext.isAEAD:
data = self._decryptAndUnseal(header, data)
elif self._readState and self._readState.encryptThenMAC:
data = self._macThenDecrypt(header.type, data)
elif self._readState and \
self._readState.encContext and \
self._readState.encContext.isBlockCipher:
data = self._decryptThenMAC(header.type, data)
else:
data = self._decryptStreamThenMAC(header.type, data)
# if we don't have an encryption context established
# and early data is ok, that means we have received
# encrypted record in case the type of record is
# application_data (from TLS 1.3)
if not self._readState.encContext \
and not self._readState.macContext \
and self.early_data_ok and \
header.type == ContentType.application_data:
raise TLSBadRecordMAC("early data received")
except TLSBadRecordMAC:
if self.early_data_ok and (
self._early_data_processed + len(data)
< self.max_early_data):
# ignore exception, retry reading
self._early_data_processed += len(data)
# reload state for decryption
self._readState = read_state_copy
continue
raise
# as soon as we're able to decrypt messages again, we must
# start checking the MACs
self.early_data_ok = False

# TLS 1.3 encrypts the type, CCS is not encrypted
if self._is_tls13_plus() and self._readState and \
self._readState.encContext and\
header.type != ContentType.change_cipher_spec:
data, contentType = self._tls13_de_pad(data)
header = RecordHeader3().create((3, 4), contentType, len(data))

# RFC 5246, section 6.2.1
if len(data) > 2**14:
raise TLSRecordOverflow()

yield (header, Parser(data))

#
# cryptography state methods
Expand Down
32 changes: 26 additions & 6 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,8 +2483,8 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB,
yield result

# sanity check the TLS 1.3 extensions
ext = clientHello.getExtension(ExtensionType.supported_versions)
if ext and TLS_1_3_DRAFT in ext.versions:
ver_ext = clientHello.getExtension(ExtensionType.supported_versions)
if ver_ext and TLS_1_3_DRAFT in ver_ext.versions:
psk = clientHello.getExtension(ExtensionType.pre_shared_key)
if psk:
psk_modes = clientHello.getExtension(
Expand Down Expand Up @@ -2577,12 +2577,28 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB,
.format(GroupName.toStr(mismatch))):
yield result

versionsExt = clientHello.getExtension(ExtensionType
.supported_versions)
early_data = clientHello.getExtension(ExtensionType.early_data)
if early_data:
if early_data.extData:
for result in self._sendError(
AlertDescription.decode_error,
"malformed early_data extension"):
yield result
if not psk:
for result in self._sendError(
AlertDescription.illegal_parameter,
"early_data without PSK extension"):
yield result
# if early data comes from version we don't support, client
# MUST (section D.3 draft 28) abort the connection so we
# enable early data tolerance only when versions match
self._recordLayer.max_early_data = settings.max_early_data
self._recordLayer.early_data_ok = True

high_ver = None
if versionsExt:
if ver_ext:
high_ver = getFirstMatching(settings.versions,
versionsExt.versions)
ver_ext.versions)
if not high_ver:
for result in self._sendError(
AlertDescription.protocol_version,
Expand Down Expand Up @@ -3034,6 +3050,10 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB,
AlertDescription.illegal_parameter,
"PSK extension not last in client hello"):
yield result
# early_data extension MUST be dropped
old_ext = clientHello1.getExtension(ExtensionType.early_data)
if old_ext:
clientHello1.extensions.remove(old_ext)

if clientHello1 != clientHello:
for result in self._sendError(AlertDescription
Expand Down
8 changes: 8 additions & 0 deletions tlslite/tlsrecordlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def _getNextRecord(self):
header = RecordHeader3().create(self.version, ret[0], 0)
yield header, Parser(ret[1])

# CCS can be sent before early_data but processing it will
# remove the flag from record layer, so reset it
early_data_ok = self._recordLayer.early_data_ok

# when the message buffer is empty, read next record from socket
for result in self._getNextRecordFromSocket():
if result in (0, 1):
Expand All @@ -885,6 +889,10 @@ def _getNextRecord(self):
if header.type == ContentType.application_data or \
(self.version > (3, 3) and
header.type == ContentType.change_cipher_spec):
# CCS doesn't change the status of undecryptable
# records
if header.type == ContentType.change_cipher_spec:
self._recordLayer.early_data_ok = early_data_ok
yield (header, parser)
# If it's an SSLv2 ClientHello, we can return it as well, since
# it's the only ssl2 type we support
Expand Down
9 changes: 9 additions & 0 deletions unit_tests/test_tlslite_handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,15 @@ def test_invalid_psk_mode(self):

self.assertIn("psk_pqe_ke", str(e.exception))

def test_invalid_max_early_data(self):
hs = HandshakeSettings()
hs.max_early_data = -1

with self.assertRaises(ValueError) as e:
hs.validate()

self.assertIn("max_early_data", str(e.exception))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit e1a037e

Please sign in to comment.