Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 155 additions & 8 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from sys import platform
from functools import wraps, partial
from itertools import count
from itertools import count, chain
from weakref import WeakValueDictionary
from errno import errorcode

from six import text_type as _text_type
from six import integer_types as integer_types
from six import int2byte, indexbytes

from OpenSSL._util import (
ffi as _ffi,
Expand Down Expand Up @@ -164,8 +165,24 @@ class SysCallError(Error):
pass


class _CallbackExceptionHelper(object):
"""
A base class for wrapper classes that allow for intelligent exception
handling in OpenSSL callbacks.
"""
def __init__(self, callback):
pass

def raise_if_problem(self):
if self._problems:
try:
_raise_current_error()
except Error:
pass
raise self._problems.pop(0)


class _VerifyHelper(object):
class _VerifyHelper(_CallbackExceptionHelper):
def __init__(self, callback):
self._problems = []

Expand Down Expand Up @@ -196,14 +213,86 @@ def wrapper(ok, store_ctx):
"int (*)(int, X509_STORE_CTX *)", wrapper)


def raise_if_problem(self):
if self._problems:
class _NpnAdvertiseHelper(_CallbackExceptionHelper):
def __init__(self, callback):
self._problems = []

@wraps(callback)
def wrapper(ssl, out, outlen, arg):
try:
_raise_current_error()
except Error:
pass
raise self._problems.pop(0)
conn = Connection._reverse_mapping[ssl]
protos = callback(conn)

# Join the protocols into a Python bytestring, length-prefixing
# each element.
protostr = b''.join(
chain.from_iterable((int2byte(len(p)), p) for p in protos)
)

# Save our callback arguments on the connection object. This is
# done to make sure that they don't get freed before OpenSSL
# uses them. Then, return them appropriately in the output
# parameters.
conn._npn_advertise_callback_args = [
_ffi.new("unsigned int *", len(protostr)),
_ffi.new("unsigned char[]", protostr),
]
outlen[0] = conn._npn_advertise_callback_args[0][0]
out[0] = conn._npn_advertise_callback_args[1]
return 0
except Exception as e:
self._problems.append(e)
return 2 # SSL_TLSEXT_ERR_ALERT_FATAL

self.callback = _ffi.callback(
"int (*)(SSL *, const unsigned char **, unsigned int *, void *)",
wrapper
)


class _NpnSelectHelper(_CallbackExceptionHelper):
def __init__(self, callback):
self._problems = []

@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
try:
conn = Connection._reverse_mapping[ssl]

# The string passed to us is actually made up of multiple
# length-prefixed bytestrings. We need to split that into a
# list.
instr = _ffi.buffer(in_, inlen)[:]
protolist = []
while instr:
l = indexbytes(instr, 0)
proto = instr[1:l+1]
protolist.append(proto)
instr = instr[l+1:]

# Call the callback
outstr = callback(conn, protolist)

# Save our callback arguments on the connection object. This is
# done to make sure that they don't get freed before OpenSSL
# uses them. Then, return them appropriately in the output
# parameters.
conn._npn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
outlen[0] = conn._npn_select_callback_args[0][0]
out[0] = conn._npn_select_callback_args[1]
return 0
except Exception as e:
self._problems.append(e)
return 2 # SSL_TLSEXT_ERR_ALERT_FATAL

self.callback = _ffi.callback(
"int (*)(SSL *, unsigned char **, unsigned char *, "
"const unsigned char *, unsigned int, void *)",
wrapper
)


def _asFileDescriptor(obj):
Expand Down Expand Up @@ -293,6 +382,10 @@ def __init__(self, method):
self._info_callback = None
self._tlsext_servername_callback = None
self._app_data = None
self._npn_advertise_helper = None
self._npn_advertise_callback = None
self._npn_select_helper = None
self._npn_select_callback = None

# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
Expand Down Expand Up @@ -809,6 +902,39 @@ def wrapper(ssl, alert, arg):
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)


def set_npn_advertise_callback(self, callback):
"""
Specify a callback function that will be called when offering `Next
Protocol Negotiation
<https://technotes.googlecode.com/git/nextprotoneg.html>`_ as a server.

:param callback: The callback function. It will be invoked with one
argument, the Connection instance. It should return a list of
bytestrings representing the advertised protocols, like
``[b'http/1.1', b'spdy/2']``.
"""
self._npn_advertise_helper = _NpnAdvertiseHelper(callback)
self._npn_advertise_callback = self._npn_advertise_helper.callback
_lib.SSL_CTX_set_next_protos_advertised_cb(
self._context, self._npn_advertise_callback, _ffi.NULL)


def set_npn_select_callback(self, callback):
"""
Specify a callback function that will be called when a server offers
Next Protocol Negotiation options.

:param callback: The callback function. It will be invoked with two
arguments: the Connection, and a list of offered protocols as
bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return
one of those bytestrings, the chosen protocol.
"""
self._npn_select_helper = _NpnSelectHelper(callback)
self._npn_select_callback = self._npn_select_helper.callback
_lib.SSL_CTX_set_next_proto_select_cb(
self._context, self._npn_select_callback, _ffi.NULL)

ContextType = Context


Expand All @@ -833,6 +959,13 @@ def __init__(self, context, socket=None):
self._ssl = _ffi.gc(ssl, _lib.SSL_free)
self._context = context

# References to strings used for Next Protocol Negotiation. OpenSSL's
# header files suggest that these might get copied at some point, but
# doesn't specify when, so we store them here to make sure they don't
# get freed before OpenSSL uses them.
self._npn_advertise_callback_args = None
self._npn_select_callback_args = None

self._reverse_mapping[self._ssl] = self

if socket is None:
Expand Down Expand Up @@ -867,6 +1000,10 @@ def __getattr__(self, name):
def _raise_ssl_error(self, ssl, result):
if self._context._verify_helper is not None:
self._context._verify_helper.raise_if_problem()
if self._context._npn_advertise_helper is not None:
self._context._npn_advertise_helper.raise_if_problem()
if self._context._npn_select_helper is not None:
self._context._npn_select_helper.raise_if_problem()

error = _lib.SSL_get_error(ssl, result)
if error == _lib.SSL_ERROR_WANT_READ:
Expand Down Expand Up @@ -1550,6 +1687,16 @@ def get_cipher_version(self):
version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
return version.decode("utf-8")

def get_next_proto_negotiated(self):
"""
Get the protocol that was negotiated by NPN.
"""
data = _ffi.new("unsigned char **")
data_len = _ffi.new("unsigned int *")

_lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len)

return _ffi.buffer(data[0], data_len[0])[:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is responsible for the memory (in particular, releasing it) used by the protocol name in data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are (we allocated it, we have to free it). However, it should be freed up automatically because data will go out of scope at the bottom of the function (CPython will free it immediately, PyPy will get around to it). That's fine, because _ffi.buffer() will have copied what we care about out already.



ConnectionType = Connection
Expand Down
159 changes: 159 additions & 0 deletions OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,165 @@ def servername(conn):
self.assertEqual([(server, b("foo1.example.com"))], args)


class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
"""
Test for Next Protocol Negotiation in PyOpenSSL.
"""
def test_npn_success(self):
"""
Tests that clients and servers that agree on the negotiated next
protocol can correct establish a connection, and that the agreed
protocol is reported by the connections.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b'spdy/2'

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

self._interactInMemory(server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)

self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')


def test_npn_client_fail(self):
"""
Tests that when clients and servers cannot agree on what protocol to
use next that the TLS connection does not get established.
"""
advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(Error, self._interactInMemory, server, client)

self.assertEqual([(server,)], advertise_args)
self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)


def test_npn_select_error(self):
"""
Test that we can handle exceptions in the select callback. If select
fails it should be fatal to the connection.
"""
advertise_args = []
def advertise(conn):
advertise_args.append((conn,))
return [b'http/1.1', b'spdy/2']
def select(conn, options):
raise TypeError

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

# If the callback throws an exception it should be raised here.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([(server,)], advertise_args)


def test_npn_advertise_error(self):
"""
Test that we can handle exceptions in the advertise callback. If
advertise fails no NPN is advertised to the client.
"""
select_args = []
def advertise(conn):
raise TypeError
def select(conn, options):
select_args.append((conn, options))
return b''

server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)

client_context = Context(TLSv1_METHOD)
client_context.set_npn_select_callback(select)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem))
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem))

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_connect_state()

# If the client doesn't return anything, the connection will fail.
self.assertRaises(TypeError, self._interactInMemory, server, client)
self.assertEqual([], select_args)



class SessionTests(TestCase):
"""
Expand Down
Loading