Skip to content

Commit

Permalink
Merge pull request #263 from tomato42/cookie-parser
Browse files Browse the repository at this point in the history
Cookie extension parser
  • Loading branch information
tomato42 committed Jun 1, 2018
2 parents 5f6110b + c1d10e9 commit 1ef5b0a
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 86 deletions.
225 changes: 140 additions & 85 deletions tlslite/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,28 +284,25 @@ def __repr__(self):
self.encExtType)


class ListExtension(TLSExtension):
class CustomNameExtension(TLSExtension):
"""
Abstract class for extensions that deal with single list in payload.
Abstract class for handling custom name for payload variable.
Extension for handling arbitrary extensions comprising of just a list
of same-sized elementes inside an array
"""
Used for handling arbitrary extensions that deal with a single value
in payload (either an opaque data, single value or single array).
def __init__(self, fieldName, extType, item_enum=None):
Must be subclassed.
"""
def __init__(self, field_name, extType):
"""
Create instance of the class.
:param str fieldName: name of the field to store the list that is
the payload
:type int extType: numerical ID of the extension
:param class item_enum: TLSEnum class that defines the enum of the
items in the list
:param str field_name: name of the field to store the extension payload
:param int ext_type: numerical ID of the extension
"""
super(ListExtension, self).__init__(extType=extType)
self._fieldName = fieldName
self._internalList = None
self._item_enum = item_enum
super(CustomNameExtension, self).__init__(extType=extType)
self._field_name = field_name
self._internal_value = None

@property
def extData(self):
Expand All @@ -322,7 +319,7 @@ def create(self, values):
:param list values: list of values to save
"""
self._internalList = values
self._internal_value = values
return self

def parse(self, parser):
Expand All @@ -336,36 +333,119 @@ def parse(self, parser):

def __getattr__(self, name):
"""Return the special field name value."""
if name == '_fieldName':
raise AttributeError("type object '{0}' has no attribute '{1}'"\
.format(self.__class__.__name__, name))
if name == self._fieldName:
return self._internalList
raise AttributeError("type object '{0}' has no attribute '{1}'"\
if name == '_field_name':
raise AttributeError(
"type object '{0}' has no attribute '{1}'"
.format(self.__class__.__name__, name))
if name == self._field_name:
return self._internal_value
raise AttributeError(
"type object '{0}' has no attribute '{1}'"
.format(self.__class__.__name__, name))

def __setattr__(self, name, value):
"""Set the special field value."""
if name == '_fieldName':
super(ListExtension, self).__setattr__(name, value)
return
if hasattr(self, '_fieldName') and name == self._fieldName:
self._internalList = value
if hasattr(self, '_field_name') and name == self._field_name:
self._internal_value = value
return
super(ListExtension, self).__setattr__(name, value)
super(CustomNameExtension, self).__setattr__(name, value)


class VarBytesExtension(CustomNameExtension):
"""
Abstract class for extension that deal with single byte array payload.
Extension for handling arbitrary extensions that comprise of just
a single bytearray of variable size.
"""

def __init__(self, field_name, length_length, ext_type):
"""
Crate instance of the class.
:param str field_name: name of the field to store the bytearray
that is the payload
:param int length_length: number of bytes needed to encode the length
field of the string
:param int ext_type: numerical ID of the extension
"""
super(VarBytesExtension, self).__init__(field_name, extType=ext_type)
self._length_length = length_length

@property
def extData(self):
"""
Return raw data encoding of the extension.
:rtype: bytearray
"""
if self._internal_value is None:
return bytearray(0)

writer = Writer()
writer.add_var_bytes(self._internal_value, self._length_length)
return writer.bytes

def parse(self, parser):
"""Deserialise extension from on-the-wire data.
:param tlslite.utils.codec.Parser parser: data
:rtype TLSExtension
"""
if not parser.getRemainingLength():
self._internal_value = None
return self

self._internal_value = parser.getVarBytes(self._length_length)

if parser.getRemainingLength():
raise SyntaxError("Extra data after extension payload")

return self

def __repr__(self):
"""Return human readable representation of the extension."""
if self._internal_value is not None:
return "{0}(len({1})={2})".format(self.__class__.__name__,
self._field_name,
len(self._internal_value))
return "{0}({1}=None)".format(self.__class__.__name__,
self._field_name)


class ListExtension(CustomNameExtension):
"""
Abstract class for extensions that deal with single list in payload.
Extension for handling arbitrary extensions comprising of just a list
of same-sized elementes inside an array
"""

def __init__(self, fieldName, extType, item_enum=None):
"""
Create instance of the class.
:param str fieldName: name of the field to store the list that is
the payload
:param int extType: numerical ID of the extension
:param class item_enum: TLSEnum class that defines the enum of the
items in the list
"""
super(ListExtension, self).__init__(fieldName, extType=extType)
self._item_enum = item_enum

def _list_to_repr(self):
"""Return human redable representation of the item list"""
if not self._internalList or not self._item_enum:
return "{0!r}".format(self._internalList)
if not self._internal_value or not self._item_enum:
return "{0!r}".format(self._internal_value)

return "[{0}]".format(
", ".join(self._item_enum.toStr(i) for i in self._internalList))
", ".join(self._item_enum.toStr(i) for i in self._internal_value))

def __repr__(self):
"""Return human readable representation of the extension."""
return "{0}({1}={2})".format(self.__class__.__name__,
self._fieldName,
self._field_name,
self._list_to_repr())


Expand Down Expand Up @@ -399,11 +479,11 @@ def extData(self):
:rtype: bytearray
"""
if self._internalList is None:
if self._internal_value is None:
return bytearray(0)

writer = Writer()
writer.addVarSeq(self._internalList,
writer.addVarSeq(self._internal_value,
self._elemLength,
self._lengthLength)
return writer.bytes
Expand All @@ -416,11 +496,11 @@ def parse(self, parser):
:rtype: Extension
"""
if parser.getRemainingLength() == 0:
self._internalList = None
self._internal_value = None
return self

self._internalList = parser.getVarList(self._elemLength,
self._lengthLength)
self._internal_value = parser.getVarList(self._elemLength,
self._lengthLength)

if parser.getRemainingLength():
raise SyntaxError()
Expand Down Expand Up @@ -463,11 +543,11 @@ def extData(self):
:rtype: bytearray
"""
if self._internalList is None:
if self._internal_value is None:
return bytearray(0)

writer = Writer()
writer.addVarTupleSeq(self._internalList,
writer.addVarTupleSeq(self._internal_value,
self._elemLength,
self._lengthLength)
return writer.bytes
Expand All @@ -480,12 +560,12 @@ def parse(self, parser):
:rtype: Extension
"""
if parser.getRemainingLength() == 0:
self._internalList = None
self._internal_value = None
return self

self._internalList = parser.getVarTupleList(self._elemLength,
self._elemNum,
self._lengthLength)
self._internal_value = parser.getVarTupleList(self._elemLength,
self._elemNum,
self._lengthLength)
if parser.getRemainingLength():
raise SyntaxError()

Expand Down Expand Up @@ -1389,7 +1469,8 @@ def parse(self, p):
self.paddingData = p.getFixBytes(p.getRemainingLength())
return self

class RenegotiationInfoExtension(TLSExtension):

class RenegotiationInfoExtension(VarBytesExtension):
"""
Client and Server Hello secure renegotiation extension from RFC 5746
Expand All @@ -1400,46 +1481,10 @@ class RenegotiationInfoExtension(TLSExtension):
def __init__(self):
"""Create instance"""
extType = ExtensionType.renegotiation_info
super(RenegotiationInfoExtension, self).__init__(extType=extType)
self.renegotiated_connection = None

@property
def extData(self):
"""
Return raw encoding of the extension.
:rtype: bytearray
"""
if self.renegotiated_connection is None:
return bytearray(0)
writer = Writer()
writer.add(len(self.renegotiated_connection), 1)
writer.bytes += self.renegotiated_connection
return writer.bytes

def create(self, renegotiated_connection):
"""
Set the finished message payload from previous connection.
:param bytearray renegotiated_connection: data
"""
self.renegotiated_connection = renegotiated_connection
return self

def parse(self, parser):
"""
Deserialise extension from on the wire data.
:param Parser parser: data to be parsed
:rtype: RenegotiationInfoExtension
"""
if parser.getRemainingLength() == 0:
self.renegotiated_connection = None
else:
self.renegotiated_connection = parser.getVarBytes(1)

return self
super(RenegotiationInfoExtension, self).__init__(
'renegotiated_connection',
1,
extType)


class ALPNExtension(TLSExtension):
Expand Down Expand Up @@ -2013,6 +2058,15 @@ def __init__(self):
PskKeyExchangeMode)


class CookieExtension(VarBytesExtension):
"""Handling of the TLS 1.3 cookie extension."""

def __init__(self):
"""Create instance."""
ext_type = ExtensionType.cookie
super(CookieExtension, self).__init__('cookie', 2, ext_type)


TLSExtension._universalExtensions = \
{
ExtensionType.server_name: SNIExtension,
Expand All @@ -2031,7 +2085,8 @@ def __init__(self):
ExtensionType.signature_algorithms_cert:
SignatureAlgorithmsCertExtension,
ExtensionType.pre_shared_key: PreSharedKeyExtension,
ExtensionType.psk_key_exchange_modes: PskKeyExchangeModesExtension}
ExtensionType.psk_key_exchange_modes: PskKeyExchangeModesExtension,
ExtensionType.cookie: CookieExtension}

TLSExtension._serverExtensions = \
{
Expand Down
18 changes: 18 additions & 0 deletions tlslite/utils/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ def addVarTupleSeq(self, seq, length, lengthLength):
if startPos + dataLength + lengthLength != len(self.bytes):
raise ValueError("Tuples of different lengths")

def add_var_bytes(self, data, length_length):
"""
Add a variable length array of bytes.
Inverse of Parser.getVarBytes()
:type data: bytes
:param data: bytes to add to the buffer
:param int length_length: size of the field to represent the length
of the data string
"""
length = len(data)
self.add(length, length_length)
self.bytes += data


class Parser(object):
"""
Expand Down Expand Up @@ -307,6 +323,8 @@ def getVarBytes(self, lengthLength):
"""
Read a variable length string with a fixed length.
see Writer.add_var_bytes() for an inverse of this method
:type lengthLength: int
:param lengthLength: number of bytes in which the length of the string
is encoded in
Expand Down
Loading

0 comments on commit 1ef5b0a

Please sign in to comment.