Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Record fragmentation handling #46

Merged
merged 5 commits into from
Jul 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions scapy_ssl_tls/ssl_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,18 @@ class TLSKexNames(object):
DHE = "DHE"
ECDHE = "ECDHE"

class TLSFragmentationError(Exception):
pass

class TLSRecord(StackedLenPacket):
MAX_LEN = 2**16 - 1
name = "TLS Record"
fields_desc = [ByteEnumField("content_type", TLSContentType.APPLICATION_DATA, TLS_CONTENT_TYPES),
XShortEnumField("version", TLSVersion.TLS_1_0, TLS_VERSIONS),
XLenField("length", None, fmt="!H"), ]

def __init__(self, *args, **fields):
self.fragments = []
try:
self.tls_ctx = fields["ctx"]
del(fields["ctx"])
Expand All @@ -291,6 +296,25 @@ def guess_payload_class(self, payload):
pass
return cls

def do_build(self):
"""
Taken as is from superclass. Just raises exception when payload can't fit in a TLSRecord
"""
if not self.explicit:
self = self.__iter__().next()
if len(self.payload) > TLSRecord.MAX_LEN:
raise TLSFragmentationError()
pkt = self.self_build()
for t in self.post_transforms:
pkt = t(pkt)
pay = self.do_build_payload()
p = self.post_build(pkt,pay)
return p

def fragment(self, size=2**14):
return tls_fragment_payload(self.payload, self, size)


class TLSServerName(PacketNoPadding):
name = "TLS Servername"
fields_desc = [ByteEnumField("type", 0x00, {0x00:"host"}),
Expand Down Expand Up @@ -725,7 +749,8 @@ def _is_listening(self, socket):
except socket.error as se:
# OSX and BSDs do not support ENOPROTOOPT. Linux and Windows seem to
if se.errno == errno.ENOPROTOOPT:
raise RuntimeError("OS does not support SO_ACCEPTCONN, cannot determine socket state. Please supply an explicit client value (True for client, False for server)")
raise RuntimeError("OS does not support SO_ACCEPTCONN, cannot determine socket state. Please supply an"
"explicit client value (True for client, False for server)")
else:
raise
return True if is_listening != 0 else False
Expand Down Expand Up @@ -759,6 +784,7 @@ def recvall(self, size=8192, timeout=0.5):
records = TLS("".join(resp), ctx=self.tls_ctx)
return records


# entry class
class SSL(Packet):
'''
Expand Down Expand Up @@ -926,6 +952,24 @@ def tls_do_handshake(tls_socket, version, ciphers):
tls_socket.sendall(to_raw(TLSFinished(), tls_socket.tls_ctx))
tls_socket.recvall()

def tls_fragment_payload(pkt, record=None, size=2**14):
if size <= 0:
raise ValueError("Fragment size must be strictly positive")
payload = str(pkt)
payloads = [payload[i: i+size] for i in range(0, len(payload), size)]
if record is None:
return payloads
else:
fragments = []
for payload in payloads:
fragments.append(TLSRecord(content_type=record.content_type, version=record.version, length=len(payload)) /
payload)
try:
stack = TLS.from_records(fragments, ctx=record.tls_ctx)
except struct.error as se:
raise TLSFragmentationError("Fragment size must be a power of 2: %s" % se)
return stack

# bind magic
bind_layers(TCP, SSL, dport=443)
bind_layers(TCP, SSL, sport=443)
Expand All @@ -937,7 +981,6 @@ def tls_do_handshake(tls_socket, version, ciphers):
bind_layers(TLSRecord, TLSCiphertext, {"content_type":TLSContentType.APPLICATION_DATA})
bind_layers(TLSRecord, TLSHeartBeat, {'content_type':TLSContentType.HEARTBEAT})
bind_layers(TLSRecord, TLSAlert, {'content_type':TLSContentType.ALERT})

bind_layers(TLSRecord, TLSHandshake, {'content_type':TLSContentType.HANDSHAKE})

# --> handshake proto
Expand Down
57 changes: 51 additions & 6 deletions tests/test_ssl_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,48 @@ def test_pkt_built_from_stacked_tls_handshakes_is_identical(self):
# check TLS layers one by one
self.assertEqual(re.findall(r'<(TLS[\w]+)',str(repr(self.stacked_handshake))), ['TLSRecord', 'TLSHandshake', 'TLSServerHello',
'TLSHandshake', 'TLSCertificateList', 'TLSCertificate',
'TLSHandshake', 'TLSServerHelloDone'])

'TLSHandshake', 'TLSServerHelloDone'])

def test_fragmentation_fails_on_non_aligned_boundary_for_handshakes(self):
pkt = tls.TLSRecord()/tls.TLSHandshake()/tls.TLSClientHello()
with self.assertRaises(tls.TLSFragmentationError):
pkt.fragment(7)
self.assertIsInstance(pkt.fragment(8), tls.TLS)

def test_fragmenting_a_record_returns_a_list_of_records_when_fragment_size_is_smaller_than_record(self):
frag_size = 3
app_data = "A"*7
pkt = tls.TLSRecord(version=tls.TLSVersion.TLS_1_1, content_type=tls.TLSContentType.APPLICATION_DATA)/app_data
fragments = pkt.fragment(frag_size)
self.assertEqual(len(fragments.records), len(app_data) / frag_size + len(app_data) % frag_size)
record_length = len(tls.TLSRecord())
self.assertTrue(all(list(map(lambda x: x.haslayer(tls.TLSRecord), fragments.records))))
self.assertEqual(len(fragments.records[0]), record_length + frag_size)
self.assertEqual(len(fragments.records[1]), record_length + frag_size)
self.assertEqual(len(fragments.records[2]), record_length + (len(app_data) % frag_size))

def test_fragmenting_a_record_does_nothing_when_fragment_size_is_larger_than_record(self):
app_data = "A"*7
frag_size = len(app_data)
pkt = tls.TLSRecord(version=tls.TLSVersion.TLS_1_1, content_type=tls.TLSContentType.APPLICATION_DATA)/app_data
self.assertEqual(str(pkt), str(pkt.fragment(frag_size)))
frag_size = len(app_data) * 2
self.assertEqual(str(pkt), str(pkt.fragment(frag_size)))

def test_large_record_payload_is_not_fragmented_when_smaller_then_max_ushort(self):
app_data = "A"*tls.TLSRecord.MAX_LEN
pkt = tls.TLSRecord(version=tls.TLSVersion.TLS_1_1, content_type=tls.TLSContentType.APPLICATION_DATA)/app_data
try:
str(pkt)
except tls.TLSFragmentationError:
self.fail()

def test_large_record_payload_is_fragmented_when_above_max_ushort(self):
app_data = "A"*(tls.TLSRecord.MAX_LEN + 1)
pkt = tls.TLSRecord(version=tls.TLSVersion.TLS_1_1, content_type=tls.TLSContentType.APPLICATION_DATA)/app_data
with self.assertRaises(tls.TLSFragmentationError):
str(pkt)

class TestTLSDissector(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -232,7 +272,6 @@ def setUp(self):
],)
unittest.TestCase.setUp(self)


def test_dissect_contains_client_hello(self):
p = tls.SSL(str(self.pkt))
self.assertEqual(len(p.records),1)
Expand Down Expand Up @@ -407,7 +446,8 @@ def test_pcap_record_order(self):
# check if there are any more pakets?
with self.assertRaises(IndexError):
record = pkts.pop()



class TestToRaw(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -652,8 +692,13 @@ def test_tls_certificate_x509_pubkey(self):
self.assertTrue(len(ciphertext))
self.assertEqual(ciphertext,ciphertext_2)

class TestTLSKeyExchange(unittest.TestCase):
pass

class TestTLSTopLevelFunctions(unittest.TestCase):

def test_tls_payload_fragmentation_raises_error_with_negative_size(self):
with self.assertRaises(ValueError):
tls.tls_fragment_payload("AAAA", size=-1)


if __name__ == "__main__":
unittest.main()