Skip to content

Commit

Permalink
create core to store some utility classes. cleans up dispatcher a tad.
Browse files Browse the repository at this point in the history
  • Loading branch information
toppk committed Nov 23, 2019
1 parent cd491cf commit 81ac73e
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 147 deletions.
82 changes: 82 additions & 0 deletions pysandra/core.py
@@ -0,0 +1,82 @@
from typing import Dict, Generic, List, Optional, Type, TypeVar, cast

from .exceptions import InternalDriverError, MaximumStreamsException

# from .protocol import RequestMessage # noqa: F401


class SBytes(bytes):
_index: int = 0

def __new__(cls: Type[bytes], val: bytes) -> "SBytes":
return cast(SBytes, super().__new__(cls, val)) # type: ignore # https://github.com/python/typeshed/issues/2630

def hex(self) -> str:
return "0x" + super().hex()

def grab(self, count: int) -> bytes:
assert self._index is not None
if self._index + count > len(self):
raise InternalDriverError(
f"cannot go beyond {len(self)} count={count} index={self._index} sbytes={self!r}"
)
curindex = self._index
self._index += count
return self[curindex : curindex + count]

def at_end(self) -> bool:
return self._index == len(self)

@property
def remaining(self) -> bytes:
return self[self._index :]


T = TypeVar("T")


class Streams(Generic[T]):
def __init__(self) -> None:
self._last_stream_id: Optional[int] = None
self._streams: Dict[int, Optional[T]] = {}

def items(self) -> List[int]:
return list(self._streams.keys())

def remove(self, stream_id: int) -> T:
try:
store = self._streams.pop(stream_id)
assert store is not None
return store
except KeyError:
raise InternalDriverError(
f"stream_id={stream_id} is not open", stream_id=stream_id
)

def create(self) -> int:
maxstream = 2 ** 15
last_id = self._last_stream_id
if len(self._streams) > maxstream:
raise MaximumStreamsException(
f"too many streams last_id={last_id} length={len(self._streams)}"
)
next_id = 0x00
if last_id is not None:
next_id = last_id + 1
while True:
if next_id > maxstream:
next_id = 0x00
if next_id not in self._streams:
break
next_id = next_id + 1
# store will come in later
self._streams[next_id] = None
self._last_stream_id = next_id
return next_id

def update(self, stream_id: int, store: T) -> None:
if stream_id not in self._streams:
raise InternalDriverError(f"stream_id={stream_id} not being tracked")
if store is None:
raise InternalDriverError("cannot store empty request")
self._streams[stream_id] = store
76 changes: 11 additions & 65 deletions pysandra/dispatcher.py
Expand Up @@ -2,15 +2,11 @@
import ssl
import sys
import traceback
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

from .constants import EVENT_STREAM_ID, Flags
from .exceptions import (
ConnectionDroppedError,
InternalDriverError,
MaximumStreamsException,
ServerError,
)
from .core import Streams
from .exceptions import ConnectionDroppedError, InternalDriverError, ServerError
from .protocol import ErrorMessage, Protocol, RequestMessage # noqa: F401
from .types import ExpectedResponses # noqa: F401
from .utils import get_logger
Expand All @@ -36,77 +32,27 @@ def __init__(
"asyncio.Event",
Union["ExpectedResponses", "InternalDriverError", "ServerError"],
] = {}
self._streams: Dict[
int, Optional[Tuple["RequestMessage", Callable, asyncio.Event]]
] = {}
self._streams: Streams[
Tuple["RequestMessage", Callable, asyncio.Event]
] = Streams()
self.decompress: Optional[Callable] = None
self._connected = False
self._running = False
self._last_stream_id: Optional[int] = None
self._writer: Optional["asyncio.StreamWriter"] = None
self._reader: Optional["asyncio.StreamReader"] = None
self._read_task: Optional["asyncio.Future"] = None

def _list_stream_ids(self) -> List[int]:
return list(self._streams.keys())

def _rm_stream_id(
self, stream_id: int
) -> Tuple["RequestMessage", Callable, asyncio.Event]:
try:
store = self._streams.pop(stream_id)
assert store is not None
return store
except KeyError:
raise InternalDriverError(
f"stream_id={stream_id} is not open", stream_id=stream_id
)

def _new_stream_id(self) -> int:
maxstream = 2 ** 15
last_id = self._last_stream_id
if last_id is None:
next_id = 0x00
elif len(self._streams) > maxstream:
raise MaximumStreamsException(
f"too many streams last_id={last_id} length={len(self._streams)}"
)
else:
next_id = last_id + 1
while True:
if next_id > maxstream:
next_id = 0x00
if next_id not in self._streams:
break
# print("cannot use %s" % next_id)
next_id = next_id + 1
if next_id is None:
raise InternalDriverError("next_id cannot be None")
# store will come in later
self._streams[next_id] = None
self._last_stream_id = next_id
return next_id

def _update_stream_id(
self, stream_id: int, store: Tuple["RequestMessage", Callable, asyncio.Event]
) -> None:
if stream_id not in self._streams:
raise InternalDriverError(f"stream_id={stream_id} not being tracked")
if store is None:
raise InternalDriverError("cannot store empty request")
self._streams[stream_id] = store

async def send(
self, request_handler: Callable, response_handler: Callable, params: dict = None
) -> "asyncio.Event":
if not self._connected:
await self._connect()

stream_id = self._new_stream_id()
stream_id = self._streams.create()
# should order compression
request = request_handler(stream_id, params)
event = asyncio.Event()
self._update_stream_id(stream_id, (request, response_handler, event))
self._streams.update(stream_id, (request, response_handler, event))
assert self._writer is not None
self._writer.write(bytes(request))
return event
Expand All @@ -131,7 +77,7 @@ async def _receive(self) -> None:
version, flags, stream_id, opcode, length, body
)
else:
request, response_handler, event = self._rm_stream_id(stream_id)
request, response_handler, event = self._streams.remove(stream_id)
# exceptions are stashed here (in the wrong task)
try:
self._data[event] = response_handler(
Expand Down Expand Up @@ -165,8 +111,8 @@ async def _listener(self) -> None:
# logger.warning(f" connection dropped, going to close")
self._running = False
# close all requests
for stream_id in self._list_stream_ids():
_req, _resp_handler, event = self._rm_stream_id(stream_id)
for stream_id in self._streams.items():
_req, _resp_handler, event = self._streams.remove(stream_id)
self._data[event] = e
event.set()
self._reader = None
Expand Down
3 changes: 2 additions & 1 deletion pysandra/protocol.py
Expand Up @@ -17,6 +17,7 @@
ResultFlags,
SchemaChangeTarget,
)
from .core import SBytes
from .exceptions import (
BadInputException,
InternalDriverError,
Expand All @@ -25,7 +26,7 @@
VersionMismatchException,
)
from .types import ChangeEvent, ExpectedResponses, Rows, SchemaChange, SchemaChangeType
from .utils import SBytes, get_logger
from .utils import get_logger

logger = get_logger(__name__)

Expand Down
29 changes: 1 addition & 28 deletions pysandra/utils.py
Expand Up @@ -2,7 +2,7 @@
import os
import sys
from struct import Struct
from typing import List, Type, cast
from typing import List

from .exceptions import InternalDriverError

Expand Down Expand Up @@ -76,30 +76,3 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)

return logging.getLogger(name)


class SBytes(bytes):
_index: int = 0

def __new__(cls: Type[bytes], val: bytes) -> "SBytes":
return cast(SBytes, super().__new__(cls, val)) # type: ignore # https://github.com/python/typeshed/issues/2630

def hex(self) -> str:
return "0x" + super().hex()

def grab(self, count: int) -> bytes:
assert self._index is not None
if self._index + count > len(self):
raise InternalDriverError(
f"cannot go beyond {len(self)} count={count} index={self._index} sbytes={self!r}"
)
curindex = self._index
self._index += count
return self[curindex : curindex + count]

def at_end(self) -> bool:
return self._index == len(self)

@property
def remaining(self) -> bytes:
return self[self._index :]
3 changes: 2 additions & 1 deletion pysandra/v4protocol.py
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, Optional

from .constants import Opcode
from .core import SBytes
from .exceptions import ServerError # noqa: F401
from .exceptions import InternalDriverError, UnknownPayloadException
from .protocol import (
Expand All @@ -26,7 +27,7 @@
VoidResultMessage,
)
from .types import ExpectedResponses # noqa: F401
from .utils import SBytes, get_logger
from .utils import get_logger

logger = get_logger(__name__)

Expand Down
91 changes: 91 additions & 0 deletions tests/test_core.py
@@ -0,0 +1,91 @@
import pytest

from pysandra.core import SBytes, Streams
from pysandra.exceptions import InternalDriverError, MaximumStreamsException


def test_max_streams():
with pytest.raises(
MaximumStreamsException, match=r"too many streams last_id=31159 length=32769"
):
streams: Streams[int] = Streams()
move = 0
while True:
move += 1
stream_id = streams.create()
streams.update(stream_id, 3)
if (move % 19) == 0:
streams.remove(stream_id)


def test_streams_list():
streams: Streams[int] = Streams()
streams.create()
streams.create()
assert streams.items() == [0, 1]


def test_streams_update():
streams: Streams[int] = Streams()
stream_id = streams.create()
stream_id2 = streams.create()
streams.update(stream_id, "FOO")
streams.update(stream_id2, "BAR")
streams.remove(stream_id2)
assert streams.remove(stream_id) == "FOO"


def test_streams_update_fail_found():
with pytest.raises(InternalDriverError, match=r"not being tracked"):
streams: Streams[int] = Streams()
stream_id = streams.create()
streams.update(stream_id + 1, "FOO")


def test_streams_update_fail_null():
with pytest.raises(InternalDriverError, match=r"empty request"):
streams: Streams[int] = Streams()
stream_id = streams.create()
streams.update(stream_id, None)


def test_streams_error():
with pytest.raises(InternalDriverError, match=r"is not open"):
streams: Streams[int] = Streams()
stream_id = streams.create()
streams.remove(stream_id + 1)


def test_sbytes_at_end():
t = SBytes(b"12345")
print(f"{t.grab(1)!r}{t.at_end()}")
print(f"{t.grab(3)!r}{t.at_end()}")
print(f"{t.grab(1)!r}{t.at_end()}")
assert t.at_end()


def test_sbytes_hex():
t = SBytes(b"\x03\13\45")
assert t.hex() == "0x030b25"


def test_sbytes_remaining():
t = SBytes(b"\x03\13\45")
t.grab(2)

assert t.remaining == b"%"


def test_sbytes_not_end():
t = SBytes(b"12345")
print(f"{t.grab(1)!r}{t.at_end()}")
print(f"{t.grab(3)!r}{t.at_end()}")
assert not t.at_end()


def test_sbytes_overflow():
with pytest.raises(InternalDriverError, match=r"cannot go beyond"):
t = SBytes(b"12345")
print(f"{t.grab(1)!r}{t.at_end()}")
print(f"{t.grab(3)!r}{t.at_end()}")
print(f"{t.grab(2)!r}{t.at_end()}")
18 changes: 2 additions & 16 deletions tests/test_dispatcher.py
@@ -1,20 +1,6 @@
import pytest

from pysandra.dispatcher import Dispatcher
from pysandra.exceptions import MaximumStreamsException


def test_max_streams():
with pytest.raises(
MaximumStreamsException, match=r"too many streams last_id=31159 length=32769"
):
client = Dispatcher("blank", "", False, 0)
move = 0
while True:
move += 1
streamid = client._new_stream_id()
client._update_stream_id(
streamid, ("something", "else", "entirely"),
)
if (move % 19) == 0:
client._rm_stream_id(streamid)
d = Dispatcher("blank", "", False, 0)
print(f"({d})")

0 comments on commit 81ac73e

Please sign in to comment.