Skip to content

Commit 230ac6b

Browse files
committed
improve type handling for bytearray vs. bytes
fixes #185
1 parent 379bf4c commit 230ac6b

File tree

5 files changed

+30
-30
lines changed

5 files changed

+30
-30
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Release History
1010
- Add docs for `wsproto.Connection`
1111
- Add support for Python 3.12, 3.13, and 3.14.
1212
- Drop support for Python 3.7, 3.8, and 3.9.
13+
- Improve Python typing, specifically bytes vs. bytearray.
1314
- Various linting, styling, and packaging improvements.
1415

1516
1.2.0 (2022-08-23)

src/wsproto/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def events(self) -> Generator[Event, None, None]:
191191
elif frame.opcode is Opcode.BINARY:
192192
assert isinstance(frame.payload, (bytes, bytearray))
193193
yield BytesMessage(
194-
data=frame.payload,
194+
data=bytearray(frame.payload),
195195
frame_finished=frame.frame_finished,
196196
message_finished=frame.message_finished,
197197
)

src/wsproto/events.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def response(self) -> CloseConnection:
193193
return CloseConnection(code=self.code, reason=self.reason)
194194

195195

196-
T = TypeVar("T", bytes, str)
196+
T = TypeVar("T", bytes, bytearray, str)
197197

198198

199199
@dataclass(frozen=True)
@@ -231,7 +231,7 @@ class Message(Event, Generic[T]):
231231
@dataclass(frozen=True)
232232
class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object
233233
"""
234-
This event is fired when a data frame with TEXT payload is received.
234+
Fired when a data frame with TEXT payload is received.
235235
236236
Fields:
237237
@@ -240,31 +240,27 @@ class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object
240240
The message data as string, This only represents a single chunk
241241
of data and not a full WebSocket message. You need to buffer
242242
and reassemble these chunks to get the full message.
243-
244243
"""
245244

246-
# https://github.com/python/mypy/issues/5744
247245
data: str
248246

249247

250248
@dataclass(frozen=True)
251-
class BytesMessage(Message[bytes]): # pylint: disable=unsubscriptable-object
249+
class BytesMessage(Message[bytearray]): # pylint: disable=unsubscriptable-object
252250
"""
253-
This event is fired when a data frame with BINARY payload is
254-
received.
251+
Fired when a data frame with BINARY payload is received.
255252
256253
Fields:
257254
258255
.. attribute:: data
259256
260-
The message data as byte string, can be decoded as UTF-8 for
257+
The message data as bytearray, can be decoded as UTF-8 for
261258
TEXT messages. This only represents a single chunk of data and
262259
not a full WebSocket message. You need to buffer and
263260
reassemble these chunks to get the full message.
264261
"""
265262

266-
# https://github.com/python/mypy/issues/5744
267-
data: bytes
263+
data: bytearray
268264

269265

270266
@dataclass(frozen=True)

src/wsproto/frame_protocol.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323

2424

2525
class XorMaskerSimple:
26-
def __init__(self, masking_key: bytes) -> None:
26+
def __init__(self, masking_key: bytearray | bytes) -> None:
2727
self._masking_key = masking_key
2828

29-
def process(self, data: bytes) -> bytes:
29+
def process(self, data: bytearray) -> bytearray:
30+
data = bytearray(data)
3031
if data:
31-
data_array = bytearray(data)
32+
data_array = data
3233
a, b, c, d = (_XOR_TABLE[n] for n in self._masking_key)
3334
data_array[::4] = data_array[::4].translate(a)
3435
data_array[1::4] = data_array[1::4].translate(b)
@@ -42,12 +43,12 @@ def process(self, data: bytes) -> bytes:
4243
self._masking_key[key_rotation:] + self._masking_key[:key_rotation]
4344
)
4445

45-
return bytes(data_array)
46+
return data_array
4647
return data
4748

4849

4950
class XorMaskerNull:
50-
def process(self, data: bytes) -> bytes:
51+
def process(self, data: bytearray) -> bytearray:
5152
return data
5253

5354

@@ -271,7 +272,7 @@ def consume_at_most(self, nbytes: int) -> bytearray:
271272
self.bytes_used += len(data)
272273
return data
273274

274-
def consume_exactly(self, nbytes: int) -> bytes | None:
275+
def consume_exactly(self, nbytes: int) -> bytearray | None:
275276
if len(self.buffer) - self.bytes_used < nbytes:
276277
return None
277278

@@ -370,11 +371,11 @@ def process_buffer(self) -> Frame | None:
370371
payload = self.masker.process(payload)
371372

372373
for extension in self.extensions:
373-
payload_ = extension.frame_inbound_payload_data(self, payload)
374+
payload_ = extension.frame_inbound_payload_data(self, bytes(payload))
374375
if isinstance(payload_, CloseReason):
375376
msg = "error in extension"
376377
raise ParseFailed(msg, payload_)
377-
payload = payload_
378+
payload = bytearray(payload_)
378379

379380
if finished:
380381
final = bytearray()
@@ -387,7 +388,7 @@ def process_buffer(self) -> Frame | None:
387388
final += result
388389
payload += final
389390

390-
frame = Frame(self.effective_opcode, payload, finished, self.header.fin)
391+
frame = Frame(self.effective_opcode, bytes(payload), finished, self.header.fin)
391392

392393
if finished:
393394
self.header = None
@@ -585,7 +586,7 @@ def received_frames(self) -> Generator[Frame, None, None]:
585586
else:
586587
yield event
587588

588-
def close(self, code: int | None = None, reason: str | None = None) -> bytes:
589+
def close(self, code: int | None = None, reason: str | None = None) -> bytearray:
589590
payload = bytearray()
590591
if code is CloseReason.NO_STATUS_RCVD:
591592
code = None
@@ -603,23 +604,23 @@ def close(self, code: int | None = None, reason: str | None = None) -> bytes:
603604

604605
return self._serialize_frame(Opcode.CLOSE, payload)
605606

606-
def ping(self, payload: bytes = b"") -> bytes:
607+
def ping(self, payload: bytes = b"") -> bytearray:
607608
return self._serialize_frame(Opcode.PING, payload)
608609

609-
def pong(self, payload: bytes = b"") -> bytes:
610+
def pong(self, payload: bytes = b"") -> bytearray:
610611
return self._serialize_frame(Opcode.PONG, payload)
611612

612613
def send_data(
613614
self, payload: bytes | bytearray | str = b"", fin: bool = True,
614-
) -> bytes:
615+
) -> bytearray:
615616
if isinstance(payload, (bytes, bytearray, memoryview)):
616617
opcode = Opcode.BINARY
617618
elif isinstance(payload, str):
618619
opcode = Opcode.TEXT
619620
payload = payload.encode("utf-8")
620621
else:
621622
msg = "Must provide bytes or text"
622-
raise ValueError(msg)
623+
raise TypeError(msg)
623624

624625
if self._outbound_opcode is None:
625626
self._outbound_opcode = opcode
@@ -642,11 +643,13 @@ def _make_fin_rsv_opcode(self, fin: bool, rsv: RsvBits, opcode: Opcode) -> int:
642643
return fin_bits | rsv_bits | opcode_bits
643644

644645
def _serialize_frame(
645-
self, opcode: Opcode, payload: bytes = b"", fin: bool = True,
646+
self, opcode: Opcode, payload: bytes | bytearray = b"", fin: bool = True,
646647
) -> bytearray:
648+
payload = bytearray(payload)
649+
647650
rsv = RsvBits(False, False, False)
648651
for extension in reversed(self.extensions):
649-
rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, fin)
652+
rsv, payload = extension.frame_outbound(self, opcode, rsv, bytes(payload), fin)
650653

651654
fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
652655

@@ -688,8 +691,8 @@ def _serialize_frame(
688691
# authors of malicious applications from selecting the bytes that
689692
# appear on the wire."
690693
# -- https://tools.ietf.org/html/rfc6455#section-5.3
691-
masking_key = os.urandom(4)
694+
masking_key = bytearray(os.urandom(4))
692695
masker = XorMaskerSimple(masking_key)
693-
return bytearray(header + masking_key + masker.process(payload))
696+
return bytearray(header + masking_key + masker.process(bytearray(payload)))
694697

695698
return bytearray(header + payload)

tests/test_frame_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ def test_data_we_have_no_idea_what_to_do_with(self) -> None:
11971197
proto = fp.FrameProtocol(client=False, extensions=[])
11981198
payload: Dict[str, str] = dict()
11991199

1200-
with pytest.raises(ValueError):
1200+
with pytest.raises(TypeError):
12011201
# Intentionally passing illegal type.
12021202
proto.send_data(payload) # type: ignore
12031203

0 commit comments

Comments
 (0)