Skip to content

Commit

Permalink
resolve static typing
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Nov 27, 2018
1 parent a829b38 commit 159fb20
Show file tree
Hide file tree
Showing 20 changed files with 118 additions and 85 deletions.
1 change: 1 addition & 0 deletions Pipfile
Expand Up @@ -18,6 +18,7 @@ twine = "*"
pytest = "*"
yappi = "*"
pytest-repeat = "*"
mypy = "*"

[requires]
python_version = "3.6"
2 changes: 1 addition & 1 deletion benchmarks/mapper.py
@@ -1,7 +1,7 @@
import multiprocessing
from time import perf_counter

import zproc
import multiprocessing

ctx = zproc.Context()
ctx.workers.start(2)
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/webpage_downloader.py
Expand Up @@ -7,8 +7,6 @@
import asyncio
from time import time

import grequests

import zproc

SAMPLES = 1
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Expand Up @@ -8,13 +8,12 @@

# -- Path setup --------------------------------------------------------------

import sys

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

Expand Down
1 change: 1 addition & 0 deletions examples/peanut_processor.py
Expand Up @@ -112,6 +112,7 @@
...
"""
from time import sleep

import zproc

ctx = zproc.Context()
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
@@ -0,0 +1,2 @@
[mypy]
ignore_missing_imports = True
5 changes: 3 additions & 2 deletions tests/test_atomic_contract.py
Expand Up @@ -31,7 +31,7 @@ def mutator(snap):

def test_signal_contract(ctx, state):
@zproc.atomic
def mutator(snap):
def atomic_fn(snap):
snap["x"] = 5
time.sleep(0.1)

Expand All @@ -45,6 +45,7 @@ def p():
zproc.signal_to_exception(signal.SIGINT)

with pytest.raises(zproc.SignalException):
mutator(state)
atomic_fn(state)

print(state.copy())
assert state == {"x": 5}
9 changes: 1 addition & 8 deletions tests/test_liveness.py
Expand Up @@ -14,8 +14,8 @@ def mutator(ctx: zproc.Context):
state = ctx.create_state()

for n in range(10):
state["counter"] = n
sleep(0.1)
state["counter"] = n

return ctx.create_state()

Expand All @@ -30,10 +30,3 @@ def test_live(state: zproc.State):
it = state.get_when_change("counter", live=True)
sleep(0.25)
assert next(it)["counter"] > 0


def test_go_live(state: zproc.State):
it = state.get_when_change("counter")
sleep(0.25)
state.go_live()
assert next(it)["counter"] > 0
3 changes: 2 additions & 1 deletion tests/test_process_wait.py
@@ -1,6 +1,7 @@
import pytest
import time

import pytest

import zproc

TOLERANCE = 0.1
Expand Down
2 changes: 1 addition & 1 deletion zproc/__init__.py
Expand Up @@ -12,6 +12,6 @@
from .process import Process
from .server.tools import start_server, ping
from .state.state import State, atomic
from .task.swarm import Swarm
from .task.result import SequenceTaskResult, SimpleTaskResult
from .task.swarm import Swarm
from .util import clean_process_tree, consume
3 changes: 1 addition & 2 deletions zproc/context.py
@@ -1,11 +1,10 @@
import atexit
import functools
import multiprocessing
import pprint
import signal
import time
from contextlib import suppress
from typing import Callable, Union, Hashable, Any, List, Mapping, Sequence, Tuple, cast
from typing import Callable, Union, Any, List, Mapping, Sequence, Tuple, cast

from . import util
from .consts import DEFAULT_NAMESPACE
Expand Down
3 changes: 1 addition & 2 deletions zproc/server/main.py
@@ -1,5 +1,4 @@
import atexit
from multiprocessing.connection import Connection

import zmq

Expand All @@ -11,7 +10,7 @@
from zproc.task.server import start_task_server, start_task_proxy


def main(server_address: str, send_conn: Connection):
def main(server_address: str, send_conn):
with util.socket_factory(zmq.ROUTER, zmq.ROUTER) as (
zmq_ctx,
state_router,
Expand Down
2 changes: 1 addition & 1 deletion zproc/server/tools.py
@@ -1,7 +1,7 @@
import multiprocessing
import os
from collections import Callable
from typing import Union, Optional, Tuple
from typing import Union, Tuple

import zmq

Expand Down
6 changes: 3 additions & 3 deletions zproc/state/server.py
Expand Up @@ -99,18 +99,18 @@ def recv_request(self):
self.dispatch_dict[request[Msgs.cmd]](request)

def reply(self, response):
# print("server rep:", self._active_ident, rep, time.time())
# print("server rep:", self.identity, response, time.time())
self.state_router.send_multipart([self.identity, serializer.dumps(response)])

@contextmanager
def mutate_safely(self):
stamp = time.time()
old = deepcopy(self.state)
stamp = time.time()

try:
yield
except Exception:
self.state_map[self.namespace] = old
self.state = self.state_map[self.namespace] = old
raise

slot = self.history[self.namespace]
Expand Down
71 changes: 42 additions & 29 deletions zproc/state/state.py
Expand Up @@ -6,7 +6,7 @@
from functools import wraps
from pprint import pformat
from textwrap import indent
from typing import Hashable, Any, Callable, Dict, List, Iterator
from typing import Hashable, Any, Callable, Dict, List, Generator

import zmq

Expand Down Expand Up @@ -145,8 +145,10 @@ def _create_state_dealer(self) -> zmq.Socket:

def _request_reply(self, request: Dict[int, Any]):
request[Msgs.namespace] = self._namespace_bytes
self._s_dealer.send(serializer.dumps(request))
return serializer.loads(self._s_dealer.recv())
msg = serializer.dumps(request)
return serializer.loads(
util.strict_request_reply(msg, self._s_dealer.send, self._s_dealer.recv)
)

def set(self, value: dict):
"""
Expand Down Expand Up @@ -201,8 +203,9 @@ def _watcher_request_reply(
self, request: List[bytes], only_after: float
) -> List[bytes]:
request[-1] = struct.pack("d", only_after)
self._w_dealer.send_multipart(request)
return self._w_dealer.recv_multipart()
return util.strict_request_reply(
request, self._w_dealer.send_multipart, self._w_dealer.recv_multipart
)

def get_raw_update(
self,
Expand All @@ -212,7 +215,7 @@ def get_raw_update(
identical_okay: bool = False,
start_time: bool = None,
count: int = None,
) -> Iterator[StateUpdate]:
) -> Generator[StateUpdate, None, None]:
"""
A low-level hook that emits each and every state update.
All other state watchers are built upon this only.
Expand All @@ -229,9 +232,9 @@ def get_raw_update(
only_after = self._request_reply({Msgs.cmd: Cmds.time})

if count is None:
count = itertools.count()
counter = itertools.count()
else:
count = range(count)
counter = iter(range(count))

request_msg = [
self._identity,
Expand All @@ -241,12 +244,14 @@ def get_raw_update(
]

def _(only_after):
for _ in count:
for _ in counter:
if time_limit is None:
self._w_dealer.setsockopt(zmq.RCVTIMEO, DEFAULT_ZMQ_RECVTIMEO)
else:
if time_limit < time.time():
raise TimeoutError("Timed-out while waiting for a state update.")
raise TimeoutError(
"Timed-out while waiting for a state update."
)

self._w_dealer.setsockopt(
zmq.RCVTIMEO, int((time_limit - time.time()) * 1000)
Expand All @@ -272,7 +277,7 @@ def _(only_after):

def get_when_change(
self, *keys: Hashable, exclude: bool = False, **watcher_kwargs
) -> Iterator[dict]:
) -> Generator[dict, None, None]:
"""
Block until a change is observed, and then return a copy of the state.
Expand All @@ -287,7 +292,7 @@ def get_when_change(
else:
return (next(it).after for _ in range(count))

keys = set(keys)
key_set = set(keys)
identical_okay = watcher_kwargs.get("identical_okay", False)
if identical_okay:
raise ValueError(
Expand All @@ -302,9 +307,9 @@ def _():
def select():
selected = {*before.keys(), *after.keys()}
if exclude:
return selected - keys
return selected - key_set
else:
return selected & keys
return selected & key_set

i = 0
while i < count:
Expand All @@ -321,7 +326,7 @@ def select():

return _()

def get_when(self, test_fn, **watcher_kwargs) -> Iterator[dict]:
def get_when(self, test_fn, **watcher_kwargs) -> Generator[dict, None, None]:
"""
Block until ``test_fn(snap)`` returns a "truthy" value,
and then return a copy of the state.
Expand All @@ -334,24 +339,28 @@ def get_when(self, test_fn, **watcher_kwargs) -> Iterator[dict]:
"""
snap = self.copy()
if test_fn(snap):
return iter([snap])

count = watcher_kwargs.pop("count", math.inf)
it = self.get_raw_update(**watcher_kwargs)
def _():
yield snap

def _():
i = 0
while i < count:
snap = next(it).after
if test_fn(snap):
i += 1
yield snap
else:
count = watcher_kwargs.pop("count", math.inf)
it = self.get_raw_update(**watcher_kwargs)

def _():
i = 0
while i < count:
snap = next(it).after

if test_fn(snap):
i += 1
yield snap

return _()

def get_when_equal(
self, key: Hashable, value: Any, **watcher_kwargs
) -> Iterator[dict]:
) -> Generator[dict, None, None]:
"""
Block until ``state[key] == value``, and then return a copy of the state.
Expand All @@ -368,7 +377,7 @@ def _(snap):

def get_when_not_equal(
self, key: Hashable, value: Any, **watcher_kwargs
) -> Iterator[dict]:
) -> Generator[dict, None, None]:
"""
Block until ``state[key] != value``, and then return a copy of the state.
Expand All @@ -383,7 +392,9 @@ def _(snap):

return self.get_when(_, **watcher_kwargs)

def get_when_none(self, key: Hashable, **watcher_kwargs) -> Iterator[dict]:
def get_when_none(
self, key: Hashable, **watcher_kwargs
) -> Generator[dict, None, None]:
"""
Block until ``state[key] is None``, and then return a copy of the state.
Expand All @@ -398,7 +409,9 @@ def _(snap):

return self.get_when(_, **watcher_kwargs)

def get_when_not_none(self, key: Hashable, **watcher_kwargs) -> Iterator[dict]:
def get_when_not_none(
self, key: Hashable, **watcher_kwargs
) -> Generator[dict, None, None]:
"""
Block until ``state[key] is not None``, and then return a copy of the state.
Expand Down
5 changes: 3 additions & 2 deletions zproc/task/result.py
Expand Up @@ -19,8 +19,9 @@ def _create_dealer(self) -> zmq.Socket:

def _get_chunk(self, index: int):
chunk_id = util.encode_chunk_id(self.task_id, index)
self._dealer.send(chunk_id)
return serializer.loads(self._dealer.recv())
return serializer.loads(
util.strict_request_reply(chunk_id, self._dealer.send, self._dealer.recv)
)

def __del__(self):
try:
Expand Down
7 changes: 3 additions & 4 deletions zproc/task/server.py
@@ -1,7 +1,6 @@
import multiprocessing
from collections import defaultdict, Callable, deque
from multiprocessing.connection import Connection
from typing import Any, Dict, List
from typing import Dict, List

import zmq

Expand Down Expand Up @@ -70,7 +69,7 @@ def tick(self):
self.recv_chunk_result()


def _task_server(send_conn: Connection, _bind: Callable):
def _task_server(send_conn, _bind: Callable):
with util.socket_factory(zmq.ROUTER, zmq.PULL) as (zmq_ctx, router, result_pull):
with send_conn:
try:
Expand Down Expand Up @@ -99,7 +98,7 @@ def _task_server(send_conn: Connection, _bind: Callable):
# Clients never need to talk to a worker directly.


def _task_proxy(send_conn: Connection, _bind: Callable):
def _task_proxy(send_conn, _bind: Callable):
with util.socket_factory(zmq.PULL, zmq.PUSH) as (zmq_ctx, proxy_in, proxy_out):
with send_conn:
try:
Expand Down
2 changes: 1 addition & 1 deletion zproc/task/swarm.py
@@ -1,5 +1,5 @@
import multiprocessing
from typing import List, Mapping, Sequence, Any, Callable, Union
from typing import List, Mapping, Sequence, Any, Callable

import zmq

Expand Down

0 comments on commit 159fb20

Please sign in to comment.