Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations #101

Merged
merged 6 commits into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion katsdptelstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
import time as _time
__version__ = "0.0+unknown.{}".format(_time.strftime('%Y%m%d%H%M'))
else:
__version__ = _katversion.get_version(__path__[0])
__version__ = _katversion.get_version(__path__[0]) # type: ignore
# END VERSION CHECK
42 changes: 25 additions & 17 deletions katsdptelstate/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from abc import ABC, abstractmethod
import time
from typing import List, Tuple, Dict, Generator, BinaryIO, Iterable, Optional, Union

from .utils import KeyType
from .utils import KeyType, _PathType


class KeyUpdateBase:
Expand All @@ -36,7 +37,7 @@ class KeyUpdate(ABC, KeyUpdateBase):
be instantiated directly.
"""

def __init__(self, key: bytes, value: bytes):
def __init__(self, key: bytes, value: bytes) -> None:
self.key = key
self.value = value

Expand All @@ -49,7 +50,7 @@ def key_type(self) -> KeyType:
class MutableKeyUpdate(KeyUpdate):
"""Update notification for a mutable key."""

def __init__(self, key: bytes, value: bytes, timestamp: float):
def __init__(self, key: bytes, value: bytes, timestamp: float) -> None:
super().__init__(key, value)
self.timestamp = timestamp

Expand Down Expand Up @@ -89,35 +90,35 @@ class Backend(ABC):
"""

@abstractmethod
def load_from_file(self, file):
def load_from_file(self, file: Union[_PathType, BinaryIO]) -> int:
"""Implements :meth:`TelescopeState.load_from_file`."""

@abstractmethod
def __contains__(self, key):
def __contains__(self, key: bytes) -> bool:
"""Return if `key` is in the backend."""

@abstractmethod
def keys(self, filter):
def keys(self, filter: bytes) -> List[bytes]:
"""Return all keys matching `filter`.

The filter is a redis pattern. Backends might only support ``b'*'`` as
a filter.
"""

@abstractmethod
def delete(self, key):
def delete(self, key: bytes) -> None:
"""Delete a key (no-op if it does not exist)"""

@abstractmethod
def clear(self):
def clear(self) -> None:
"""Remove all keys"""

@abstractmethod
def key_type(self, key):
def key_type(self, key: bytes) -> Optional[KeyType]:
"""Get type of `key`, or ``None`` if it does not exist."""

@abstractmethod
def set_immutable(self, key, value):
def set_immutable(self, key: bytes, value: bytes) -> Optional[bytes]:
"""Set the value of an immutable key.

If the key already exists (and is immutable), returns the existing
Expand All @@ -130,7 +131,11 @@ def set_immutable(self, key, value):
"""

@abstractmethod
def get(self, key):
def get(self, key: bytes) -> Union[
Tuple[None, None],
Tuple[bytes, None],
Tuple[bytes, float],
Tuple[Dict[bytes, bytes], None]]:
"""Get the value and timestamp of a key.

The return value depends on the key type:
Expand All @@ -148,7 +153,7 @@ def get(self, key):
"""

@abstractmethod
def add_mutable(self, key, value, timestamp):
def add_mutable(self, key: bytes, value: bytes, timestamp: float) -> None:
"""Set a (value, timestamp) pair in a mutable key.

The `timestamp` will be a non-negative float value.
Expand All @@ -160,7 +165,7 @@ def add_mutable(self, key, value, timestamp):
"""

@abstractmethod
def set_indexed(self, key, sub_key, value):
def set_indexed(self, key: bytes, sub_key: bytes, value: bytes) -> Optional[bytes]:
"""Add value in an indexed immutable key.

If the sub-key already exists, returns the existing value and does not
Expand All @@ -173,7 +178,7 @@ def set_indexed(self, key, sub_key, value):
"""

@abstractmethod
def get_indexed(self, key, sub_key):
def get_indexed(self, key: bytes, sub_key: bytes) -> Optional[bytes]:
"""Get the value of an indexed immutable key.

Returns ``None`` if the key exists but the sub-key does not exist.
Expand All @@ -187,7 +192,8 @@ def get_indexed(self, key, sub_key):
"""

@abstractmethod
def get_range(self, key, start_time, end_time, include_previous, include_end):
def get_range(self, key: bytes, start_time: float, end_time: float,
include_previous: bool, include_end: bool) -> Optional[List[Tuple[bytes, float]]]:
"""Obtain a range of values from a mutable key.

If the key does not exist, returns None.
Expand All @@ -212,10 +218,11 @@ def get_range(self, key, start_time, end_time, include_previous, include_end):
"""

@abstractmethod
def dump(self, key):
def dump(self, key: bytes) -> Optional[bytes]:
"""Return a key in the same format as the Redis DUMP command, or None if not present."""

def monitor_keys(self, keys):
def monitor_keys(self, keys: Iterable[bytes]) \
-> Generator[Optional[KeyUpdateBase], Optional[float], None]:
"""Report changes to keys in `keys`.

Returns a generator. The first yield from the generator is a no-op.
Expand All @@ -233,5 +240,6 @@ def monitor_keys(self, keys):
# This is a valid but usually suboptimal implementation
timeout = yield None
while True:
assert timeout is not None
time.sleep(timeout)
timeout = yield KeyUpdateBase()
27 changes: 14 additions & 13 deletions katsdptelstate/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import threading
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this file looked new... great idea (wrong PR, I know)!

import warnings
import pickle
from typing import Optional, Any

import msgpack
import numpy as np
Expand Down Expand Up @@ -72,7 +73,7 @@
_pickle_loads = functools.partial(pickle.loads, encoding='latin1')


def set_allow_pickle(allow, warn=False):
def set_allow_pickle(allow: bool, warn: bool = False) -> None:
"""Control whether pickles are allowed.

This overrides the defaults which are determined from the environment.
Expand All @@ -92,7 +93,7 @@ def set_allow_pickle(allow, warn=False):
_warn_on_pickle = warn


def _init_allow_pickle():
def _init_allow_pickle() -> None:
env = os.environ.get('KATSDPTELSTATE_ALLOW_PICKLE')
allow = False
if env == '1':
Expand All @@ -107,7 +108,7 @@ def _init_allow_pickle():
_init_allow_pickle()


def _encode_ndarray(value):
def _encode_ndarray(value: np.ndarray) -> bytes:
fp = io.BytesIO()
try:
np.save(fp, value, allow_pickle=False)
Expand All @@ -117,23 +118,23 @@ def _encode_ndarray(value):
return fp.getvalue()


def _decode_ndarray(data):
def _decode_ndarray(data: bytes) -> np.ndarray:
fp = io.BytesIO(data)
try:
return np.load(fp, allow_pickle=False)
except ValueError as error:
raise DecodeError(str(error))


def _encode_numpy_scalar(value):
def _encode_numpy_scalar(value: np.generic) -> bytes:
if value.dtype.hasobject:
raise EncodeError('cannot encode dtype {} as it contains objects'
.format(value.dtype))
descr = np.lib.format.dtype_to_descr(value.dtype)
return _msgpack_encode(descr) + value.tobytes()


def _decode_numpy_scalar(data):
def _decode_numpy_scalar(data: bytes) -> np.generic:
try:
descr = _msgpack_decode(data)
raw = b''
Expand All @@ -148,7 +149,7 @@ def _decode_numpy_scalar(data):
return value[0]


def _msgpack_default(value):
def _msgpack_default(value: Any) -> msgpack.ExtType:
if isinstance(value, tuple):
return msgpack.ExtType(MSGPACK_EXT_TUPLE, _msgpack_encode(list(value)))
elif isinstance(value, np.ndarray):
Expand All @@ -163,7 +164,7 @@ def _msgpack_default(value):
.format(value.__class__.__name__))


def _msgpack_ext_hook(code, data):
def _msgpack_ext_hook(code: int, data: bytes) -> Any:
if code == MSGPACK_EXT_TUPLE:
content = _msgpack_decode(data)
if not isinstance(content, list):
Expand All @@ -181,12 +182,12 @@ def _msgpack_ext_hook(code, data):
raise DecodeError('unknown extension type {}'.format(code))


def _msgpack_encode(value):
def _msgpack_encode(value: Any) -> bytes:
return msgpack.packb(value, use_bin_type=True, strict_types=True,
default=_msgpack_default)


def _msgpack_decode(value):
def _msgpack_decode(value: bytes) -> Any:
# The max_*_len prevent a corrupted or malicious input from consuming
# memory significantly in excess of the input size before it determines
# that there isn't actually enough data to back it.
Expand All @@ -199,7 +200,7 @@ def _msgpack_decode(value):
max_ext_len=max_len)


def encode_value(value, encoding=ENCODING_DEFAULT):
def encode_value(value: Any, encoding: bytes = ENCODING_DEFAULT) -> bytes:
"""Encode a value to a byte array for storage in redis.

Parameters
Expand All @@ -224,7 +225,7 @@ def encode_value(value, encoding=ENCODING_DEFAULT):
raise ValueError('Unknown encoding {:#x}'.format(ord(encoding)))


def decode_value(value, allow_pickle=None):
def decode_value(value: bytes, allow_pickle: Optional[bool] = None) -> Any:
"""Decode a value encoded with :func:`encode_value`.

The encoded value is self-describing, so it is not necessary to specify
Expand Down Expand Up @@ -277,7 +278,7 @@ def decode_value(value, allow_pickle=None):
'(katsdptelstate may need to be updated)'.format(value[:1]))


def equal_encoded_values(a, b):
def equal_encoded_values(a: bytes, b: bytes) -> bool:
"""Test whether two encoded values represent the same/equivalent objects.

This is not a complete implementation. Mostly, it just checks that the
Expand Down