2323
2424
2525class XorMaskerSimple :
26- def __init__ (self , masking_key : bytes ) -> None :
26+ def __init__ (self , masking_key : bytearray | bytes ) -> None :
2727 self ._masking_key = masking_key
2828
29- def process (self , data : bytes ) -> bytes :
29+ def process (self , data : bytearray ) -> bytearray :
30+ data = bytearray (data )
3031 if data :
31- data_array = bytearray ( data )
32+ data_array = data
3233 a , b , c , d = (_XOR_TABLE [n ] for n in self ._masking_key )
3334 data_array [::4 ] = data_array [::4 ].translate (a )
3435 data_array [1 ::4 ] = data_array [1 ::4 ].translate (b )
@@ -42,12 +43,12 @@ def process(self, data: bytes) -> bytes:
4243 self ._masking_key [key_rotation :] + self ._masking_key [:key_rotation ]
4344 )
4445
45- return bytes ( data_array )
46+ return data_array
4647 return data
4748
4849
4950class XorMaskerNull :
50- def process (self , data : bytes ) -> bytes :
51+ def process (self , data : bytearray ) -> bytearray :
5152 return data
5253
5354
@@ -271,7 +272,7 @@ def consume_at_most(self, nbytes: int) -> bytearray:
271272 self .bytes_used += len (data )
272273 return data
273274
274- def consume_exactly (self , nbytes : int ) -> bytes | None :
275+ def consume_exactly (self , nbytes : int ) -> bytearray | None :
275276 if len (self .buffer ) - self .bytes_used < nbytes :
276277 return None
277278
@@ -370,11 +371,11 @@ def process_buffer(self) -> Frame | None:
370371 payload = self .masker .process (payload )
371372
372373 for extension in self .extensions :
373- payload_ = extension .frame_inbound_payload_data (self , payload )
374+ payload_ = extension .frame_inbound_payload_data (self , bytes ( payload ) )
374375 if isinstance (payload_ , CloseReason ):
375376 msg = "error in extension"
376377 raise ParseFailed (msg , payload_ )
377- payload = payload_
378+ payload = bytearray ( payload_ )
378379
379380 if finished :
380381 final = bytearray ()
@@ -387,7 +388,7 @@ def process_buffer(self) -> Frame | None:
387388 final += result
388389 payload += final
389390
390- frame = Frame (self .effective_opcode , payload , finished , self .header .fin )
391+ frame = Frame (self .effective_opcode , bytes ( payload ) , finished , self .header .fin )
391392
392393 if finished :
393394 self .header = None
@@ -585,7 +586,7 @@ def received_frames(self) -> Generator[Frame, None, None]:
585586 else :
586587 yield event
587588
588- def close (self , code : int | None = None , reason : str | None = None ) -> bytes :
589+ def close (self , code : int | None = None , reason : str | None = None ) -> bytearray :
589590 payload = bytearray ()
590591 if code is CloseReason .NO_STATUS_RCVD :
591592 code = None
@@ -603,23 +604,23 @@ def close(self, code: int | None = None, reason: str | None = None) -> bytes:
603604
604605 return self ._serialize_frame (Opcode .CLOSE , payload )
605606
606- def ping (self , payload : bytes = b"" ) -> bytes :
607+ def ping (self , payload : bytes = b"" ) -> bytearray :
607608 return self ._serialize_frame (Opcode .PING , payload )
608609
609- def pong (self , payload : bytes = b"" ) -> bytes :
610+ def pong (self , payload : bytes = b"" ) -> bytearray :
610611 return self ._serialize_frame (Opcode .PONG , payload )
611612
612613 def send_data (
613614 self , payload : bytes | bytearray | str = b"" , fin : bool = True ,
614- ) -> bytes :
615+ ) -> bytearray :
615616 if isinstance (payload , (bytes , bytearray , memoryview )):
616617 opcode = Opcode .BINARY
617618 elif isinstance (payload , str ):
618619 opcode = Opcode .TEXT
619620 payload = payload .encode ("utf-8" )
620621 else :
621622 msg = "Must provide bytes or text"
622- raise ValueError (msg )
623+ raise TypeError (msg )
623624
624625 if self ._outbound_opcode is None :
625626 self ._outbound_opcode = opcode
@@ -642,11 +643,13 @@ def _make_fin_rsv_opcode(self, fin: bool, rsv: RsvBits, opcode: Opcode) -> int:
642643 return fin_bits | rsv_bits | opcode_bits
643644
644645 def _serialize_frame (
645- self , opcode : Opcode , payload : bytes = b"" , fin : bool = True ,
646+ self , opcode : Opcode , payload : bytes | bytearray = b"" , fin : bool = True ,
646647 ) -> bytearray :
648+ payload = bytearray (payload )
649+
647650 rsv = RsvBits (False , False , False )
648651 for extension in reversed (self .extensions ):
649- rsv , payload = extension .frame_outbound (self , opcode , rsv , payload , fin )
652+ rsv , payload = extension .frame_outbound (self , opcode , rsv , bytes ( payload ) , fin )
650653
651654 fin_rsv_opcode = self ._make_fin_rsv_opcode (fin , rsv , opcode )
652655
@@ -688,8 +691,8 @@ def _serialize_frame(
688691 # authors of malicious applications from selecting the bytes that
689692 # appear on the wire."
690693 # -- https://tools.ietf.org/html/rfc6455#section-5.3
691- masking_key = os .urandom (4 )
694+ masking_key = bytearray ( os .urandom (4 ) )
692695 masker = XorMaskerSimple (masking_key )
693- return bytearray (header + masking_key + masker .process (payload ))
696+ return bytearray (header + masking_key + masker .process (bytearray ( payload ) ))
694697
695698 return bytearray (header + payload )
0 commit comments