diff --git a/tlslite/errors.py b/tlslite/errors.py index 2c523176..35878987 100644 --- a/tlslite/errors.py +++ b/tlslite/errors.py @@ -14,7 +14,15 @@ from .constants import AlertDescription, AlertLevel -class TLSError(Exception): +class BaseTLSException(Exception): + """Metaclass for TLS Lite exceptions. + + Look to L{TLSError} for exceptions that should be caught by tlslite + consumers + """ + pass + +class TLSError(BaseTLSException): """Base class for all TLS Lite exceptions.""" def __str__(self): @@ -173,5 +181,20 @@ class TLSUnsupportedError(TLSError): pass class TLSInternalError(TLSError): - """The internal state of object is unexpected or invalid""" + """The internal state of object is unexpected or invalid. + + Caused by incorrect use of API. + """ + pass + +class TLSProtocolException(BaseTLSException): + """Exceptions used internally for handling errors in received messages""" + pass + +class TLSIllegalParameterException(TLSProtocolException): + """Parameters specified in message were incorrect or invalid""" + pass + +class TLSRecordOverflow(TLSProtocolException): + """The received record size was too big""" pass diff --git a/tlslite/messages.py b/tlslite/messages.py index 4bfb6b41..b9040d24 100644 --- a/tlslite/messages.py +++ b/tlslite/messages.py @@ -20,38 +20,53 @@ from .utils.tackwrapper import * from .extensions import * -class RecordHeader3(object): - def __init__(self): +class RecordHeader(object): + + """Generic interface to SSLv2 and SSLv3 (and later) record headers""" + + def __init__(self, ssl2): + """define instance variables""" self.type = 0 - self.version = (0,0) + self.version = (0, 0) self.length = 0 - self.ssl2 = False + self.ssl2 = ssl2 + +class RecordHeader3(RecordHeader): + + """SSLv3 (and later) TLS record header""" + + def __init__(self): + """Define a SSLv3 style class""" + super(RecordHeader3, self).__init__(ssl2=False) def create(self, version, type, length): + """Set object values for writing (serialisation)""" self.type = type self.version = version self.length = length return self def write(self): - w = Writer() - w.add(self.type, 1) - w.add(self.version[0], 1) - w.add(self.version[1], 1) - w.add(self.length, 2) - return w.bytes - - def parse(self, p): - self.type = p.get(1) - self.version = (p.get(1), p.get(1)) - self.length = p.get(2) + """Serialise object to bytearray""" + writer = Writer() + writer.add(self.type, 1) + writer.add(self.version[0], 1) + writer.add(self.version[1], 1) + writer.add(self.length, 2) + return writer.bytes + + def parse(self, parser): + """Deserialise object from Parser""" + self.type = parser.get(1) + self.version = (parser.get(1), parser.get(1)) + self.length = parser.get(2) self.ssl2 = False return self @property def type_name(self): matching = [x[0] for x in ContentType.__dict__.items() - if x[1] == self.type] + if x[1] == self.type] if len(matching) == 0: return "unknown(" + str(self.type) + ")" else: @@ -66,22 +81,41 @@ def __repr__(self): return "RecordHeader3(type={0}, version=({1[0]}.{1[1]}), length={2})".\ format(self.type, self.version, self.length) -class RecordHeader2(object): +class RecordHeader2(RecordHeader): + """SSLv2 record header (just reading)""" def __init__(self): - self.type = 0 - self.version = (0,0) - self.length = 0 - self.ssl2 = True + """Define a SSLv2 style class""" + super(RecordHeader2, self).__init__(ssl2=True) - def parse(self, p): - if p.get(1)!=128: + def parse(self, parser): + """Deserialise object from Parser""" + if parser.get(1) != 128: raise SyntaxError() self.type = ContentType.handshake - self.version = (2,0) - #We don't support 2-byte-length-headers; could be a problem - self.length = p.get(1) + self.version = (2, 0) + #XXX We don't support 2-byte-length-headers; could be a problem + self.length = parser.get(1) return self +class Message(object): + + """Generic TLS message""" + + def __init__(self, contentType, data): + """ + Initialize object with specified contentType and data + + @type contentType: int + @param contentType: TLS record layer content type of associated data + @type data: bytearray + @param data: data + """ + self.contentType = contentType + self.data = data + + def write(self): + """Return serialised object data""" + return self.data class Alert(object): def __init__(self): diff --git a/tlslite/recordlayer.py b/tlslite/recordlayer.py new file mode 100644 index 00000000..38ff0859 --- /dev/null +++ b/tlslite/recordlayer.py @@ -0,0 +1,196 @@ +# Copyright (c) 2014, Hubert Kario +# +# See the LICENSE file for legal information regarding use of this file. + +"""Implementation of the TLS Record Layer protocol""" + +import socket +import errno +from tlslite.constants import ContentType +from .messages import RecordHeader3, RecordHeader2 +from .utils.codec import Parser +from .errors import TLSRecordOverflow, TLSIllegalParameterException,\ + TLSAbruptCloseError + +class RecordSocket(object): + + """Socket wrapper for reading and writing TLS Records""" + + def __init__(self, sock): + """ + Assign socket to wrapper + + @type sock: socket.socket + """ + self.sock = sock + self.version = (0, 0) + + def _sockSendAll(self, data): + """ + Send all data through socket + + @type data: bytearray + @param data: data to send + @raise socket.error: when write to socket failed + """ + while 1: + try: + bytesSent = self.sock.send(data) + except socket.error as why: + if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + yield 1 + continue + raise + + if bytesSent == len(data): + return + data = data[bytesSent:] + yield 1 + + def send(self, msg): + """ + Send the message through socket. + + @type msg: bytearray + @param msg: TLS message to send + @raise socket.error: when write to socket failed + """ + + data = msg.write() + + header = RecordHeader3().create(self.version, + msg.contentType, + len(data)) + + data = header.write() + data + + for result in self._sockSendAll(data): + yield result + + def _sockRecvAll(self, length): + """ + Read exactly the amount of bytes specified in L{length} from raw socket. + + @rtype: generator + @return: generator that will return 0 or 1 in case the socket is non + blocking and would block and bytearray in case the read finished + @raise TLSAbruptCloseError: when the socket closed + """ + + buf = bytearray(0) + + if length == 0: + yield buf + + while True: + try: + socketBytes = self.sock.recv(length - len(buf)) + except socket.error as why: + if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + yield 0 + continue + else: + raise + + #if the connection closed, raise socket error + if len(socketBytes) == 0: + raise TLSAbruptCloseError() + + buf += bytearray(socketBytes) + if len(buf) == length: + yield buf + + def _recvHeader(self): + """Read a single record header from socket""" + #Read the next record header + buf = bytearray(0) + ssl2 = False + + result = None + for result in self._sockRecvAll(1): + if result in (0, 1): + yield result + else: break + assert result is not None + + buf += result + + if buf[0] in ContentType.all: + ssl2 = False + # SSLv3 record layer header is 5 bytes long, we already read 1 + result = None + for result in self._sockRecvAll(4): + if result in (0, 1): + yield result + else: break + assert result is not None + buf += result + # XXX this should be 'buf[0] & 128', otherwise hello messages longer + # than 127 bytes won't be properly parsed + elif buf[0] == 128: + ssl2 = True + # in SSLv2 we need to read 2 bytes in total to know the size of + # header, we already read 1 + result = None + for result in self._sockRecvAll(1): + if result in (0, 1): + yield result + else: break + assert result is not None + buf += result + else: + raise TLSIllegalParameterException( + "Record header type doesn't specify known type") + + #Parse the record header + if ssl2: + record = RecordHeader2().parse(Parser(buf)) + else: + record = RecordHeader3().parse(Parser(buf)) + + yield record + + def recv(self): + """ + Read a single record from socket, handles both SSLv2 and SSLv3 record + layer + + @rtype: generator + @return: generator that returns 0 or 1 in case the read would be + blocking or a tuple containing record header (object) and record + data (bytearray) read from socket + @raise socket.error: In case of network error + @raise TLSAbruptCloseError: When the socket was closed on the other + side in middle of record receiving + @raise TLSRecordOverflow: When the received record was longer than + allowed by TLS + @raise TLSIllegalParameterException: When the record header was + malformed + """ + + record = None + for record in self._recvHeader(): + if record in (0, 1): + yield record + else: break + assert record is not None + + #Check the record header fields + # 18432 = 2**14 (basic record size limit) + 1024 (maximum compression + # overhead) + 1024 (maximum encryption overhead) + if record.length > 18432: + raise TLSRecordOverflow() + + #Read the record contents + buf = bytearray(0) + + result = None + for result in self._sockRecvAll(record.length): + if result in (0, 1): + yield result + else: break + assert result is not None + + buf += result + + yield (record, buf) diff --git a/tlslite/tlsrecordlayer.py b/tlslite/tlsrecordlayer.py index ecabd561..7c188c19 100644 --- a/tlslite/tlsrecordlayer.py +++ b/tlslite/tlsrecordlayer.py @@ -19,9 +19,9 @@ from .mathtls import * from .constants import * from .utils.cryptomath import getRandomBytes +from .recordlayer import RecordSocket import socket -import errno import traceback class _ConnectionState(object): @@ -102,6 +102,7 @@ class TLSRecordLayer(object): def __init__(self, sock): self.sock = sock + self._recordSocket = RecordSocket(sock) #My session object (Session instance; read-only) self.session = None @@ -120,7 +121,7 @@ def __init__(self, sock): self._handshake_sha256 = hashlib.sha256() #TLS Protocol Version - self.version = (0,0) #read-only + self._version = (0, 0) #read-only self._versionCheck = False #Once we choose a version, this is True #Current and Pending connection states @@ -149,6 +150,23 @@ def __init__(self, sock): #Fault we will induce, for testing purposes self.fault = None + @property + def version(self): + """Get the SSL protocol version of connection""" + return self._version + + @version.setter + def version(self, value): + """ + Set the SSL protocol version of connection + + The setter is a public method only for backwards compatibility. + Don't use it! See at HandshakeSettings for options to set desired + protocol version. + """ + self._version = value + self._recordSocket.version = value + def clearReadBuffer(self): self._readBuffer = b'' @@ -610,53 +628,43 @@ def _sendMsg(self, msg, randomizeFirstBlock = True): b += macBytes b = self._writeState.encContext.encrypt(b) - #Add record header and send - r = RecordHeader3().create(self.version, contentType, len(b)) - s = r.write() + b - while 1: - try: - bytesSent = self.sock.send(s) #Might raise socket.error - except socket.error as why: - if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): - yield 1 - continue - else: - # The socket was unexpectedly closed. The tricky part - # is that there may be an alert sent by the other party - # sitting in the read buffer. So, if we get here after - # handshaking, we will just raise the error and let the - # caller read more data if it would like, thus stumbling - # upon the error. - # - # However, if we get here DURING handshaking, we take - # it upon ourselves to see if the next message is an - # Alert. - if contentType == ContentType.handshake: - - # See if there's an alert record - # Could raise socket.error or TLSAbruptCloseError - for result in self._getNextRecord(): - if result in (0,1): - yield result - - # Closes the socket - self._shutdown(False) - - # If we got an alert, raise it - recordHeader, p = result - if recordHeader.type == ContentType.alert: - alert = Alert().parse(p) - raise TLSRemoteAlert(alert) - else: - # If we got some other message who know what - # the remote side is doing, just go ahead and - # raise the socket.error - raise - if bytesSent == len(s): - return - s = s[bytesSent:] - yield 1 + msg = Message(contentType, b) + try: + for result in self._recordSocket.send(msg): + if result in (0, 1): + yield result + except socket.error: + # The socket was unexpectedly closed. The tricky part + # is that there may be an alert sent by the other party + # sitting in the read buffer. So, if we get here after + # handshaking, we will just raise the error and let the + # caller read more data if it would like, thus stumbling + # upon the error. + # + # However, if we get here DURING handshaking, we take + # it upon ourselves to see if the next message is an + # Alert. + if contentType == ContentType.handshake: + + # See if there's an alert record + # Could raise socket.error or TLSAbruptCloseError + for result in self._getNextRecord(): + if result in (0, 1): + yield result + + # Closes the socket + self._shutdown(False) + # If we got an alert, raise it + recordHeader, p = result + if recordHeader.type == ContentType.alert: + alert = Alert().parse(p) + raise TLSRemoteAlert(alert) + else: + # If we got some other message who know what + # the remote side is doing, just go ahead and + # raise the socket.error + raise def _getMsg(self, expectedType, secondaryType=None, constructorType=None): try: @@ -827,68 +835,18 @@ def _getNextRecord(self): yield (recordHeader, Parser(b)) return - #Otherwise... - #Read the next record header - b = bytearray(0) - recordHeaderLength = 1 - ssl2 = False - while 1: - try: - s = self.sock.recv(recordHeaderLength-len(b)) - except socket.error as why: - if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): - yield 0 - continue - else: - raise - - #If the connection was abruptly closed, raise an error - if len(s)==0: - raise TLSAbruptCloseError() - - b += bytearray(s) - if len(b)==1: - if b[0] in ContentType.all: - ssl2 = False - recordHeaderLength = 5 - elif b[0] == 128: - ssl2 = True - recordHeaderLength = 2 - else: - raise SyntaxError() - if len(b) == recordHeaderLength: - break - - #Parse the record header - if ssl2: - r = RecordHeader2().parse(Parser(b)) - else: - r = RecordHeader3().parse(Parser(b)) - - #Check the record header fields - if r.length > 18432: + try: + for result in self._recordSocket.recv(): + if result in (0, 1): + yield result + else: break + (r, b) = result + except TLSRecordOverflow: for result in self._sendError(AlertDescription.record_overflow): yield result - - #Read the record contents - b = bytearray(0) - while 1: - try: - s = self.sock.recv(r.length - len(b)) - except socket.error as why: - if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): - yield 0 - continue - else: - raise - - #If the connection is closed, raise a socket error - if len(s)==0: - raise TLSAbruptCloseError() - - b += bytearray(s) - if len(b) == r.length: - break + except TLSIllegalParameterException: + for result in self._sendError(AlertDescription.illegal_parameter): + yield result #Check the record header fields (2) #We do this after reading the contents from the socket, so that diff --git a/unit_tests/mocksock.py b/unit_tests/mocksock.py new file mode 100644 index 00000000..5e3d282c --- /dev/null +++ b/unit_tests/mocksock.py @@ -0,0 +1,93 @@ +# Copyright (c) 2015, Hubert Kario +# +# See the LICENSE file for legal information regarding use of this file. + +import socket +import errno +class MockSocket(socket.socket): + def __init__(self, buf, maxRet=None, maxWrite=None, blockEveryOther=False): + # current position in read buffer (buf) + self.index = 0 + # read buffer + self.buf = buf + # write buffer (data sent from application, to be asserted by test) + self.sent = [] + self.closed = False + # maximum number of bytes that socket will read/return at a time + self.maxRet = maxRet + # maximum number of bytes that socket will write at a time + self.maxWrite = maxWrite + # make socket rise errno.EWOULDBLOCK every other read or write + self.blockEveryOther = blockEveryOther + # if next read will be blocked + self.blockRead = False + # if next write will be blocked + self.blockWrite = False + + def __repr__(self): + return "MockSocket(index={0}, buf={1!r}, sent={2!r})".format( + self.index, self.buf, self.sent) + + def recv(self, size): + if self.closed: + raise ValueError("Read from closed socket") + + # simulate a socket with full buffers, make it rise "Would block" + # every other call + if self.blockEveryOther: + if self.blockRead: + self.blockRead = False + raise socket.error(errno.EWOULDBLOCK) + else: + self.blockRead = True + + # return empty array if the caller asked for no data + if size == 0: + return bytearray(0) + + # limit returned data (if set) + # this will cause the socket to return just maxRet bytes, even if it + # has more in buf or was asked to return more in this call + if self.maxRet is not None and self.maxRet < size: + size = self.maxRet + + # don't allow reading past array end + if len(self.buf[self.index:]) == 0: + raise socket.error(errno.EWOULDBLOCK) + # if asked for more than we have prepared, return just as much as we + # have + elif len(self.buf[self.index:]) < size: + ret = self.buf[self.index:] + self.index = len(self.buf) + return ret + # regular call, return as much as was asked for + else: + ret = self.buf[self.index:self.index+size] + self.index+=size + return ret + + def send(self, data): + if self.closed: + raise ValueError("Write to closed socket") + + # simulate a socket with full buffer, raise "Would Block" every other + # call + if self.blockEveryOther: + if self.blockWrite: + self.blockWrite = False + raise socket.error(errno.EWOULDBLOCK) + else: + self.blockWrite = True + + # regular write, just append to list of performed writes + if self.maxWrite is None or len(data) < self.maxWrite: + self.sent.append(data) + return len(data) + + # simulate a socket that won't write more data that it can + # (e.g. because the simulated buffers are mostly full) + self.sent.append(data[:self.maxWrite]) + return self.maxWrite + + def close(self): + self.closed = True diff --git a/unit_tests/test_tlslite_messages.py b/unit_tests/test_tlslite_messages.py index d7587601..ee4f6754 100644 --- a/unit_tests/test_tlslite_messages.py +++ b/unit_tests/test_tlslite_messages.py @@ -8,7 +8,7 @@ except ImportError: import unittest from tlslite.messages import ClientHello, ServerHello, RecordHeader3, Alert, \ - RecordHeader2 + RecordHeader2, Message from tlslite.utils.codec import Parser from tlslite.constants import CipherSuite, CertificateType, ContentType, \ AlertLevel, AlertDescription, ExtensionType @@ -16,6 +16,18 @@ SRPExtension, TLSExtension from tlslite.errors import TLSInternalError +class TestMessage(unittest.TestCase): + def test___init__(self): + msg = Message(ContentType.application_data, bytearray(0)) + + self.assertEqual(ContentType.application_data, msg.contentType) + self.assertEqual(bytearray(0), msg.data) + + def test_write(self): + msg = Message(0, bytearray(10)) + + self.assertEqual(bytearray(10), msg.write()) + class TestClientHello(unittest.TestCase): def test___init__(self): client_hello = ClientHello() diff --git a/unit_tests/test_tlslite_recordlayer.py b/unit_tests/test_tlslite_recordlayer.py new file mode 100644 index 00000000..34e9c9be --- /dev/null +++ b/unit_tests/test_tlslite_recordlayer.py @@ -0,0 +1,299 @@ +# Copyright (c) 2014, Hubert Kario +# +# See the LICENSE file for legal information regarding use of this file. + +# compatibility with Python 2.6, for that we need unittest2 package, +# which is not available on 3.3 or 3.4 +try: + import unittest2 as unittest +except ImportError: + import unittest +try: + import mock + from mock import call +except ImportError: + import unittest.mock as mock + from unittest.mock import call + +import socket +import errno + +from tlslite.messages import Message +from tlslite.recordlayer import RecordSocket +from tlslite.constants import ContentType +from unit_tests.mocksock import MockSocket +from tlslite.errors import TLSRecordOverflow, TLSIllegalParameterException,\ + TLSAbruptCloseError + +class TestRecordSocket(unittest.TestCase): + def test___init__(self): + sock = RecordSocket(-42) + + self.assertIsNotNone(sock) + self.assertEqual(sock.sock, -42) + self.assertEqual(sock.version, (0, 0)) + + def test_send(self): + mockSock = MockSocket(bytearray(0)) + sock = RecordSocket(mockSock) + sock.version = (3, 3) + + msg = Message(ContentType.handshake, bytearray(10)) + + for result in sock.send(msg): + if result in (0, 1): + self.assertTrue(False, "Blocking socket") + else: break + + self.assertEqual(len(mockSock.sent), 1) + self.assertEqual(bytearray( + b'\x16' + # handshake message + b'\x03\x03' + # version + b'\x00\x0a' + # payload length + b'\x00'*10 # payload + ), mockSock.sent[0]) + + def test_send_with_very_slow_socket(self): + mockSock = MockSocket(bytearray(0), maxWrite=1, blockEveryOther=True) + sock = RecordSocket(mockSock) + + msg = Message(ContentType.handshake, bytearray(b'\x32'*2)) + + gotRetry = False + for result in sock.send(msg): + if result in (0, 1): + gotRetry = True + else: break + + self.assertTrue(gotRetry) + self.assertEqual([ + bytearray(b'\x16'), # handshake message + bytearray(b'\x00'), bytearray(b'\x00'), # version (unset) + bytearray(b'\x00'), bytearray(b'\x02'), # payload length + bytearray(b'\x32'), bytearray(b'\x32')], + mockSock.sent) + + def test_send_with_errored_out_socket(self): + mockSock = mock.MagicMock() + mockSock.send.side_effect = socket.error(errno.ETIMEDOUT) + + sock = RecordSocket(mockSock) + + msg = Message(ContentType.handshake, bytearray(10)) + + gen = sock.send(msg) + + with self.assertRaises(socket.error): + next(gen) + + def test_recv(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + )) + sock = RecordSocket(mockSock) + + for result in sock.recv(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + + self.assertEqual(data, bytearray(4)) + self.assertEqual(header.type, ContentType.handshake) + self.assertEqual(header.version, (3, 3)) + self.assertEqual(header.length, 4) + + def test_recv_stops_itelf(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + )) + sock = RecordSocket(mockSock) + + for result in sock.recv(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + + header, data = result + + self.assertEqual(data, bytearray(4)) + self.assertEqual(header.type, ContentType.handshake) + self.assertEqual(header.version, (3, 3)) + self.assertEqual(header.length, 4) + + def test_recv_with_trickling_socket(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + ), maxRet=1) + + sock = RecordSocket(mockSock) + + for result in sock.recv(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + + self.assertEqual(bytearray(4), data) + + def test_recv_with_blocking_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = socket.error(errno.EWOULDBLOCK) + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + self.assertEqual(0, next(gen)) + + def test_recv_with_errored_out_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = socket.error(errno.ETIMEDOUT) + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + with self.assertRaises(socket.error): + next(gen) + + def test_recv_with_empty_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = [bytearray(0)] + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + with self.assertRaises(TLSAbruptCloseError): + next(gen) + + def test_recv_with_slow_socket(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + ), maxRet=1, blockEveryOther=True) + + sock = RecordSocket(mockSock) + + gotRetry = False + for result in sock.recv(): + if result in (0, 1): + gotRetry = True + else: break + + header, data = result + + self.assertTrue(gotRetry) + self.assertEqual(bytearray(4), data) + + def test_recv_with_malformed_record(self): + mockSock = MockSocket(bytearray( + b'\x01' + # wrong type + b'\x03\x03' + # TLSv1.2 + b'\x00\x01' + # length + b'\x00')) + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + with self.assertRaises(TLSIllegalParameterException): + next(gen) + + def test_recv_with_too_big_record(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\xff\xff' + # length + b'\x00'*65536)) + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + with self.assertRaises(TLSRecordOverflow): + next(gen) + + + def test_recv_with_empty_data(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x00')) # length + + sock = RecordSocket(mockSock) + + gen = sock.recv() + + for result in sock.recv(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + + self.assertEqual(ContentType.handshake, header.type) + self.assertEqual((3, 3), header.version) + self.assertEqual(0, header.length) + + self.assertEqual(bytearray(0), data) + + def test_recv_with_SSL2_record(self): + mockSock = MockSocket(bytearray( + b'\x80' + # tag + b'\x04' + # length + b'\x00'*4)) + + sock = RecordSocket(mockSock) + + for result in sock.recv(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + + self.assertTrue(header.ssl2) + self.assertEqual(ContentType.handshake, header.type) + self.assertEqual(4, header.length) + self.assertEqual((2, 0), header.version) + + self.assertEqual(bytearray(4), data) + + def test_recv_with_not_complete_SSL2_record(self): + mockSock = MockSocket(bytearray( + b'\x80' + # tag + b'\x04' + # length + b'\x00'*3)) + + sock = RecordSocket(mockSock) + + for result in sock.recv(): + break + + self.assertEqual(0, result) + + def test_recv_with_SSL2_record_with_incomplete_header(self): + mockSock = MockSocket(bytearray( + b'\x80' # tag + )) + + sock = RecordSocket(mockSock) + + for result in sock.recv(): + break + + self.assertEqual(0, result) diff --git a/unit_tests/test_tlslite_tlsrecordlayer.py b/unit_tests/test_tlslite_tlsrecordlayer.py new file mode 100644 index 00000000..e311808c --- /dev/null +++ b/unit_tests/test_tlslite_tlsrecordlayer.py @@ -0,0 +1,295 @@ +# Copyright (c) 2014, Hubert Kario +# +# See the LICENSE file for legal information regarding use of this file. + +# compatibility with Python 2.6, for that we need unittest2 package, +# which is not available on 3.3 or 3.4 +try: + import unittest2 as unittest +except ImportError: + import unittest +try: + import mock + from mock import call +except ImportError: + import unittest.mock as mock + from unittest.mock import call + +import socket +import errno +from tlslite.tlsrecordlayer import TLSRecordLayer +from tlslite.constants import ContentType +from tlslite.errors import TLSAbruptCloseError, TLSLocalAlert +from tlslite.messages import Message +from unit_tests.mocksock import MockSocket + +class TestTLSRecordLayer(unittest.TestCase): + def test___init__(self): + record_layer = TLSRecordLayer(None) + + self.assertIsNotNone(record_layer) + self.assertIsInstance(record_layer, TLSRecordLayer) + + def test__getNextRecord(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + )) + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + for result in sock._getNextRecord(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + data = data.bytes + + self.assertEqual(data, bytearray(4)) + self.assertEqual(header.type, ContentType.handshake) + self.assertEqual(header.version, (3, 3)) + self.assertEqual(header.length, 4) + + def test__getNextRecord_stops_itelf(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + )) + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + for result in sock._getNextRecord(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + + header, data = result + data = data.bytes + + self.assertEqual(data, bytearray(4)) + self.assertEqual(header.type, ContentType.handshake) + self.assertEqual(header.version, (3, 3)) + self.assertEqual(header.length, 4) + + def test__getNextRecord_with_trickling_socket(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + ), maxRet=1) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + for result in sock._getNextRecord(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + data = data.bytes + + self.assertEqual(bytearray(4), data) + + def test__getNextRecord_with_blocking_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = socket.error(errno.EWOULDBLOCK) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + gen = sock._getNextRecord() + + self.assertEqual(0, next(gen)) + + def test__getNextRecord_with_errored_out_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = socket.error(errno.ETIMEDOUT) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + gen = sock._getNextRecord() + + with self.assertRaises(socket.error): + next(gen) + + def test__getNextRecord_with_empty_socket(self): + mockSock = mock.MagicMock() + mockSock.recv.side_effect = [bytearray(0)] + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + gen = sock._getNextRecord() + + with self.assertRaises(TLSAbruptCloseError): + next(gen) + + def test__getNextRecord_with_slow_socket(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\x00\x04' + # length + b'\x00'*4 + ), maxRet=1, blockEveryOther=True) + + sock = TLSRecordLayer(mockSock) + + gotRetry = False + # XXX using private method! + for result in sock._getNextRecord(): + if result in (0, 1): + gotRetry = True + else: break + + header, data = result + data = data.bytes + + self.assertTrue(gotRetry) + self.assertEqual(bytearray(4), data) + + def test__getNextRecord_with_malformed_record(self): + mockSock = MockSocket(bytearray( + b'\x01' + # wrong type + b'\x03\x03' + # TLSv1.2 + b'\x00\x01' + # length + b'\x00')) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + gen = sock._getNextRecord() + + with self.assertRaises(TLSLocalAlert) as context: + next(gen) + + self.assertEqual(str(context.exception), "illegal_parameter") + + def test__getNextRecord_with_too_big_record(self): + mockSock = MockSocket(bytearray( + b'\x16' + # type - handshake + b'\x03\x03' + # TLSv1.2 + b'\xff\xff' + # length + b'\x00'*65536)) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + gen = sock._getNextRecord() + + with self.assertRaises(TLSLocalAlert) as context: + next(gen) + + self.assertEqual(str(context.exception), "record_overflow") + + def test__getNextRecord_with_SSL2_record(self): + mockSock = MockSocket(bytearray( + b'\x80' + # tag + b'\x04' + # length + b'\x00'*4)) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + for result in sock._getNextRecord(): + if result in (0, 1): + self.assertTrue(False, "blocking socket") + else: break + + header, data = result + data = data.bytes + + self.assertTrue(header.ssl2) + self.assertEqual(ContentType.handshake, header.type) + self.assertEqual(4, header.length) + self.assertEqual((2, 0), header.version) + + self.assertEqual(bytearray(4), data) + + def test__getNextRecord_with_not_complete_SSL2_record(self): + mockSock = MockSocket(bytearray( + b'\x80' + # tag + b'\x04' + # length + b'\x00'*3)) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method! + for result in sock._getNextRecord(): + break + + self.assertEqual(0, result) + + def test__getNextRecord_with_SSL2_record_with_incomplete_header(self): + mockSock = MockSocket(bytearray( + b'\x80' # tag + )) + + sock = TLSRecordLayer(mockSock) + + # XXX using private method + for result in sock._getNextRecord(): + break + + self.assertEqual(0, result) + + def test__sendMsg(self): + mockSock = MockSocket(bytearray(0)) + sock = TLSRecordLayer(mockSock) + sock.version = (3, 3) + + msg = Message(ContentType.handshake, bytearray(10)) + + # XXX using private method + for result in sock._sendMsg(msg, False): + if result in (0, 1): + self.assertTrue(False, "Blocking socket") + else: break + + self.assertEqual(len(mockSock.sent), 1) + self.assertEqual(bytearray( + b'\x16' + # handshake message + b'\x03\x03' + # version + b'\x00\x0a' + # payload length + b'\x00'*10 # payload + ), mockSock.sent[0]) + + def test__sendMsg_with_very_slow_socket(self): + mockSock = MockSocket(bytearray(0), maxWrite=1, blockEveryOther=True) + sock = TLSRecordLayer(mockSock) + + msg = Message(ContentType.handshake, bytearray(b'\x32'*2)) + + gotRetry = False + # XXX using private method! + for result in sock._sendMsg(msg, False): + if result in (0, 1): + gotRetry = True + else: break + + self.assertTrue(gotRetry) + self.assertEqual([ + bytearray(b'\x16'), # handshake message + bytearray(b'\x00'), bytearray(b'\x00'), # version (unset) + bytearray(b'\x00'), bytearray(b'\x02'), # payload length + bytearray(b'\x32'), bytearray(b'\x32')], + mockSock.sent) + + def test__sendMsg_with_errored_out_socket(self): + mockSock = mock.MagicMock() + mockSock.send.side_effect = socket.error(errno.ETIMEDOUT) + + sock = TLSRecordLayer(mockSock) + + msg = Message(ContentType.handshake, bytearray(10)) + + gen = sock._sendMsg(msg, False) + + with self.assertRaises(TLSAbruptCloseError): + next(gen)