Skip to content

Commit

Permalink
Merge pull request #5 from tomato42/record-socket-refactor-5
Browse files Browse the repository at this point in the history
[v5] Small TLSRecordLayer refactor

Prepare TLSRecordLayer for moving encryption and decryption to external class.

unit test coverage for sending and receiving records through socket, minimal test coverage for TLSRecordLayer (using private methods).
  • Loading branch information
tomato42 committed Jun 2, 2015
2 parents d3a2493 + 7e4b64e commit 7395199
Show file tree
Hide file tree
Showing 8 changed files with 1,047 additions and 137 deletions.
27 changes: 25 additions & 2 deletions tlslite/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

from .constants import AlertDescription, AlertLevel

class TLSError(Exception):
class BaseTLSException(Exception):
"""Metaclass for TLS Lite exceptions.
Look to L{TLSError} for exceptions that should be caught by tlslite
consumers
"""
pass

class TLSError(BaseTLSException):
"""Base class for all TLS Lite exceptions."""

def __str__(self):
Expand Down Expand Up @@ -173,5 +181,20 @@ class TLSUnsupportedError(TLSError):
pass

class TLSInternalError(TLSError):
"""The internal state of object is unexpected or invalid"""
"""The internal state of object is unexpected or invalid.
Caused by incorrect use of API.
"""
pass

class TLSProtocolException(BaseTLSException):
"""Exceptions used internally for handling errors in received messages"""
pass

class TLSIllegalParameterException(TLSProtocolException):
"""Parameters specified in message were incorrect or invalid"""
pass

class TLSRecordOverflow(TLSProtocolException):
"""The received record size was too big"""
pass
86 changes: 60 additions & 26 deletions tlslite/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,53 @@
from .utils.tackwrapper import *
from .extensions import *

class RecordHeader3(object):
def __init__(self):
class RecordHeader(object):

"""Generic interface to SSLv2 and SSLv3 (and later) record headers"""

def __init__(self, ssl2):
"""define instance variables"""
self.type = 0
self.version = (0,0)
self.version = (0, 0)
self.length = 0
self.ssl2 = False
self.ssl2 = ssl2

class RecordHeader3(RecordHeader):

"""SSLv3 (and later) TLS record header"""

def __init__(self):
"""Define a SSLv3 style class"""
super(RecordHeader3, self).__init__(ssl2=False)

def create(self, version, type, length):
"""Set object values for writing (serialisation)"""
self.type = type
self.version = version
self.length = length
return self

def write(self):
w = Writer()
w.add(self.type, 1)
w.add(self.version[0], 1)
w.add(self.version[1], 1)
w.add(self.length, 2)
return w.bytes

def parse(self, p):
self.type = p.get(1)
self.version = (p.get(1), p.get(1))
self.length = p.get(2)
"""Serialise object to bytearray"""
writer = Writer()
writer.add(self.type, 1)
writer.add(self.version[0], 1)
writer.add(self.version[1], 1)
writer.add(self.length, 2)
return writer.bytes

def parse(self, parser):
"""Deserialise object from Parser"""
self.type = parser.get(1)
self.version = (parser.get(1), parser.get(1))
self.length = parser.get(2)
self.ssl2 = False
return self

@property
def type_name(self):
matching = [x[0] for x in ContentType.__dict__.items()
if x[1] == self.type]
if x[1] == self.type]
if len(matching) == 0:
return "unknown(" + str(self.type) + ")"
else:
Expand All @@ -66,22 +81,41 @@ def __repr__(self):
return "RecordHeader3(type={0}, version=({1[0]}.{1[1]}), length={2})".\
format(self.type, self.version, self.length)

class RecordHeader2(object):
class RecordHeader2(RecordHeader):
"""SSLv2 record header (just reading)"""
def __init__(self):
self.type = 0
self.version = (0,0)
self.length = 0
self.ssl2 = True
"""Define a SSLv2 style class"""
super(RecordHeader2, self).__init__(ssl2=True)

def parse(self, p):
if p.get(1)!=128:
def parse(self, parser):
"""Deserialise object from Parser"""
if parser.get(1) != 128:
raise SyntaxError()
self.type = ContentType.handshake
self.version = (2,0)
#We don't support 2-byte-length-headers; could be a problem
self.length = p.get(1)
self.version = (2, 0)
#XXX We don't support 2-byte-length-headers; could be a problem
self.length = parser.get(1)
return self

class Message(object):

"""Generic TLS message"""

def __init__(self, contentType, data):
"""
Initialize object with specified contentType and data
@type contentType: int
@param contentType: TLS record layer content type of associated data
@type data: bytearray
@param data: data
"""
self.contentType = contentType
self.data = data

def write(self):
"""Return serialised object data"""
return self.data

class Alert(object):
def __init__(self):
Expand Down
196 changes: 196 additions & 0 deletions tlslite/recordlayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright (c) 2014, Hubert Kario
#
# See the LICENSE file for legal information regarding use of this file.

"""Implementation of the TLS Record Layer protocol"""

import socket
import errno
from tlslite.constants import ContentType
from .messages import RecordHeader3, RecordHeader2
from .utils.codec import Parser
from .errors import TLSRecordOverflow, TLSIllegalParameterException,\
TLSAbruptCloseError

class RecordSocket(object):

"""Socket wrapper for reading and writing TLS Records"""

def __init__(self, sock):
"""
Assign socket to wrapper
@type sock: socket.socket
"""
self.sock = sock
self.version = (0, 0)

def _sockSendAll(self, data):
"""
Send all data through socket
@type data: bytearray
@param data: data to send
@raise socket.error: when write to socket failed
"""
while 1:
try:
bytesSent = self.sock.send(data)
except socket.error as why:
if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
yield 1
continue
raise

if bytesSent == len(data):
return
data = data[bytesSent:]
yield 1

def send(self, msg):
"""
Send the message through socket.
@type msg: bytearray
@param msg: TLS message to send
@raise socket.error: when write to socket failed
"""

data = msg.write()

header = RecordHeader3().create(self.version,
msg.contentType,
len(data))

data = header.write() + data

for result in self._sockSendAll(data):
yield result

def _sockRecvAll(self, length):
"""
Read exactly the amount of bytes specified in L{length} from raw socket.
@rtype: generator
@return: generator that will return 0 or 1 in case the socket is non
blocking and would block and bytearray in case the read finished
@raise TLSAbruptCloseError: when the socket closed
"""

buf = bytearray(0)

if length == 0:
yield buf

while True:
try:
socketBytes = self.sock.recv(length - len(buf))
except socket.error as why:
if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
yield 0
continue
else:
raise

#if the connection closed, raise socket error
if len(socketBytes) == 0:
raise TLSAbruptCloseError()

buf += bytearray(socketBytes)
if len(buf) == length:
yield buf

def _recvHeader(self):
"""Read a single record header from socket"""
#Read the next record header
buf = bytearray(0)
ssl2 = False

result = None
for result in self._sockRecvAll(1):
if result in (0, 1):
yield result
else: break
assert result is not None

buf += result

if buf[0] in ContentType.all:
ssl2 = False
# SSLv3 record layer header is 5 bytes long, we already read 1
result = None
for result in self._sockRecvAll(4):
if result in (0, 1):
yield result
else: break
assert result is not None
buf += result
# XXX this should be 'buf[0] & 128', otherwise hello messages longer
# than 127 bytes won't be properly parsed
elif buf[0] == 128:
ssl2 = True
# in SSLv2 we need to read 2 bytes in total to know the size of
# header, we already read 1
result = None
for result in self._sockRecvAll(1):
if result in (0, 1):
yield result
else: break
assert result is not None
buf += result
else:
raise TLSIllegalParameterException(
"Record header type doesn't specify known type")

#Parse the record header
if ssl2:
record = RecordHeader2().parse(Parser(buf))
else:
record = RecordHeader3().parse(Parser(buf))

yield record

def recv(self):
"""
Read a single record from socket, handles both SSLv2 and SSLv3 record
layer
@rtype: generator
@return: generator that returns 0 or 1 in case the read would be
blocking or a tuple containing record header (object) and record
data (bytearray) read from socket
@raise socket.error: In case of network error
@raise TLSAbruptCloseError: When the socket was closed on the other
side in middle of record receiving
@raise TLSRecordOverflow: When the received record was longer than
allowed by TLS
@raise TLSIllegalParameterException: When the record header was
malformed
"""

record = None
for record in self._recvHeader():
if record in (0, 1):
yield record
else: break
assert record is not None

#Check the record header fields
# 18432 = 2**14 (basic record size limit) + 1024 (maximum compression
# overhead) + 1024 (maximum encryption overhead)
if record.length > 18432:
raise TLSRecordOverflow()

#Read the record contents
buf = bytearray(0)

result = None
for result in self._sockRecvAll(record.length):
if result in (0, 1):
yield result
else: break
assert result is not None

buf += result

yield (record, buf)
Loading

0 comments on commit 7395199

Please sign in to comment.