diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 39a2bcab1..bba59ac43 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -399,6 +399,21 @@ def SSLeay_version(type): return _ffi.string(_lib.SSLeay_version(type)) +def _requires_npn(func): + """ + Wraps any function that requires NPN support in OpenSSL, ensuring that + NotImplementedError is raised if NPN is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_NEXTPROTONEG: + raise NotImplementedError("NPN not available.") + + return func(*args, **kwargs) + + return wrapper + + def _requires_alpn(func): """ @@ -993,6 +1008,8 @@ def wrapper(ssl, alert, arg): _lib.SSL_CTX_set_tlsext_servername_callback( self._context, self._tlsext_servername_callback) + + @_requires_npn def set_npn_advertise_callback(self, callback): """ Specify a callback function that will be called when offering `Next @@ -1010,6 +1027,7 @@ def set_npn_advertise_callback(self, callback): self._context, self._npn_advertise_callback, _ffi.NULL) + @_requires_npn def set_npn_select_callback(self, callback): """ Specify a callback function that will be called when a server offers @@ -1868,6 +1886,8 @@ def get_cipher_version(self): version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher)) return version.decode("utf-8") + + @_requires_npn def get_next_proto_negotiated(self): """ Get the protocol that was negotiated by NPN. diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 4dedb6b10..2055b6dae 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -44,6 +44,7 @@ Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) from OpenSSL._util import lib as _lib + from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b from OpenSSL.test.test_crypto import ( cleartextCertificatePEM, cleartextPrivateKeyPEM, @@ -1627,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): """ Test for Next Protocol Negotiation in PyOpenSSL. """ - def test_npn_success(self): - """ - Tests that clients and servers that agree on the negotiated next - protocol can correct establish a connection, and that the agreed - protocol is reported by the connections. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'spdy/2' + if _lib.Cryptography_HAS_NEXTPROTONEG: + def test_npn_success(self): + """ + Tests that clients and servers that agree on the negotiated next + protocol can correct establish a connection, and that the agreed + protocol is reported by the connections. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - self._interactInMemory(server, client) + self._interactInMemory(server, client) - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') - self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') - def test_npn_client_fail(self): - """ - Tests that when clients and servers cannot agree on what protocol to - use next that the TLS connection does not get established. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'' + def test_npn_client_fail(self): + """ + Tests that when clients and servers cannot agree on what protocol + to use next that the TLS connection does not get established. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - # If the client doesn't return anything, the connection will fail. - self.assertRaises(Error, self._interactInMemory, server, client) + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - def test_npn_select_error(self): - """ - Test that we can handle exceptions in the select callback. If select - fails it should be fatal to the connection. - """ - advertise_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - raise TypeError + def test_npn_select_error(self): + """ + Test that we can handle exceptions in the select callback. If + select fails it should be fatal to the connection. + """ + advertise_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + raise TypeError - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() - # If the callback throws an exception it should be raised here. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([(server,)], advertise_args) + # If the callback throws an exception it should be raised here. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([(server,)], advertise_args) - def test_npn_advertise_error(self): - """ - Test that we can handle exceptions in the advertise callback. If - advertise fails no NPN is advertised to the client. - """ - select_args = [] - def advertise(conn): - raise TypeError - def select(conn, options): - select_args.append((conn, options)) - return b'' + def test_npn_advertise_error(self): + """ + Test that we can handle exceptions in the advertise callback. If + advertise fails no NPN is advertised to the client. + """ + select_args = [] + def advertise(conn): + raise TypeError + def select(conn, options): + select_args.append((conn, options)) + return b'' - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - client = Connection(client_context, None) - client.set_connect_state() + client = Connection(client_context, None) + client.set_connect_state() + + # If the client doesn't return anything, the connection will fail. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([], select_args) - # If the client doesn't return anything, the connection will fail. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([], select_args) + else: + # No NPN. + def test_npn_not_implemented(self): + # Test the context methods first. + context = Context(TLSv1_METHOD) + fail_methods = [ + context.set_npn_advertise_callback, + context.set_npn_select_callback, + ] + for method in fail_methods: + self.assertRaises( + NotImplementedError, method, None + ) + + # Now test a connection. + conn = Connection(context) + fail_methods = [ + conn.get_next_proto_negotiated, + ] + for method in fail_methods: + self.assertRaises(NotImplementedError, method)