diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 1215526e1..87492af6c 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -347,6 +347,21 @@ def SSLeay_version(type): return _ffi.string(_lib.SSLeay_version(type)) +def _requires_npn(func): + """ + Wraps any function that requires NPN support in OpenSSL, ensuring that + NotImplementedError is raised if NPN is not present. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if not _lib.Cryptography_HAS_NEXTPROTONEG: + raise NotImplementedError("NPN not available.") + + return func(*args, **kwargs) + + return wrapper + + class Session(object): pass @@ -924,6 +939,7 @@ def wrapper(ssl, alert, arg): self._context, self._tlsext_servername_callback) + @_requires_npn def set_npn_advertise_callback(self, callback): """ Specify a callback function that will be called when offering `Next @@ -941,6 +957,7 @@ def set_npn_advertise_callback(self, callback): self._context, self._npn_advertise_callback, _ffi.NULL) + @_requires_npn def set_npn_select_callback(self, callback): """ Specify a callback function that will be called when a server offers @@ -1746,6 +1763,8 @@ def get_cipher_version(self): version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher)) return version.decode("utf-8") + + @_requires_npn def get_next_proto_negotiated(self): """ Get the protocol that was negotiated by NPN. diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index c82dea6a5..caa9d0215 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -42,6 +42,8 @@ from OpenSSL.SSL import ( Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) +from OpenSSL._util import lib as _lib + from OpenSSL.test.util import NON_ASCII, TestCase, b from OpenSSL.test.test_crypto import ( cleartextCertificatePEM, cleartextPrivateKeyPEM) @@ -1626,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin): """ Test for Next Protocol Negotiation in PyOpenSSL. """ - def test_npn_success(self): - """ - Tests that clients and servers that agree on the negotiated next - protocol can correct establish a connection, and that the agreed - protocol is reported by the connections. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'spdy/2' - - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) - - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) - - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) - - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() - - client = Connection(client_context, None) - client.set_connect_state() - - self._interactInMemory(server, client) - - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - - self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') - self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') - - - def test_npn_client_fail(self): - """ - Tests that when clients and servers cannot agree on what protocol to - use next that the TLS connection does not get established. - """ - advertise_args = [] - select_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - select_args.append((conn, options)) - return b'' - - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) - - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) - - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) - - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() - - client = Connection(client_context, None) - client.set_connect_state() + if _lib.Cryptography_HAS_NEXTPROTONEG: + def test_npn_success(self): + """ + Tests that clients and servers that agree on the negotiated next + protocol can correct establish a connection, and that the agreed + protocol is reported by the connections. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'spdy/2' + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - # If the client doesn't return anything, the connection will fail. - self.assertRaises(Error, self._interactInMemory, server, client) + client = Connection(client_context, None) + client.set_connect_state() - self.assertEqual([(server,)], advertise_args) - self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) + self._interactInMemory(server, client) + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - def test_npn_select_error(self): - """ - Test that we can handle exceptions in the select callback. If select - fails it should be fatal to the connection. - """ - advertise_args = [] - def advertise(conn): - advertise_args.append((conn,)) - return [b'http/1.1', b'spdy/2'] - def select(conn, options): - raise TypeError + self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2') + self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2') - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) + def test_npn_client_fail(self): + """ + Tests that when clients and servers cannot agree on what protocol + to use next that the TLS connection does not get established. + """ + advertise_args = [] + select_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + select_args.append((conn, options)) + return b'' + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + client = Connection(client_context, None) + client.set_connect_state() - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + # If the client doesn't return anything, the connection will fail. + self.assertRaises(Error, self._interactInMemory, server, client) - client = Connection(client_context, None) - client.set_connect_state() + self.assertEqual([(server,)], advertise_args) + self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args) - # If the callback throws an exception it should be raised here. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([(server,)], advertise_args) + def test_npn_select_error(self): + """ + Test that we can handle exceptions in the select callback. If + select fails it should be fatal to the connection. + """ + advertise_args = [] + def advertise(conn): + advertise_args.append((conn,)) + return [b'http/1.1', b'spdy/2'] + def select(conn, options): + raise TypeError + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - def test_npn_advertise_error(self): - """ - Test that we can handle exceptions in the advertise callback. If - advertise fails no NPN is advertised to the client. - """ - select_args = [] - def advertise(conn): - raise TypeError - def select(conn, options): - select_args.append((conn, options)) - return b'' + client = Connection(client_context, None) + client.set_connect_state() - server_context = Context(TLSv1_METHOD) - server_context.set_npn_advertise_callback(advertise) + # If the callback throws an exception it should be raised here. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([(server,)], advertise_args) - client_context = Context(TLSv1_METHOD) - client_context.set_npn_select_callback(select) - # Necessary to actually accept the connection - server_context.use_privatekey( - load_privatekey(FILETYPE_PEM, server_key_pem)) - server_context.use_certificate( - load_certificate(FILETYPE_PEM, server_cert_pem)) + def test_npn_advertise_error(self): + """ + Test that we can handle exceptions in the advertise callback. If + advertise fails no NPN is advertised to the client. + """ + select_args = [] + def advertise(conn): + raise TypeError + def select(conn, options): + select_args.append((conn, options)) + return b'' + + server_context = Context(TLSv1_METHOD) + server_context.set_npn_advertise_callback(advertise) + + client_context = Context(TLSv1_METHOD) + client_context.set_npn_select_callback(select) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem)) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() - # Do a little connection to trigger the logic - server = Connection(server_context, None) - server.set_accept_state() + client = Connection(client_context, None) + client.set_connect_state() - client = Connection(client_context, None) - client.set_connect_state() + # If the client doesn't return anything, the connection will fail. + self.assertRaises( + TypeError, self._interactInMemory, server, client + ) + self.assertEqual([], select_args) - # If the client doesn't return anything, the connection will fail. - self.assertRaises(TypeError, self._interactInMemory, server, client) - self.assertEqual([], select_args) + else: + # No NPN. + def test_npn_not_implemented(self): + # Test the context methods first. + context = Context(TLSv1_METHOD) + fail_methods = [ + context.set_npn_advertise_callback, + context.set_npn_select_callback, + ] + for method in fail_methods: + self.assertRaises( + NotImplementedError, method, None + ) + + # Now test a connection. + conn = Connection(context) + fail_methods = [ + conn.get_next_proto_negotiated, + ] + for method in fail_methods: + self.assertRaises(NotImplementedError, method)