diff --git a/ChangeLog b/ChangeLog index 252f13c21..44a69442b 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,9 @@ +2015-03-15 Jean-Paul Calderone + + * OpenSSL/SSL.py: Add ``Connection.recv_into``, mirroring the + builtin ``socket.recv_into``. Based on work from Cory Benfield. + * OpenSSL/test/test_ssl.py: Add tests for ``recv_into``. + 2015-01-30 Stephen Holsapple * OpenSSL/crypto.py: Expose ``X509StoreContext`` for verifying certificates. diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 2731d64a8..e86d855a7 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -1030,6 +1030,45 @@ def recv(self, bufsiz, flags=None): read = recv + def recv_into(self, buffer, nbytes=None, flags=None): + """ + Receive data on the connection and store the data into a buffer rather + than creating a new string. + + :param buffer: The buffer to copy into. + :param nbytes: (optional) The maximum number of bytes to read into the + buffer. If not present, defaults to the size of the buffer. If + larger than the size of the buffer, is reduced to the size of the + buffer. + :param flags: (optional) Included for compatibility with the socket + API, the value is ignored. + :return: The number of bytes read into the buffer. + """ + if nbytes is None: + nbytes = len(buffer) + else: + nbytes = min(nbytes, len(buffer)) + + # We need to create a temporary buffer. This is annoying, it would be + # better if we could pass memoryviews straight into the SSL_read call, + # but right now we can't. Revisit this if CFFI gets that ability. + buf = _ffi.new("char[]", nbytes) + result = _lib.SSL_read(self._ssl, buf, nbytes) + self._raise_ssl_error(self._ssl, result) + + # This strange line is all to avoid a memory copy. The buffer protocol + # should allow us to assign a CFFI buffer to the LHS of this line, but + # on CPython 3.3+ that segfaults. As a workaround, we can temporarily + # wrap it in a memoryview, except on Python 2.6 which doesn't have a + # memoryview type. + try: + buffer[:result] = memoryview(_ffi.buffer(buf, result)) + except NameError: + buffer[:result] = _ffi.buffer(buf, result) + + return result + + def _handle_bio_errors(self, bio, result): if _lib.BIO_should_retry(bio): if _lib.BIO_should_read(bio): diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index f098327c7..aa07e1bfb 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -2268,6 +2268,164 @@ def test_short_buffer(self): +def _make_memoryview(size): + """ + Create a new ``memoryview`` wrapped around a ``bytearray`` of the given + size. + """ + return memoryview(bytearray(size)) + + + +class ConnectionRecvIntoTests(TestCase, _LoopbackMixin): + """ + Tests for :py:obj:`Connection.recv_into` + """ + def _no_length_test(self, factory): + """ + Assert that when the given buffer is passed to + ``Connection.recv_into``, whatever bytes are available to be received + that fit into that buffer are written into that buffer. + """ + output_buffer = factory(5) + + server, client = self._loopback() + server.send(b('xy')) + + self.assertEqual(client.recv_into(output_buffer), 2) + self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00'))) + + + def test_bytearray_no_length(self): + """ + :py:obj:`Connection.recv_into` can be passed a ``bytearray`` instance + and data in the receive buffer is written to it. + """ + self._no_length_test(bytearray) + + + def _respects_length_test(self, factory): + """ + Assert that when the given buffer is passed to ``Connection.recv_into`` + along with a value for ``nbytes`` that is less than the size of that + buffer, only ``nbytes`` bytes are written into the buffer. + """ + output_buffer = factory(10) + + server, client = self._loopback() + server.send(b('abcdefghij')) + + self.assertEqual(client.recv_into(output_buffer, 5), 5) + self.assertEqual( + output_buffer, bytearray(b('abcde\x00\x00\x00\x00\x00')) + ) + + + def test_bytearray_respects_length(self): + """ + When called with a ``bytearray`` instance, + :py:obj:`Connection.recv_into` respects the ``nbytes`` parameter and + doesn't copy in more than that number of bytes. + """ + self._respects_length_test(bytearray) + + + def _doesnt_overfill_test(self, factory): + """ + Assert that if there are more bytes available to be read from the + receive buffer than would fit into the buffer passed to + :py:obj:`Connection.recv_into`, only as many as fit are written into + it. + """ + output_buffer = factory(5) + + server, client = self._loopback() + server.send(b('abcdefghij')) + + self.assertEqual(client.recv_into(output_buffer), 5) + self.assertEqual(output_buffer, bytearray(b('abcde'))) + rest = client.recv(5) + self.assertEqual(b('fghij'), rest) + + + def test_bytearray_doesnt_overfill(self): + """ + When called with a ``bytearray`` instance, + :py:obj:`Connection.recv_into` respects the size of the array and + doesn't write more bytes into it than will fit. + """ + self._doesnt_overfill_test(bytearray) + + + def _really_doesnt_overfill_test(self, factory): + """ + Assert that if the value given by ``nbytes`` is greater than the actual + size of the output buffer passed to :py:obj:`Connection.recv_into`, the + behavior is as if no value was given for ``nbytes`` at all. + """ + output_buffer = factory(5) + + server, client = self._loopback() + server.send(b('abcdefghij')) + + self.assertEqual(client.recv_into(output_buffer, 50), 5) + self.assertEqual(output_buffer, bytearray(b('abcde'))) + rest = client.recv(5) + self.assertEqual(b('fghij'), rest) + + + def test_bytearray_really_doesnt_overfill(self): + """ + When called with a ``bytearray`` instance and an ``nbytes`` value that + is too large, :py:obj:`Connection.recv_into` respects the size of the + array and not the ``nbytes`` value and doesn't write more bytes into + the buffer than will fit. + """ + self._doesnt_overfill_test(bytearray) + + + try: + memoryview + except NameError: + "cannot test recv_into memoryview without memoryview" + else: + def test_memoryview_no_length(self): + """ + :py:obj:`Connection.recv_into` can be passed a ``memoryview`` + instance and data in the receive buffer is written to it. + """ + self._no_length_test(_make_memoryview) + + + def test_memoryview_respects_length(self): + """ + When called with a ``memoryview`` instance, + :py:obj:`Connection.recv_into` respects the ``nbytes`` parameter + and doesn't copy more than that number of bytes in. + """ + self._respects_length_test(_make_memoryview) + + + def test_memoryview_doesnt_overfill(self): + """ + When called with a ``memoryview`` instance, + :py:obj:`Connection.recv_into` respects the size of the array and + doesn't write more bytes into it than will fit. + """ + self._doesnt_overfill_test(_make_memoryview) + + + def test_memoryview_really_doesnt_overfill(self): + """ + When called with a ``memoryview`` instance and an ``nbytes`` value + that is too large, :py:obj:`Connection.recv_into` respects the size + of the array and not the ``nbytes`` value and doesn't write more + bytes into the buffer than will fit. + """ + self._doesnt_overfill_test(_make_memoryview) + + + class ConnectionSendallTests(TestCase, _LoopbackMixin): """ Tests for :py:obj:`Connection.sendall`. diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index a75af1f7d..a3265b07f 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -614,6 +614,14 @@ Connection objects have the following methods: by *bufsize*. +.. py:method:: Connection.recv_into(buffer[, nbytes[, flags]]) + + Receive data from the Connection and copy it directly into the provided + buffer. The return value is the number of bytes read from the connection. + The maximum amount of data to be received at once is specified by *nbytes*. + *flags* is accepted for compatibility with ``socket.recv_into`` but its + value is ignored. + .. py:method:: Connection.bio_write(bytes) If the Connection was created with a memory BIO, this method can be used to add