Skip to content

Commit

Permalink
actually refactor ResponseMessage. still could use a typing person
Browse files Browse the repository at this point in the history
  • Loading branch information
toppk committed Nov 20, 2019
1 parent 724b165 commit 135ebfc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 31 deletions.
48 changes: 33 additions & 15 deletions pysandra/protocol.py
Expand Up @@ -4,6 +4,7 @@

from .constants import (
Consitency,
ErrorCode,
Kind,
Opcode,
OptionID,
Expand Down Expand Up @@ -204,37 +205,50 @@ def encode_header(self, body_length: int) -> bytes:
class ResponseMessage(BaseMessage):
opcode: int

def __init__(
self, version: int = None, flags: int = None, stream_id: int = None
) -> None:
def __init__(self, version: int, flags: int, stream_id: int) -> None:
self.version = version
self.flags = flags
self.stream_id = stream_id

@staticmethod
def build(
version: int, flags: int, stream_id: int, body: "SBytes"
) -> "ResponseMessage":
raise InternalDriverError("subclass should implement build() method")


class ReadyMessage(ResponseMessage):
opcode = Opcode.READY

@staticmethod
def build(version: int, flags: int, body: "SBytes") -> "ReadyMessage":
def build(
version: int, flags: int, stream_id: int, body: "SBytes"
) -> "ReadyMessage":
logger.debug(f"ReadyResponse body={body!r}")
return ReadyMessage(flags=flags)
return ReadyMessage(version, flags, stream_id)


class ErrorMessage(ResponseMessage):
opcode = Opcode.ERROR

def __init__(self, error_code: int, error_text: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(self, error_code: "ErrorCode", error_text: str, *args: Any) -> None:
super().__init__(*args)
self.error_code = error_code
self.error_text = error_text

@staticmethod
def build(version: int, flags: int, body: "SBytes") -> "ErrorMessage":
def build(
version: int, flags: int, stream_id: int, body: "SBytes"
) -> "ErrorMessage":
logger.debug(f"ErrorResponse body={body!r}")
error_code = decode_int(body)
code = decode_int(body)
try:
error_code = ErrorCode(code)
except ValueError:
raise InternalDriverError(f"unknown error code={code}")
logger.debug(f"ErrorMessage error_code={error_code:x}")
error_text = decode_string(body)
return ErrorMessage(flags=flags, error_code=error_code, error_text=error_text)
return ErrorMessage(error_code, error_text, version, flags, stream_id)


class ResultMessage(ResponseMessage):
Expand All @@ -245,12 +259,14 @@ def __init__(self, kind: int, *args: Any) -> None:
self.kind = kind

@staticmethod
def build(version: int, flags: int, body: "SBytes",) -> "ResultMessage":
def build(
version: int, flags: int, stream_id: int, body: "SBytes",
) -> "ResultMessage":
msg: Optional["ResultMessage"] = None
kind = decode_int(body)
logger.debug(f"ResultResponse kind={kind} body={body!r}")
if kind == Kind.VOID:
msg = VoidResultMessage(version, flags, kind)
msg = VoidResultMessage(kind, version, flags, stream_id)
elif kind == Kind.ROWS:
result_flags = decode_int(body)
column_count = decode_int(body)
Expand Down Expand Up @@ -278,7 +294,7 @@ def build(version: int, flags: int, body: "SBytes",) -> "ResultMessage":
elif length == 0:
cell = b""
rows.add(cell)
msg = RowsResultMessage(rows, version, flags, kind)
msg = RowsResultMessage(rows, kind, version, flags, stream_id)

elif kind == Kind.SET_KEYSPACE:
pass
Expand Down Expand Up @@ -349,7 +365,9 @@ def build(version: int, flags: int, body: "SBytes",) -> "ResultMessage":
# parse col_spec_i
for _col in range(result_columns_count):
pass
msg = PreparedResultMessage(query_id, col_specs, kind, version, flags)
msg = PreparedResultMessage(
query_id, col_specs, kind, version, flags, stream_id
)
elif kind == Kind.SCHEMA_CHANGE:
# <change_type>
try:
Expand Down Expand Up @@ -385,7 +403,7 @@ def build(version: int, flags: int, body: "SBytes",) -> "ResultMessage":
f"SCHEMA_CHANGE change_type={change_type} target={target} options={options}"
)
schema_change = SchemaChange(change_type, target, options)
msg = SchemaResultMessage(schema_change, kind, version, flags)
msg = SchemaResultMessage(schema_change, kind, version, flags, stream_id)
else:
raise UnknownPayloadException(f"RESULT message has unknown kind={kind}")
if msg is None:
Expand Down
32 changes: 16 additions & 16 deletions pysandra/v4protocol.py
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple

from .constants import CQL_VERSION, SERVER_SENT, Opcode, Options
from .exceptions import (
Expand Down Expand Up @@ -85,46 +85,46 @@ def build_response(
request: "RequestMessage",
version: int,
flags: int,
stream: int,
stream_id: int,
opcode: int,
length: int,
body: bytes,
) -> "ExpectedResponses":
sbytes_body = SBytes(body)
response: Optional["ResponseMessage"] = None
factory: Optional[Callable] = None
if opcode == Opcode.ERROR:
response = ErrorMessage.build(version, flags, sbytes_body)
if not sbytes_body.at_end():
raise InternalDriverError(
f"ErrorMessage still data left remains={sbytes_body.show()!r}"
)
raise ServerError(
f'got error_code={response.error_code:x} with description="{response.error_text}"',
msg=response,
)
factory = ErrorMessage.build
elif opcode == Opcode.READY:
response = ReadyMessage.build(version, flags, sbytes_body)
factory = ReadyMessage.build
elif opcode == Opcode.AUTHENTICATE:
pass
elif opcode == Opcode.SUPPORTED:
pass
elif opcode == Opcode.RESULT:
response = ResultMessage.build(version, flags, sbytes_body)
factory = ResultMessage.build
elif opcode == Opcode.EVENT:
pass
elif opcode == Opcode.AUTH_CHALLENGE:
pass
elif opcode == Opcode.AUTH_SUCCESS:
pass
else:
raise UnknownPayloadException(f"unknown message opcode={opcode}")
if factory is None:
raise UnknownPayloadException(f"unhandled message opcode={opcode}")
response = factory(version, flags, stream_id, sbytes_body)
if response is None:
raise InternalDriverError(
f"didn't generate a response message for opcode={opcode}"
)

if not sbytes_body.at_end():
raise InternalDriverError(f"still data left remains={sbytes_body.show()!r}")
if opcode == Opcode.ERROR:
assert isinstance(response, ErrorMessage)
raise ServerError(
f'got error_code={response.error_code:x} with description="{response.error_text}"',
msg=response,
)

return self.respond(request, response)

def respond(
Expand Down

0 comments on commit 135ebfc

Please sign in to comment.