Skip to content

Commit

Permalink
Support bytearray and memoryview as raw payload
Browse files Browse the repository at this point in the history
  • Loading branch information
alexforencich authored and gpotter2 committed Dec 20, 2020
1 parent 1ffa744 commit 2a8733a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
14 changes: 7 additions & 7 deletions scapy/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,10 @@ def add_payload(self, payload):
if t in payload.overload_fields:
self.overloaded_fields = payload.overload_fields[t]
break
elif isinstance(payload, bytes):
self.payload = conf.raw_layer(load=payload)
elif isinstance(payload, (bytes, str, bytearray, memoryview)):
self.payload = conf.raw_layer(load=bytes_encode(payload))
else:
raise TypeError("payload must be either 'Packet' or 'bytes', not [%s]" % repr(payload)) # noqa: E501
raise TypeError("payload must be 'Packet', 'bytes', 'str', 'bytearray', or 'memoryview', not [%s]" % repr(payload)) # noqa: E501

def remove_payload(self):
# type: () -> None
Expand Down Expand Up @@ -577,16 +577,16 @@ def __div__(self, other):
cloneB = other.copy()
cloneA.add_payload(cloneB)
return cloneA
elif isinstance(other, (bytes, str, bytearray)):
return self / conf.raw_layer(load=other)
elif isinstance(other, (bytes, str, bytearray, memoryview)):
return self / conf.raw_layer(load=bytes_encode(other))
else:
return other.__rdiv__(self) # type: ignore
__truediv__ = __div__

def __rdiv__(self, other):
# type: (Any) -> Packet
if isinstance(other, (bytes, str, bytearray)):
return conf.raw_layer(load=other) / self
if isinstance(other, (bytes, str, bytearray, memoryview)):
return conf.raw_layer(load=bytes_encode(other)) / self
else:
raise TypeError
__rtruediv__ = __rdiv__
Expand Down
12 changes: 12 additions & 0 deletions test/regression.uts
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,18 @@ a=3
assert bytes(Raw("sca")/"py") == b"scapy"
assert bytes(Raw("sca")/b"py") == b"scapy"
assert bytes(Raw("sca")/bytearray(b"py")) == b"scapy"
assert bytes("sca"/Raw("py")) == b"scapy"
assert bytes(b"sca"/Raw("py")) == b"scapy"
assert bytes(bytearray(b"sca")/Raw("py")) == b"scapy"
a=Raw("sca")
a.add_payload("py")
assert bytes(a) == b"scapy"
a=Raw("sca")
a.add_payload(b"py")
assert bytes(a) == b"scapy"
a=Raw("sca")
a.add_payload(bytearray(b"py"))
assert bytes(a) == b"scapy"

= Checking overloads
~ basic IP TCP Ether
Expand Down

0 comments on commit 2a8733a

Please sign in to comment.