Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,21 @@ def SSLeay_version(type):
return _ffi.string(_lib.SSLeay_version(type))


def _requires_npn(func):
"""
Wraps any function that requires NPN support in OpenSSL, ensuring that
NotImplementedError is raised if NPN is not present.
"""
@wraps(func)
def wrapper(*args, **kwargs):
if not _lib.Cryptography_HAS_NEXTPROTONEG:
raise NotImplementedError("NPN not available.")

return func(*args, **kwargs)

return wrapper



def _requires_alpn(func):
"""
Expand Down Expand Up @@ -993,6 +1008,8 @@ def wrapper(ssl, alert, arg):
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)


@_requires_npn
def set_npn_advertise_callback(self, callback):
"""
Specify a callback function that will be called when offering `Next
Expand All @@ -1010,6 +1027,7 @@ def set_npn_advertise_callback(self, callback):
self._context, self._npn_advertise_callback, _ffi.NULL)


@_requires_npn
def set_npn_select_callback(self, callback):
"""
Specify a callback function that will be called when a server offers
Expand Down Expand Up @@ -1868,6 +1886,8 @@ def get_cipher_version(self):
version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
return version.decode("utf-8")


@_requires_npn
def get_next_proto_negotiated(self):
"""
Get the protocol that was negotiated by NPN.
Expand Down
268 changes: 148 additions & 120 deletions OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)

from OpenSSL._util import lib as _lib

from OpenSSL.test.util import WARNING_TYPE_EXPECTED, NON_ASCII, TestCase, b
from OpenSSL.test.test_crypto import (
cleartextCertificatePEM, cleartextPrivateKeyPEM,
Expand Down Expand Up @@ -1627,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
"""
Test for Next Protocol Negotiation in PyOpenSSL.
"""
def test_npn_success(self):
"""
Tests that clients and servers that agree on the negotiated next
protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'
if _lib.Cryptography_HAS_NEXTPROTONEG:
def test_npn_success(self):
"""
Tests that clients and servers that agree on the negotiated next
protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)
client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()
client = Connection(client_context, None)
client.set_connect_state()

self._interactInMemory(server, client)
self._interactInMemory(server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')


def test_npn_client_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol to
use next that the TLS connection does not get established.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b''
def test_npn_client_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol
to use next that the TLS connection does not get established.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)
client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()
client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)
# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)


def test_npn_select_error(self):
"""
Test that we can handle exceptions in the select callback. If select
fails it should be fatal to the connection.
"""
advertise_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
raise TypeError
def test_npn_select_error(self):
"""
Test that we can handle exceptions in the select callback. If
select fails it should be fatal to the connection.
"""
advertise_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
raise TypeError

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)
client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()
client = Connection(client_context, None)
client.set_connect_state()

# If the callback throws an exception it should be raised here.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([(server,)], advertise_args)
# If the callback throws an exception it should be raised here.
self.assertRaises(
TypeError, self._interactInMemory, server, client
)
self.assertEqual([(server,)], advertise_args)


def test_npn_advertise_error(self):
"""
Test that we can handle exceptions in the advertise callback. If
advertise fails no NPN is advertised to the client.
"""
select_args = []
def advertise(conn):
raise TypeError
def select(conn, options):
select_args.append((conn, options))
return b''
def test_npn_advertise_error(self):
"""
Test that we can handle exceptions in the advertise callback. If
advertise fails no NPN is advertised to the client.
"""
select_args = []
def advertise(conn):
raise TypeError
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)
client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()
client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(
TypeError, self._interactInMemory, server, client
)
self.assertEqual([], select_args)

# If the client doesn't return anything, the connection will fail.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([], select_args)
else:
# No NPN.
def test_npn_not_implemented(self):
# Test the context methods first.
context = Context(TLSv1_METHOD)
fail_methods = [
context.set_npn_advertise_callback,
context.set_npn_select_callback,
]
for method in fail_methods:
self.assertRaises(
NotImplementedError, method, None
)

# Now test a connection.
conn = Connection(context)
fail_methods = [
conn.get_next_proto_negotiated,
]
for method in fail_methods:
self.assertRaises(NotImplementedError, method)



Expand Down