Skip to content

Commit

Permalink
Cache fields initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Valadon authored and Guillaume Valadon committed Oct 24, 2018
1 parent e64b261 commit de07974
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 6 deletions.
91 changes: 85 additions & 6 deletions scapy/packet.py
Expand Up @@ -18,12 +18,12 @@
import types

from scapy.fields import StrField, ConditionalField, Emph, PacketListField, \
BitField, MultiEnumField, EnumField, FlagsField
BitField, MultiEnumField, EnumField, FlagsField, MultipleTypeField
from scapy.config import conf, _version_checker
from scapy.consts import WINDOWS
from scapy.compat import raw, orb
from scapy.base_classes import BasePacket, Gen, SetGen, Packet_metaclass
from scapy.volatile import VolatileValue
from scapy.volatile import VolatileValue, RandField
from scapy.utils import import_hexcap, tex_escape, colgen, get_temp_file, \
issubtype, ContextManagerSubprocess, pretty_list
from scapy.error import Scapy_Exception, log_runtime
Expand Down Expand Up @@ -75,6 +75,11 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)):
payload_guess = []
show_indent = 1
show_summary = True
class_dont_cache = dict()
class_packetfields = dict()
class_default_fields = dict()
class_default_fields_ref = dict()
class_fieldtype = dict()

@classmethod
def from_hexcap(cls):
Expand Down Expand Up @@ -163,7 +168,11 @@ def init_fields(self):
"""
Initialize each fields of the fields_desc dict
"""
self.do_init_fields(self.fields_desc)

if self.class_dont_cache.get(self.__class__, False):
self.do_init_fields(self.fields_desc)
else:
self.do_init_cached_fields()

def do_init_fields(self, flist):
"""
Expand All @@ -175,6 +184,65 @@ def do_init_fields(self, flist):
if f.holds_packets:
self.packetfields.append(f)

def do_init_cached_fields(self):
"""
Initialize each fields of the fields_desc dict, or use the cached
fields information
"""

cls_name = self.__class__

# Build the fields information
if Packet.class_default_fields.get(cls_name, None) is None:
self.prepare_cached_fields(self.fields_desc)

# Use fields information from cache
if not Packet.class_default_fields.get(cls_name, None) is None:
self.default_fields = Packet.class_default_fields[cls_name]
self.fieldtype = Packet.class_fieldtype[cls_name]
self.packetfields = Packet.class_packetfields[cls_name]

# Deepcopy default references
for fname in Packet.class_default_fields_ref[cls_name]:
value = copy.deepcopy(self.default_fields[fname])
setattr(self, fname, value)

def prepare_cached_fields(self, flist):
"""
Prepare the cached fields of the fields_desc dict
"""

cls_name = self.__class__

# Fields cache initialization
if flist:
Packet.class_default_fields[cls_name] = dict()
Packet.class_default_fields_ref[cls_name] = list()
Packet.class_fieldtype[cls_name] = dict()
Packet.class_packetfields[cls_name] = list()

# Fields initialization
for f in flist:
if isinstance(f, MultipleTypeField):
del Packet.class_default_fields[cls_name]
del Packet.class_default_fields_ref[cls_name]
del Packet.class_fieldtype[cls_name]
del Packet.class_packetfields[cls_name]
self.class_dont_cache[cls_name] = True
self.do_init_fields(self.fields_desc)
break

tmp_copy = copy.deepcopy(f.default)
Packet.class_default_fields[cls_name][f.name] = tmp_copy
Packet.class_fieldtype[cls_name][f.name] = f
if f.holds_packets:
Packet.class_packetfields[cls_name].append(f)

# Remember references
if isinstance(f.default, (list, dict, set)) or \
isinstance(f.default, RandField):
Packet.class_default_fields_ref[cls_name].append(f.name)

def dissection_done(self, pkt):
"""DEV: will be called after a dissection is completed"""
self.post_dissection(pkt)
Expand Down Expand Up @@ -340,9 +408,15 @@ def __repr__(self):
if isinstance(f, ConditionalField) and not f._evalcond(self):
continue
if f.name in self.fields:
val = f.i2repr(self, self.fields[f.name])
fval = self.fields[f.name]
if isinstance(fval, (list, dict, set)) and len(fval) == 0:
continue
val = f.i2repr(self, fval)
elif f.name in self.overloaded_fields:
val = f.i2repr(self, self.overloaded_fields[f.name])
fover = self.overloaded_fields[f.name]
if isinstance(fover, (list, dict, set)) and len(fover) == 0:
continue
val = f.i2repr(self, fover)
else:
continue
if isinstance(f, Emph) or f in conf.emph:
Expand Down Expand Up @@ -1323,10 +1397,15 @@ def decode_payload_as(self, cls):
self.payload.dissection_done(pp)

def command(self):
"""Returns a string representing the command you have to type to obtain the same packet""" # noqa: E501
"""
Returns a string representing the command you have to type to
obtain the same packet
"""
f = []
for fn, fv in self.fields.items():
fld = self.get_field(fn)
if isinstance(fv, (list, dict, set)) and len(fv) == 0:
continue
if isinstance(fv, Packet):
fv = fv.command()
elif fld.islist and fld.holds_packets and isinstance(fv, list):
Expand Down
13 changes: 13 additions & 0 deletions test/regression.uts
Expand Up @@ -10506,6 +10506,19 @@ pkt = Test(raw(Test(Values=[0, 0, 0, 0, 1, 1, 1, 1])))
assert(pkt.BitCount == 8)
assert(pkt.ByteCount == 1)

= PacketListField

class TestPacket(Packet):
name = 'TestPacket'
fields_desc = [ PacketListField('list', [], 0) ]

a = TestPacket()
a.list.append(1)
assert(len(a.list) == 1)

b = TestPacket()
assert(len(b.list) == 0)

############
############
+ MPLS tests
Expand Down

0 comments on commit de07974

Please sign in to comment.