From 84a121ee4b1cbc2daea1f50f57eb2cef7d0b8413 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Mon, 31 Mar 2014 20:30:25 +0100 Subject: [PATCH 1/5] Add support for Next Protocol Negotiation. --- OpenSSL/SSL.py | 75 ++++++++++++++++++++++++++++++++++++++ OpenSSL/test/test_ssl.py | 78 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 7b1cbc1b4..e754a7e5d 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -293,6 +293,10 @@ def __init__(self, method): self._info_callback = None self._tlsext_servername_callback = None self._app_data = None + self._npn_advertise_callback = None + self._npn_advertise_callback_args = None + self._npn_select_callback = None + self._npn_select_callback_args = None # SSL_CTX_set_app_data(self->ctx, self); # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | @@ -809,6 +813,64 @@ def wrapper(ssl, alert, arg): _lib.SSL_CTX_set_tlsext_servername_callback( self._context, self._tlsext_servername_callback) + + def set_npn_advertise_callback(self, callback): + """ + Specify a callback function that will be called when offering Next + Protocol Negotiation. + + :param callback: The callback function. It will be invoked with one + argument, the Connection instance. It should return a Python + bytestring, like b'\\x08http/1.1\\x06spdy/2'. + """ + @wraps(callback) + def wrapper(ssl, out, outlen, arg): + outstr = callback(Connection._reverse_mapping[ssl]) + self._npn_advertise_callback_args = [ + _ffi.new("unsigned int *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = self._npn_advertise_callback_args[0][0] + out[0] = self._npn_advertise_callback_args[1] + return 0 + + self._npn_advertise_callback = _ffi.callback( + "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", + wrapper) + _lib.SSL_CTX_set_next_protos_advertised_cb( + self._context, self._npn_advertise_callback, _ffi.NULL) + + + def set_npn_select_callback(self, callback): + """ + Specify a callback function that will be called when a server offers + Next Protocol Negotiation options. + + :param callback: The callback function. It will be invoked with two + arguments: the Connection, and a list of offered protocols as + length-prefixed strings in a bytestring, e.g. + b'\\x08http/1.1\\x06spdy/2'. It should return one of those + bytestrings, the chosen protocol. + """ + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + outstr = callback( + Connection._reverse_mapping[ssl], _ffi.string(in_)) + self._npn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = self._npn_select_callback_args[0][0] + out[0] = self._npn_select_callback_args[1] + return 0 + + self._npn_select_callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper) + _lib.SSL_CTX_set_next_proto_select_cb( + self._context, self._npn_select_callback, _ffi.NULL) + ContextType = Context @@ -1550,6 +1612,19 @@ def get_cipher_version(self): version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher)) return version.decode("utf-8") + def get_next_proto_negotiated(self): + """ + Get the protocol that was negotiated by NPN. + """ + data = _ffi.new("unsigned char **") + data_len = _ffi.new("unsigned int *") + + _lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len) + + if not data_len[0]: + return "" + else: + return _ffi.string(data[0]) ConnectionType = Connection diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 6409b8ee1..404f8b9c7 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1434,6 +1434,84 @@ def servername(conn): self.assertEqual([(server, b("foo1.example.com"))], args) +class NextProtoNegotiationTests(TestCase, _LoopbackMixin): + """ + Test for Next Protocol Negotiation in PyOpenSSL. + """ + def test_npn_success(self): + advertise_args =[] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return b('\x08http/1.1\x06spdy/2') + def select(conn, options): + select_args.append((conn, options)) + return b('spdy/2') + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args) + + self.assertEqual(server.get_next_proto_negotiated(), b('spdy/2')) + self.assertEqual(client.get_next_proto_negotiated(), b('spdy/2')) + + + def test_npn_client_fail(self): + advertise_args =[] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return b('\x08http/1.1\x06spdy/2') + def select(conn, options): + select_args.append((conn, options)) + return b('') + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) + + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args) + + class SessionTests(TestCase): """ From be3e7b81e278cf0668c0f250a628926533969640 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Sat, 10 May 2014 09:48:55 +0100 Subject: [PATCH 2/5] Make NPN markups. --- OpenSSL/SSL.py | 81 +++++++++++++++++++++++++++++----------- OpenSSL/test/test_ssl.py | 29 +++++++++----- doc/api/ssl.rst | 34 +++++++++++++++++ 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index e754a7e5d..bd6876743 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -1,11 +1,12 @@ from sys import platform from functools import wraps, partial -from itertools import count +from itertools import count, chain from weakref import WeakValueDictionary from errno import errorcode from six import text_type as _text_type from six import integer_types as integer_types +from six import int2byte, byte2int from OpenSSL._util import ( ffi as _ffi, @@ -294,9 +295,7 @@ def __init__(self, method): self._tlsext_servername_callback = None self._app_data = None self._npn_advertise_callback = None - self._npn_advertise_callback_args = None self._npn_select_callback = None - self._npn_select_callback_args = None # SSL_CTX_set_app_data(self->ctx, self); # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | @@ -816,22 +815,35 @@ def wrapper(ssl, alert, arg): def set_npn_advertise_callback(self, callback): """ - Specify a callback function that will be called when offering Next - Protocol Negotiation. + Specify a callback function that will be called when offering `Next + Protocol Negotiation + `_ as a server. :param callback: The callback function. It will be invoked with one - argument, the Connection instance. It should return a Python - bytestring, like b'\\x08http/1.1\\x06spdy/2'. + argument, the Connection instance. It should return a list of + bytestrings representing the advertised protocols, like + ``[b'http/1.1', b'spdy/2']``. """ @wraps(callback) def wrapper(ssl, out, outlen, arg): - outstr = callback(Connection._reverse_mapping[ssl]) - self._npn_advertise_callback_args = [ - _ffi.new("unsigned int *", len(outstr)), - _ffi.new("unsigned char[]", outstr), + conn = Connection._reverse_mapping[ssl] + protos = callback(conn) + + # Join the protocols into a Python bytestring, length-prefixing + # each element. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL uses + # them. Then, return them appropriately in the output parameters. + conn._npn_advertise_callback_args = [ + _ffi.new("unsigned int *", len(protostr)), + _ffi.new("unsigned char[]", protostr), ] - outlen[0] = self._npn_advertise_callback_args[0][0] - out[0] = self._npn_advertise_callback_args[1] + outlen[0] = conn._npn_advertise_callback_args[0][0] + out[0] = conn._npn_advertise_callback_args[1] return 0 self._npn_advertise_callback = _ffi.callback( @@ -848,20 +860,38 @@ def set_npn_select_callback(self, callback): :param callback: The callback function. It will be invoked with two arguments: the Connection, and a list of offered protocols as - length-prefixed strings in a bytestring, e.g. - b'\\x08http/1.1\\x06spdy/2'. It should return one of those - bytestrings, the chosen protocol. + bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return + one of those bytestrings, the chosen protocol. """ @wraps(callback) def wrapper(ssl, out, outlen, in_, inlen, arg): - outstr = callback( - Connection._reverse_mapping[ssl], _ffi.string(in_)) - self._npn_select_callback_args = [ + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is actually made up of multiple + # length-prefixed bytestrings. We need to split that into a list. + instr = _ffi.buffer(in_, inlen) + protolist = [] + while instr: + # This slightly insane syntax is to make sure we get a + # bytestring: on Python 3, instr[0] would return an int and + # this call would fail. + l = byte2int(instr[0:1]) + proto = instr[1:l+1] + protolist.append(proto) + instr = instr[l+1:] + + # Call the callback + outstr = callback(conn, protolist) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL uses + # them. Then, return them appropriately in the output parameters. + conn._npn_select_callback_args = [ _ffi.new("unsigned char *", len(outstr)), _ffi.new("unsigned char[]", outstr), ] - outlen[0] = self._npn_select_callback_args[0][0] - out[0] = self._npn_select_callback_args[1] + outlen[0] = conn._npn_select_callback_args[0][0] + out[0] = conn._npn_select_callback_args[1] return 0 self._npn_select_callback = _ffi.callback( @@ -895,6 +925,13 @@ def __init__(self, context, socket=None): self._ssl = _ffi.gc(ssl, _lib.SSL_free) self._context = context + # References to strings used for Next Protocol Negotiation. OpenSSL's + # header files suggest that these might get copied at some point, but + # doesn't specify when, so we store them here to make sure they don't + # get freed before OpenSSL uses them. + self._npn_advertise_callback_args = None + self._npn_select_callback_args = None + self._reverse_mapping[self._ssl] = self if socket is None: @@ -1622,7 +1659,7 @@ def get_next_proto_negotiated(self): _lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len) if not data_len[0]: - return "" + return b"" else: return _ffi.string(data[0]) diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 404f8b9c7..9fc846646 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1439,14 +1439,19 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): Test for Next Protocol Negotiation in PyOpenSSL. """ def test_npn_success(self): - advertise_args =[] + """ + Tests that clients and servers that agree on the negotiated next + protocol can correct establish a connection, and that the agreed + protocol is reported by the connections. + """ + advertise_args = [] select_args = [] def advertise(conn): advertise_args.append((conn,)) - return b('\x08http/1.1\x06spdy/2') + return [b'http/1.1', b'spdy/2'] def select(conn, options): select_args.append((conn, options)) - return b('spdy/2') + return b'spdy/2' server_context = Context(TLSv1_METHOD) server_context.set_npn_advertise_callback(advertise) @@ -1470,21 +1475,25 @@ def select(conn, options): self._interactInMemory(server, client) self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - self.assertEqual(server.get_next_proto_negotiated(), b('spdy/2')) - self.assertEqual(client.get_next_proto_negotiated(), b('spdy/2')) + self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') def test_npn_client_fail(self): - advertise_args =[] + """ + Tests that when clients and servers cannot agree on what protocol to + use next that the TLS connection does not get established. + """ + advertise_args = [] select_args = [] def advertise(conn): advertise_args.append((conn,)) - return b('\x08http/1.1\x06spdy/2') + return [b'http/1.1', b'spdy/2'] def select(conn, options): select_args.append((conn, options)) - return b('') + return b'' server_context = Context(TLSv1_METHOD) server_context.set_npn_advertise_callback(advertise) @@ -1509,7 +1518,7 @@ def select(conn, options): self.assertRaises(Error, self._interactInMemory, server, client) self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index a75af1f7d..fbee1fe89 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -472,6 +472,33 @@ Context objects have the following methods: .. versionadded:: 0.13 +.. py:method:: Context.set_npn_advertise_callback(callback) + + Specify a callback function that will be called when offering `Next + Protocol Negotiation + `_ as a server. + + *callback* should be the callback function. It will be invoked with one + argument, the :py:class:`Connection` instance. It should return a list of + bytestrings representing the advertised protocols, like + ``[b'http/1.1', b'spdy/2']``. + + .. versionadded:: 0.15 + + +.. py:method:: Context.set_npn_select_callback(callback): + + Specify a callback function that will be called when a server offers Next + Protocol Negotiation options. + + *callback* should be the callback function. It will be invoked with two + arguments: the :py:class:`Connection`, and a list of offered protocols as + bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of + those bytestrings, the chosen protocol. + + .. versionadded:: 0.15 + + .. _openssl-session: Session objects @@ -806,6 +833,13 @@ Connection objects have the following methods: .. versionadded:: 0.15 +.. py:method:: Connection.get_next_proto_negotiated(): + + Get the protocol that was negotiated by Next Protocol Negotiation. + + .. versionadded:: 0.15 + + .. Rubric:: Footnotes .. [#connection-context-socket] Actually, all that is required is an object that From cd010f60e56683a5f093979abbe7e57ac766c21f Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 15 May 2014 19:00:27 +0100 Subject: [PATCH 3/5] Implement @alex's code review. --- OpenSSL/SSL.py | 12 +++--------- doc/api/ssl.rst | 4 +++- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index bd6876743..48e379055 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -6,7 +6,7 @@ from six import text_type as _text_type from six import integer_types as integer_types -from six import int2byte, byte2int +from six import int2byte, indexbytes from OpenSSL._util import ( ffi as _ffi, @@ -872,10 +872,7 @@ def wrapper(ssl, out, outlen, in_, inlen, arg): instr = _ffi.buffer(in_, inlen) protolist = [] while instr: - # This slightly insane syntax is to make sure we get a - # bytestring: on Python 3, instr[0] would return an int and - # this call would fail. - l = byte2int(instr[0:1]) + l = indexbytes(instr, 0) proto = instr[1:l+1] protolist.append(proto) instr = instr[l+1:] @@ -1658,10 +1655,7 @@ def get_next_proto_negotiated(self): _lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len) - if not data_len[0]: - return b"" - else: - return _ffi.string(data[0]) + return _ffi.buffer(data[0], data_len[0])[:] ConnectionType = Connection diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index fbee1fe89..4b57ac5e0 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -835,7 +835,9 @@ Connection objects have the following methods: .. py:method:: Connection.get_next_proto_negotiated(): - Get the protocol that was negotiated by Next Protocol Negotiation. + Get the protocol that was negotiated by Next Protocol Negotiation. Returns + a bytestring of the protocol name. If no protocol has been negotiated yet, + returns an empty string. .. versionadded:: 0.15 From 4969c22cf65f91f5bc87bf25ce45875dae32f576 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Tue, 27 May 2014 14:19:34 +0100 Subject: [PATCH 4/5] Copy buffer into bytestring. --- OpenSSL/SSL.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 48e379055..e97df8be6 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -869,7 +869,7 @@ def wrapper(ssl, out, outlen, in_, inlen, arg): # The string passed to us is actually made up of multiple # length-prefixed bytestrings. We need to split that into a list. - instr = _ffi.buffer(in_, inlen) + instr = _ffi.buffer(in_, inlen)[:] protolist = [] while instr: l = indexbytes(instr, 0) From 0ea76e7d977b19f2bb4ce4dee9bee8aa179eaff0 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Sun, 22 Mar 2015 09:05:28 +0000 Subject: [PATCH 5/5] Handle exceptions in NPN callbacks. --- OpenSSL/SSL.py | 169 ++++++++++++++++++++++++--------------- OpenSSL/test/test_ssl.py | 72 +++++++++++++++++ 2 files changed, 177 insertions(+), 64 deletions(-) diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index e97df8be6..5af496905 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -165,8 +165,24 @@ class SysCallError(Error): pass +class _CallbackExceptionHelper(object): + """ + A base class for wrapper classes that allow for intelligent exception + handling in OpenSSL callbacks. + """ + def __init__(self, callback): + pass + + def raise_if_problem(self): + if self._problems: + try: + _raise_current_error() + except Error: + pass + raise self._problems.pop(0) + -class _VerifyHelper(object): +class _VerifyHelper(_CallbackExceptionHelper): def __init__(self, callback): self._problems = [] @@ -197,15 +213,87 @@ def wrapper(ok, store_ctx): "int (*)(int, X509_STORE_CTX *)", wrapper) - def raise_if_problem(self): - if self._problems: +class _NpnAdvertiseHelper(_CallbackExceptionHelper): + def __init__(self, callback): + self._problems = [] + + @wraps(callback) + def wrapper(ssl, out, outlen, arg): try: - _raise_current_error() - except Error: - pass - raise self._problems.pop(0) + conn = Connection._reverse_mapping[ssl] + protos = callback(conn) + + # Join the protocols into a Python bytestring, length-prefixing + # each element. + protostr = b''.join( + chain.from_iterable((int2byte(len(p)), p) for p in protos) + ) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL + # uses them. Then, return them appropriately in the output + # parameters. + conn._npn_advertise_callback_args = [ + _ffi.new("unsigned int *", len(protostr)), + _ffi.new("unsigned char[]", protostr), + ] + outlen[0] = conn._npn_advertise_callback_args[0][0] + out[0] = conn._npn_advertise_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", + wrapper + ) +class _NpnSelectHelper(_CallbackExceptionHelper): + def __init__(self, callback): + self._problems = [] + + @wraps(callback) + def wrapper(ssl, out, outlen, in_, inlen, arg): + try: + conn = Connection._reverse_mapping[ssl] + + # The string passed to us is actually made up of multiple + # length-prefixed bytestrings. We need to split that into a + # list. + instr = _ffi.buffer(in_, inlen)[:] + protolist = [] + while instr: + l = indexbytes(instr, 0) + proto = instr[1:l+1] + protolist.append(proto) + instr = instr[l+1:] + + # Call the callback + outstr = callback(conn, protolist) + + # Save our callback arguments on the connection object. This is + # done to make sure that they don't get freed before OpenSSL + # uses them. Then, return them appropriately in the output + # parameters. + conn._npn_select_callback_args = [ + _ffi.new("unsigned char *", len(outstr)), + _ffi.new("unsigned char[]", outstr), + ] + outlen[0] = conn._npn_select_callback_args[0][0] + out[0] = conn._npn_select_callback_args[1] + return 0 + except Exception as e: + self._problems.append(e) + return 2 # SSL_TLSEXT_ERR_ALERT_FATAL + + self.callback = _ffi.callback( + "int (*)(SSL *, unsigned char **, unsigned char *, " + "const unsigned char *, unsigned int, void *)", + wrapper + ) + def _asFileDescriptor(obj): fd = None @@ -294,7 +382,9 @@ def __init__(self, method): self._info_callback = None self._tlsext_servername_callback = None self._app_data = None + self._npn_advertise_helper = None self._npn_advertise_callback = None + self._npn_select_helper = None self._npn_select_callback = None # SSL_CTX_set_app_data(self->ctx, self); @@ -824,31 +914,8 @@ def set_npn_advertise_callback(self, callback): bytestrings representing the advertised protocols, like ``[b'http/1.1', b'spdy/2']``. """ - @wraps(callback) - def wrapper(ssl, out, outlen, arg): - conn = Connection._reverse_mapping[ssl] - protos = callback(conn) - - # Join the protocols into a Python bytestring, length-prefixing - # each element. - protostr = b''.join( - chain.from_iterable((int2byte(len(p)), p) for p in protos) - ) - - # Save our callback arguments on the connection object. This is - # done to make sure that they don't get freed before OpenSSL uses - # them. Then, return them appropriately in the output parameters. - conn._npn_advertise_callback_args = [ - _ffi.new("unsigned int *", len(protostr)), - _ffi.new("unsigned char[]", protostr), - ] - outlen[0] = conn._npn_advertise_callback_args[0][0] - out[0] = conn._npn_advertise_callback_args[1] - return 0 - - self._npn_advertise_callback = _ffi.callback( - "int (*)(SSL *, const unsigned char **, unsigned int *, void *)", - wrapper) + self._npn_advertise_helper = _NpnAdvertiseHelper(callback) + self._npn_advertise_callback = self._npn_advertise_helper.callback _lib.SSL_CTX_set_next_protos_advertised_cb( self._context, self._npn_advertise_callback, _ffi.NULL) @@ -863,38 +930,8 @@ def set_npn_select_callback(self, callback): bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of those bytestrings, the chosen protocol. """ - @wraps(callback) - def wrapper(ssl, out, outlen, in_, inlen, arg): - conn = Connection._reverse_mapping[ssl] - - # The string passed to us is actually made up of multiple - # length-prefixed bytestrings. We need to split that into a list. - instr = _ffi.buffer(in_, inlen)[:] - protolist = [] - while instr: - l = indexbytes(instr, 0) - proto = instr[1:l+1] - protolist.append(proto) - instr = instr[l+1:] - - # Call the callback - outstr = callback(conn, protolist) - - # Save our callback arguments on the connection object. This is - # done to make sure that they don't get freed before OpenSSL uses - # them. Then, return them appropriately in the output parameters. - conn._npn_select_callback_args = [ - _ffi.new("unsigned char *", len(outstr)), - _ffi.new("unsigned char[]", outstr), - ] - outlen[0] = conn._npn_select_callback_args[0][0] - out[0] = conn._npn_select_callback_args[1] - return 0 - - self._npn_select_callback = _ffi.callback( - "int (*)(SSL *, unsigned char **, unsigned char *, " - "const unsigned char *, unsigned int, void *)", - wrapper) + self._npn_select_helper = _NpnSelectHelper(callback) + self._npn_select_callback = self._npn_select_helper.callback _lib.SSL_CTX_set_next_proto_select_cb( self._context, self._npn_select_callback, _ffi.NULL) @@ -963,6 +1000,10 @@ def __getattr__(self, name): def _raise_ssl_error(self, ssl, result): if self._context._verify_helper is not None: self._context._verify_helper.raise_if_problem() + if self._context._npn_advertise_helper is not None: + self._context._npn_advertise_helper.raise_if_problem() + if self._context._npn_select_helper is not None: + self._context._npn_select_helper.raise_if_problem() error = _lib.SSL_get_error(ssl, result) if error == _lib.SSL_ERROR_WANT_READ: diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 9fc846646..a45edc26b 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -1521,6 +1521,78 @@ def select(conn, options): self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + def test_npn_select_error(self): + """ + Test that we can handle exceptions in the select callback. If select + fails it should be fatal to the connection. + """ + advertise_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + raise TypeError + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the callback throws an exception it should be raised here. + self.assertRaises(TypeError, self._interactInMemory, server, client) + self.assertEqual([(server,)], advertise_args) + + + def test_npn_advertise_error(self): + """ + Test that we can handle exceptions in the advertise callback. If + advertise fails no NPN is advertised to the client. + """ + select_args = [] + def advertise(conn): + raise TypeError + def select(conn, options): + select_args.append((conn, options)) + return b'' + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises(TypeError, self._interactInMemory, server, client) + self.assertEqual([], select_args) + + class SessionTests(TestCase): """