Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,21 @@ def SSLeay_version(type):
return _ffi.string(_lib.SSLeay_version(type))


def _requires_npn(func):
"""
Wraps any function that requires NPN support in OpenSSL, ensuring that
NotImplementedError is raised if NPN is not present.
"""
@wraps(func)
def wrapper(*args, **kwargs):
if not _lib.Cryptography_HAS_NEXTPROTONEG:
raise NotImplementedError("NPN not available.")

return func(*args, **kwargs)

return wrapper



class Session(object):
pass
Expand Down Expand Up @@ -924,6 +939,7 @@ def wrapper(ssl, alert, arg):
self._context, self._tlsext_servername_callback)


@_requires_npn
def set_npn_advertise_callback(self, callback):
"""
Specify a callback function that will be called when offering `Next
Expand All @@ -941,6 +957,7 @@ def set_npn_advertise_callback(self, callback):
self._context, self._npn_advertise_callback, _ffi.NULL)


@_requires_npn
def set_npn_select_callback(self, callback):
"""
Specify a callback function that will be called when a server offers
Expand Down Expand Up @@ -1746,6 +1763,8 @@ def get_cipher_version(self):
version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
return version.decode("utf-8")


@_requires_npn
def get_next_proto_negotiated(self):
"""
Get the protocol that was negotiated by NPN.
Expand Down
299 changes: 164 additions & 135 deletions OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from OpenSSL.SSL import (
Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)

from OpenSSL._util import lib as _lib

from OpenSSL.test.util import NON_ASCII, TestCase, b
from OpenSSL.test.test_crypto import (
cleartextCertificatePEM, cleartextPrivateKeyPEM)
Expand Down Expand Up @@ -1626,159 +1628,186 @@ class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
"""
Test for Next Protocol Negotiation in PyOpenSSL.
"""
def test_npn_success(self):
"""
Tests that clients and servers that agree on the negotiated next
protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

self._interactInMemory(server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')


def test_npn_client_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol to
use next that the TLS connection does not get established.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()
if _lib.Cryptography_HAS_NEXTPROTONEG:
def test_npn_success(self):
"""
Tests that clients and servers that agree on the negotiated next
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring is not quite in the right format, have a look at https://plus.google.com/+JonathanLange/posts/YA3ThKWhSAj

protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)
client = Connection(client_context, None)
client.set_connect_state()

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
self._interactInMemory(server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

def test_npn_select_error(self):
"""
Test that we can handle exceptions in the select callback. If select
fails it should be fatal to the connection.
"""
advertise_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
raise TypeError
self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)
def test_npn_client_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol
to use next that the TLS connection does not get established.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
client = Connection(client_context, None)
client.set_connect_state()

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)

client = Connection(client_context, None)
client.set_connect_state()
self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

# If the callback throws an exception it should be raised here.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([(server,)], advertise_args)

def test_npn_select_error(self):
"""
Test that we can handle exceptions in the select callback. If
select fails it should be fatal to the connection.
"""
advertise_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
raise TypeError

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

def test_npn_advertise_error(self):
"""
Test that we can handle exceptions in the advertise callback. If
advertise fails no NPN is advertised to the client.
"""
select_args = []
def advertise(conn):
raise TypeError
def select(conn, options):
select_args.append((conn, options))
return b''
client = Connection(client_context, None)
client.set_connect_state()

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
# If the callback throws an exception it should be raised here.
self.assertRaises(
TypeError, self._interactInMemory, server, client
)
self.assertEqual([(server,)], advertise_args)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))
def test_npn_advertise_error(self):
"""
Test that we can handle exceptions in the advertise callback. If
advertise fails no NPN is advertised to the client.
"""
select_args = []
def advertise(conn):
raise TypeError
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()
client = Connection(client_context, None)
client.set_connect_state()

client = Connection(client_context, None)
client.set_connect_state()
# If the client doesn't return anything, the connection will fail.
self.assertRaises(
TypeError, self._interactInMemory, server, client
)
self.assertEqual([], select_args)

# If the client doesn't return anything, the connection will fail.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([], select_args)
else:
# No NPN.
def test_npn_not_implemented(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a docstring

# Test the context methods first.
context = Context(TLSv1_METHOD)
fail_methods = [
context.set_npn_advertise_callback,
context.set_npn_select_callback,
]
for method in fail_methods:
self.assertRaises(
NotImplementedError, method, None
)

# Now test a connection.
conn = Connection(context)
fail_methods = [
conn.get_next_proto_negotiated,
]
for method in fail_methods:
self.assertRaises(NotImplementedError, method)



Expand Down