Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
import sys
import enum
import socket
import datetime
import anyio
import async_exit_stack
from async_generator import async_generator, yield_, yield_from_
from purerpc.utils import is_darwin, is_windows
from purerpc.grpclib.exceptions import ProtocolError
from .grpclib.connection import GRPCConfiguration, GRPCConnection
from .grpclib.events import RequestReceived, RequestEnded, ResponseEnded, MessageReceived, WindowUpdated
from .grpclib.buffers import MessageWriteBuffer, MessageReadBuffer
from .grpclib.exceptions import StreamClosedError
class SocketWrapper(async_exit_stack.AsyncExitStack):
def __init__(self, grpc_connection: GRPCConnection, sock: anyio.SocketStream):
super().__init__()
self._set_socket_options(sock)
self._socket = sock
self._grpc_connection = grpc_connection
self._flush_event = anyio.create_event()
self._running = True
async def __aenter__(self):
await super().__aenter__()
task_group = await self.enter_async_context(anyio.create_task_group())
await task_group.spawn(self._writer_thread)
async def callback():
self._running = False
await self._flush_event.set()
self.push_async_callback(callback)
return self
@staticmethod
def _set_socket_options(sock: anyio.SocketStream):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 300)
elif is_darwin():
# Darwin specific option
TCP_KEEPALIVE = 16
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, 300)
if not is_windows():
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
async def _writer_thread(self):
while True:
data = self._grpc_connection.data_to_send()
if data:
await self._socket.send_all(data)
elif self._running:
await self._flush_event.wait()
self._flush_event.clear()
else:
return
async def flush(self):
"""This maybe called from different threads."""
await self._flush_event.set()
async def recv(self, buffer_size: int):
"""This may only be called from single thread."""
return await self._socket.receive_some(buffer_size)
class GRPCStreamState(enum.Enum):
OPEN = 1
HALF_CLOSED_REMOTE = 2
HALF_CLOSED_LOCAL = 3
CLOSED = 4
class GRPCStream:
def __init__(self, grpc_connection: GRPCConnection, stream_id: int, socket: SocketWrapper,
grpc_socket: "GRPCSocket"):
self._stream_id = stream_id
self._grpc_connection = grpc_connection
self._grpc_socket = grpc_socket
self._socket = socket
self._flow_control_update_event = anyio.create_event()
self._incoming_events = anyio.create_queue(sys.maxsize)
self._response_started = False
self._state = GRPCStreamState.OPEN
self._start_stream_event = None
self._end_stream_event = None
@property
def state(self):
return self._state
@property
def start_stream_event(self):
return self._start_stream_event
@property
def end_stream_event(self):
return self._end_stream_event
@property
def stream_id(self):
return self._stream_id
@property
def client_side(self):
return self._grpc_connection.config.client_side
@property
def debug_prefix(self):
return "[CLIENT] " if self.client_side else "[SERVER] "
def _close_remote(self):
if self._state == GRPCStreamState.OPEN:
self._state = GRPCStreamState.HALF_CLOSED_REMOTE
elif self._state == GRPCStreamState.HALF_CLOSED_LOCAL:
self._state = GRPCStreamState.CLOSED
del self._grpc_socket._streams[self._stream_id]
def _close_local(self):
if self._state == GRPCStreamState.OPEN:
self._state = GRPCStreamState.HALF_CLOSED_LOCAL
elif self._state == GRPCStreamState.HALF_CLOSED_REMOTE:
self._state = GRPCStreamState.CLOSED
del self._grpc_socket._streams[self._stream_id]
async def _set_flow_control_update(self):
await self._flow_control_update_event.set()
async def _wait_flow_control_update(self):
await self._flow_control_update_event.wait()
self._flow_control_update_event.clear()
async def _send(self, message: bytes, compress=False):
message_write_buffer = MessageWriteBuffer(self._grpc_connection.config.message_encoding,
self._grpc_connection.config.max_message_length)
message_write_buffer.write_message(message, compress)
while message_write_buffer:
window_size = self._grpc_connection.flow_control_window(self._stream_id)
if window_size <= 0:
await self._wait_flow_control_update()
continue
num_data_to_send = min(window_size, len(message_write_buffer))
data = message_write_buffer.data_to_send(num_data_to_send)
self._grpc_connection.send_data(self._stream_id, data)
await self._socket.flush()
async def _receive(self):
event = await self._incoming_events.get()
if isinstance(event, MessageReceived):
self._grpc_connection.acknowledge_received_data(self._stream_id,
event.flow_controlled_length)
await self._socket.flush()
elif isinstance(event, RequestEnded) or isinstance(event, ResponseEnded):
assert self._end_stream_event is None
self._end_stream_event = event
else:
assert self._start_stream_event is None
self._start_stream_event = event
return event
async def close(self, status=None, content_type_suffix="", custom_metadata=()):
if self.client_side and (status or custom_metadata):
raise ValueError("Client side streams cannot be closed with non-default arguments")
if self._state in (GRPCStreamState.HALF_CLOSED_LOCAL, GRPCStreamState.CLOSED):
raise TypeError("Closing already closed stream")
self._close_local()
if self.client_side:
try:
self._grpc_connection.end_request(self._stream_id)
except StreamClosedError:
# Remote end already closed connection, do nothing here
pass
elif self._response_started:
self._grpc_connection.end_response(self._stream_id, status, custom_metadata)
else:
self._grpc_connection.respond_status(self._stream_id, status,
content_type_suffix, custom_metadata)
await self._socket.flush()
async def start_response(self, content_type_suffix="", custom_metadata=()):
if self.client_side:
raise ValueError("Cannot start response on client-side socket")
self._grpc_connection.start_response(self._stream_id, content_type_suffix, custom_metadata)
self._response_started = True
await self._socket.flush()
# TODO: this name is not correct, should be something like GRPCConnection (but this name is already
# occupied)
class GRPCSocket(async_exit_stack.AsyncExitStack):
StreamClass = GRPCStream
def __init__(self, config: GRPCConfiguration, sock,
receive_buffer_size=1024*1024):
super().__init__()
self._grpc_connection = GRPCConnection(config=config)
self._socket = SocketWrapper(self._grpc_connection, sock)
self._receive_buffer_size = receive_buffer_size
self._streams = {} # type: Dict[int, GRPCStream]
async def __aenter__(self):
await super().__aenter__()
self._socket = await self.enter_async_context(self._socket)
self._grpc_connection.initiate_connection()
await self._socket.flush()
if self.client_side:
task_group = await self.enter_async_context(anyio.create_task_group())
self.push_async_callback(task_group.cancel_scope.cancel)
await task_group.spawn(self._reader_thread)
return self
@property
def client_side(self):
return self._grpc_connection.config.client_side
def _stream_ctor(self, stream_id):
return self.StreamClass(self._grpc_connection, stream_id, self._socket, self)
def _allocate_stream(self, stream_id):
self._streams[stream_id] = self._stream_ctor(stream_id)
return self._streams[stream_id]
@async_generator
async def _listen(self):
while True:
data = await self._socket.recv(self._receive_buffer_size)
if not data:
return
events = self._grpc_connection.receive_data(data)
await self._socket.flush()
for event in events:
if isinstance(event, WindowUpdated):
if event.stream_id == 0:
for stream in self._streams.values():
await stream._set_flow_control_update()
elif event.stream_id in self._streams:
await self._streams[event.stream_id]._set_flow_control_update()
continue
elif isinstance(event, RequestReceived):
self._allocate_stream(event.stream_id)
await self._streams[event.stream_id]._incoming_events.put(event)
if isinstance(event, RequestReceived):
await yield_(self._streams[event.stream_id])
elif isinstance(event, ResponseEnded) or isinstance(event, RequestEnded):
self._streams[event.stream_id]._close_remote()
async def _reader_thread(self):
async for _ in self._listen():
raise ProtocolError("Received request on client end")
@async_generator
async def listen(self):
if self.client_side:
raise ValueError("Cannot listen client-side socket")
await yield_from_(self._listen())
async def start_request(self, scheme: str, service_name: str, method_name: str,
message_type=None, authority=None, timeout: datetime.timedelta=None,
content_type_suffix="", custom_metadata=()):
if not self.client_side:
raise ValueError("Cannot start request on server-side socket")
stream_id = self._grpc_connection.get_next_available_stream_id()
stream = self._allocate_stream(stream_id)
self._grpc_connection.start_request(stream_id, scheme, service_name, method_name,
message_type, authority, timeout,
content_type_suffix, custom_metadata)
await self._socket.flush()
return stream