diff --git a/wsproto/frame_protocol.py b/wsproto/frame_protocol.py index 7406eac..081f6ed 100644 --- a/wsproto/frame_protocol.py +++ b/wsproto/frame_protocol.py @@ -32,7 +32,25 @@ def process(self, data): # RFC6455, Section 5.2 - Base Framing Protocol -MAX_FRAME_PAYLOAD = 2 ** 64 + +# Payload length constants +PAYLOAD_LENGTH_TWO_BYTE = 126 +PAYLOAD_LENGTH_EIGHT_BYTE = 127 +MAX_PAYLOAD_NORMAL = 125 +MAX_PAYLOAD_TWO_BYTE = 2 ** 16 +MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 +MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE + +# MASK and PAYLOAD LEN are packed into a byte +MASK_MASK = 0x80 +PAYLOAD_LEN_MASK = 0x7f + +# FIN, RSV[123] and OPCODE are packed into a single byte +FIN_MASK = 0x80 +RSV1_MASK = 0x40 +RSV2_MASK = 0x20 +RSV3_MASK = 0x10 +OPCODE_MASK = 0x0f class Opcode(IntEnum): @@ -78,6 +96,17 @@ class CloseReason(IntEnum): ) +# RFC 6455, Section 7.4.2 - Status Code Ranges +MIN_CLOSE_REASON = 1000 +MIN_PROTOCOL_CLOSE_REASON = 1000 +MAX_PROTOCOL_CLOSE_REASON = 2999 +MIN_LIBRARY_CLOSE_REASON = 3000 +MAX_LIBRARY_CLOSE_REASON = 3999 +MIN_PRIVATE_CLOSE_REASON = 4000 +MAX_PRIVATE_CLOSE_REASON = 4999 +MAX_CLOSE_REASON = 4999 + + NULL_MASK = struct.pack("!I", 0) @@ -143,18 +172,18 @@ def _consume_exactly(self, nbytes): return (yield from self._consume_at_most(nbytes)) def _parse_extended_payload_length(self, opcode, payload_len): - if opcode.iscontrol() and payload_len > 125: + if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: raise ParseFailed("Control frame with payload len > 125") - if payload_len == 126: + if payload_len == PAYLOAD_LENGTH_TWO_BYTE: data = yield from self._consume_exactly(2) (payload_len,) = struct.unpack("!H", data) - if payload_len <= 125: + if payload_len <= MAX_PAYLOAD_NORMAL: raise ParseFailed( "Payload length used 2 bytes when 1 would have sufficed") - elif payload_len == 127: + elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: data = yield from self._consume_exactly(8) (payload_len,) = struct.unpack("!Q", data) - if payload_len < 2 ** 16: + if payload_len < MAX_PAYLOAD_TWO_BYTE: raise ParseFailed( "Payload length used 8 bytes when 2 would have sufficed") if payload_len >> 63: @@ -167,11 +196,11 @@ def _parse_extended_payload_length(self, opcode, payload_len): def _parse_header(self): # returns a Header object (fin_rsv_opcode,) = yield from self._consume_exactly(1) - fin = bool(fin_rsv_opcode & 0x80) - rsv = (bool(fin_rsv_opcode & 0x40), - bool(fin_rsv_opcode & 0x20), - bool(fin_rsv_opcode & 0x10)) - opcode = fin_rsv_opcode & 0x0f + fin = bool(fin_rsv_opcode & FIN_MASK) + rsv = (bool(fin_rsv_opcode & RSV1_MASK), + bool(fin_rsv_opcode & RSV2_MASK), + bool(fin_rsv_opcode & RSV3_MASK)) + opcode = fin_rsv_opcode & OPCODE_MASK try: opcode = Opcode(opcode) except ValueError: @@ -181,8 +210,8 @@ def _parse_header(self): raise ParseFailed("Invalid attempt to fragment control frame") (mask_len,) = yield from self._consume_exactly(1) - has_mask = bool(mask_len & 0x80) - payload_len = mask_len & 0x7f + has_mask = bool(mask_len & MASK_MASK) + payload_len = mask_len & PAYLOAD_LEN_MASK payload_len = yield from self._parse_extended_payload_length( opcode, payload_len ) @@ -233,7 +262,7 @@ def _process_CLOSE_payload(self, data): raise ParseFailed("CLOSE with 1 byte payload") else: (code,) = struct.unpack("!H", data[:2]) - if code < 1000: + if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: raise ParseFailed("CLOSE with invalid code") try: code = CloseReason(code) @@ -242,7 +271,8 @@ def _process_CLOSE_payload(self, data): if code in LOCAL_ONLY_CLOSE_REASONS: raise ParseFailed( "remote CLOSE with local-only reason") - if not isinstance(code, CloseReason) and code < 3000: + if not isinstance(code, CloseReason) and \ + code <= MAX_PROTOCOL_CLOSE_REASON: raise ParseFailed( "CLOSE with unknown reserved code") try: @@ -366,7 +396,8 @@ def close(self, code=None, reason=None): if code is not None: payload += struct.pack('!H', code) if reason is not None: - payload += _truncate_utf8(reason.encode('utf-8'), 123) + payload += _truncate_utf8(reason.encode('utf-8'), + MAX_PAYLOAD_NORMAL - 2) return self._serialize_frame(Opcode.CLOSE, payload) @@ -410,14 +441,14 @@ def _serialize_frame(self, opcode, payload=b'', fin=True): payload_length = len(payload) quad_payload = False - if payload_length <= 125: + if payload_length <= MAX_PAYLOAD_NORMAL: first_payload = payload_length second_payload = None - elif payload_length <= 65535: - first_payload = 126 + elif payload_length <= MAX_PAYLOAD_TWO_BYTE: + first_payload = PAYLOAD_LENGTH_TWO_BYTE second_payload = payload_length else: - first_payload = 127 + first_payload = PAYLOAD_LENGTH_EIGHT_BYTE second_payload = payload_length quad_payload = True