Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-33062: Add SSL renegotiation and key update #8620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 98 additions & 10 deletions Lib/ssl.py
Expand Up @@ -149,6 +149,11 @@
lambda name: name.startswith('CERT_'),
source=_ssl)

_IntEnum._convert_(
'KeyUpdateTypes', __name__,
lambda name: name.startswith('KEY_UPDATE_'),
source=_ssl)

PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}

Expand Down Expand Up @@ -780,6 +785,23 @@ def version(self):
def verify_client_post_handshake(self):
return self._sslobj.verify_client_post_handshake()

def key_update(self, updatetype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With TLS 1.3, you can also force an immediate rekey. From the documentation, https://www.openssl.org/docs/man1.1.1/man3/SSL_key_update.html

Alternatively SSL_do_handshake() can be called to force the update to take place immediately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also found that during writing docs,

Alternatively do_handshake() can be called to force the update to take place immediately.

Do you think it a better approach to wrap OpenSSL API into high level API like this:

    def key_update(self, updatetype, *, deferred=False):
        self._sslobj.key_update(updatetype)
        if not deferred:
            self._sslobj.do_handshake()

Or stay with OpenSSL-style API and document the details? Now I prefer the high-level one, leaving the low-level API with the _ssl module.

self._sslobj.key_update(updatetype)

@property
def key_update_type(self):
return KeyUpdateTypes(self._sslobj.get_key_update_type())

def renegotiate(self, abbreviated=False):
if abbreviated:
self._sslobj.renegotiate_abbreviated()
else:
self._sslobj.renegotiate()

@property
def renegotiate_pending(self):
return self._sslobj.renegotiate_pending()


class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
Expand Down Expand Up @@ -1090,18 +1112,14 @@ def shutdown(self, how):
super().shutdown(how)

def unwrap(self):
if self._sslobj:
s = self._sslobj.shutdown()
self._sslobj = None
return s
else:
raise ValueError("No SSL wrapper around " + str(self))
self._ensure_wrapper()
s = self._sslobj.shutdown()
self._sslobj = None
return s

def verify_client_post_handshake(self):
if self._sslobj:
return self._sslobj.verify_client_post_handshake()
else:
raise ValueError("No SSL wrapper around " + str(self))
self._ensure_wrapper()
return self._sslobj.verify_client_post_handshake()

def _real_close(self):
self._sslobj = None
Expand Down Expand Up @@ -1190,6 +1208,76 @@ def version(self):
else:
return None

def _ensure_wrapper(self):
if not self._sslobj:
raise ValueError("No SSL wrapper around " + str(self))

def key_update(self, updatetype):
self._ensure_wrapper()
self._sslobj.key_update(updatetype)

@property
def key_update_type(self):
self._ensure_wrapper()
return KeyUpdateTypes(self._sslobj.get_key_update_type())

def renegotiate(self, abbreviated=False):
self._ensure_wrapper()
if abbreviated:
self._sslobj.renegotiate_abbreviated()
else:
self._sslobj.renegotiate()

@property
def renegotiate_pending(self):
self._ensure_wrapper()
return self._sslobj.renegotiate_pending()


for name, docstr in (
('key_update', """\
Schedule an update of the keys for the current TLS connection.

If the updatetype parameter is set to KEY_UPDATE_NOT_REQUESTED then the
sending keys for this connection will be updated and the peer will be
informed of the change. If the updatetype parameter is set to
KEY_UPDATE_REQUESTED then the sending keys for this connection will be
updated and the peer will be informed of the change along with a
request for the peer to additionally update its sending keys. It is an
error if updatetype is set to KEY_UPDATE_NONE.

key_update() must only be called after the initial handshake has been
completed and TLSv1.3 has been negotiated. The key update will not take
place until the next time an IO operation such as read() or write()
takes place on the connection. Alternatively do_handshake() can be
called to force the update to take place immediately.

Raises NotImplementedError if the TLS implementation doesn't support
TLS 1.3.)

:param updatetype: KeyUpdateTypes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python we don't use sphinx markup in docstrings, documentation is not authogenerated but written manually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for this one! I'll update docs too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach looks strange and uncommon to me. I opened #9972, which I believe is a better approach to handle common doc strings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is! I'll rebase up to #9972 once it is merged. Thanks!

"""),
('key_update_type', """\
Determine whether a key update operation has been scheduled but
not yet performed.

The type of the pending key update operation will be returned if there
is one, or KEY_UPDATE_NONE otherwise.

Raises NotImplementedError if the TLS implementation doesn't support
TLS 1.3.
"""),
('renegotiate', """\
Start the SSL/TLS renegotiation, requires TLS <= 1.2."""),
('renegotiate_pending', """\
Return True if a renegotiation or renegotiation request has been
scheduled but not yet acted on, or False otherwise."""),
):
for cls in (SSLObject, SSLSocket):
getattr(cls, name).__doc__ = docstr

del name, docstr, cls


# Python does not support forward declaration of types.
SSLContext.sslsocket_class = SSLSocket
Expand Down
190 changes: 190 additions & 0 deletions Lib/test/test_ssl.py
Expand Up @@ -1645,6 +1645,39 @@ def test_bad_server_hostname(self):
ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
server_hostname="example.org\x00evil.com")

@unittest.skipUnless(ssl.HAS_TLSv1_3,
"test requires TLSv1.3 enabled OpenSSL")
def test_invalid_key_update_type_and_still_in_init(self):
b1 = ssl.MemoryBIO()
b2 = ssl.MemoryBIO()
ctx = ssl.SSLContext()
ctx.load_cert_chain(SIGNED_CERTFILE)
server = ctx.wrap_bio(b1, b2, server_side=True)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
ctx.options |= (
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
)
sslobj = ctx.wrap_bio(b2, b1, server_side=False)

handshaking = True
while handshaking:
try:
sslobj.do_handshake()
handshaking = False
except ssl.SSLWantReadError:
handshaking = True

try:
server.do_handshake()
except ssl.SSLWantReadError:
handshaking = True

with self.assertRaisesRegex(ssl.SSLError, 'invalid key update type'):
sslobj.key_update(ssl.KEY_UPDATE_NONE)
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
with self.assertRaisesRegex(ssl.SSLError, 'still in init'):
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)


class MemoryBIOTests(unittest.TestCase):

Expand Down Expand Up @@ -2076,6 +2109,91 @@ def test_bio_read_write_data(self):
self.assertEqual(buf, b'foo\n')
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)

def test_bio_renegotiation(self):
sock = socket.socket(socket.AF_INET)
self.addCleanup(sock.close)
sock.connect(self.server_addr)
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
ctx.verify_mode = ssl.CERT_NONE
ctx.options |= ssl.OP_NO_TLSv1_3
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)

self.assertEqual(outgoing.pending, 0)
sslobj.renegotiate()
self.assertEqual(outgoing.pending, 0)
self.assertTrue(sslobj.renegotiate_pending)
req = b'FOO\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
self.assertFalse(sslobj.renegotiate_pending)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf, b'foo\n')

self.assertEqual(outgoing.pending, 0)
sslobj.renegotiate(abbreviated=True)
self.assertEqual(outgoing.pending, 0)
self.assertTrue(sslobj.renegotiate_pending)
req = b'BAR\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
self.assertFalse(sslobj.renegotiate_pending)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf, b'bar\n')

if IS_OPENSSL_1_1_1 and ssl.HAS_TLSv1_3:
with self.assertRaises(ssl.SSLError,
msg='wrong ssl version'):
sslobj.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
with self.assertRaises(ssl.SSLError,
msg='wrong ssl version'):
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)

self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)

@unittest.skipUnless(ssl.HAS_TLSv1_3,
"test requires TLSv1.3 enabled OpenSSL")
def test_bio_key_update(self):
sock = socket.socket(socket.AF_INET)
self.addCleanup(sock.close)
sock.connect(self.server_addr)
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
ctx.verify_mode = ssl.CERT_NONE
ctx.options |= (
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
)
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)

self.assertEqual(outgoing.pending, 0)
sslobj.key_update(ssl.KEY_UPDATE_REQUESTED)
self.assertEqual(outgoing.pending, 0)
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_REQUESTED)
req = b'FOO\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NONE)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf, b'foo\n')

self.assertEqual(outgoing.pending, 0)
sslobj.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
self.assertEqual(outgoing.pending, 0)
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NOT_REQUESTED)
req = b'BAR\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
self.assertEqual(sslobj.key_update_type, ssl.KEY_UPDATE_NONE)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf, b'bar\n')

with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
sslobj.renegotiate()
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
sslobj.renegotiate(abbreviated=True)

self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)


class NetworkedTests(unittest.TestCase):

Expand Down Expand Up @@ -4164,6 +4282,78 @@ def test_session_handling(self):
self.assertEqual(str(e.exception),
'Session refers to a different SSLContext.')

def test_renegotiation(self):
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.load_cert_chain(CERTFILE)
context.options |= ssl.OP_NO_TLSv1_3
with ThreadedEchoServer(context=context) as server:
with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
self.assertFalse(s.renegotiate_pending)
s.renegotiate()
self.assertTrue(s.renegotiate_pending)
s.send(b'HELLO')
self.assertEqual(s.recv(1024), b'hello')
self.assertFalse(s.renegotiate_pending)
s.send(b'WORLD')
self.assertEqual(s.recv(1024), b'world')

self.assertFalse(s.renegotiate_pending)
s.renegotiate(abbreviated=True)
self.assertTrue(s.renegotiate_pending)
s.send(b'HELLO')
self.assertEqual(s.recv(1024), b'hello')
self.assertFalse(s.renegotiate_pending)
s.send(b'WORLD')
self.assertEqual(s.recv(1024), b'world')

if IS_OPENSSL_1_1_1 and ssl.HAS_TLSv1_3:
with self.assertRaises(ssl.SSLError,
msg='wrong ssl version'):
s.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
with self.assertRaises(ssl.SSLError,
msg='wrong ssl version'):
s.key_update(ssl.KEY_UPDATE_REQUESTED)

@unittest.skipUnless(ssl.HAS_TLSv1_3,
"test requires TLSv1.3 enabled OpenSSL")
def test_key_update(self):
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.load_cert_chain(CERTFILE)
context.options |= (
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
)
with ThreadedEchoServer(context=context) as server:
with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
self.assertEqual(s.version(), 'TLSv1.3')

self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
s.key_update(ssl.KEY_UPDATE_NOT_REQUESTED)
self.assertEqual(s.key_update_type,
ssl.KEY_UPDATE_NOT_REQUESTED)
s.send(b'HELLO')
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
self.assertEqual(s.recv(1024), b'hello')
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
s.send(b'WORLD')
self.assertEqual(s.recv(1024), b'world')

self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
s.key_update(ssl.KEY_UPDATE_REQUESTED)
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_REQUESTED)
s.send(b'HELLO')
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
self.assertEqual(s.recv(1024), b'hello')
self.assertEqual(s.key_update_type, ssl.KEY_UPDATE_NONE)
s.send(b'WORLD')
self.assertEqual(s.recv(1024), b'world')

with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
s.renegotiate()
with self.assertRaises(ssl.SSLError, msg='wrong ssl version'):
s.renegotiate(abbreviated=True)


@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
class TestPostHandshakeAuth(unittest.TestCase):
Expand Down
@@ -0,0 +1 @@
Added ``renegotiate()`` and ``key_update()`` in :mod:`ssl`.