Skip to content

Commit

Permalink
feat: initial version of peer processor connection and discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
nokome committed Dec 13, 2018
1 parent 022f477 commit b6a2659
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 51 deletions.
58 changes: 56 additions & 2 deletions src/Processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""

import json
from typing import Any, Dict, Union, List, Optional
from typing import Any, Dict, Union, List, Optional, Type

from .types.Thing import Thing
from .types.utils import cast, hydrate, dehydrate


class Processor:
"""
The base class for document processors.
Expand All @@ -18,11 +19,64 @@ class Processor:
methods in derived classes.
"""

def __init__(self, client_types: List[Type['Client']] = None, server_types: List[Type['Server']] = None):
self.client_types = client_types
self.server_types = server_types

self.clients = {}
self.servers = {}
self.logger = None


async def start(self):
for server_type in self.server_types:
server = server_type(self)
await server.start()
self.servers[server.url] = server

async def stop(self):
for url in list(self.servers):
server = self.servers[url]
await server.stop()
del self.servers[url]

async def connect(self, url):
"""
Connect to a peer processor.
"""
if url not in self.clients:
success = False
for client_type in self.client_types:
if client_type.connectable(url):
client = client_type(url)
await client.start()
self.clients[url] = client
success = True
break
if not success:
raise RuntimeError(f'No client types able to connect to {url}')

async def disconnect(self, url: str = None):
if url is None:
for url in list(self.clients):
await self.disconnect(url)
if url in self.clients:
await self.clients[url].stop()
del self.clients[url]

async def discover(self):
for client_type in self.client_types:
for client in await client_type.discover():
# Ensure that not connecting to one of my own servers
# or a server that I'm already connected to
if client.url not in self.servers and client.url not in self.clients:
self.clients[client.url] = client

async def hello(self, version: str) -> Dict:
return {}

async def goodbye(self) -> Dict:
pass
return {}

async def import_(self, thing: Union[str, dict, Thing],
format: str = 'application/json', type: Optional[str] = None) -> Thing:
Expand Down
18 changes: 17 additions & 1 deletion src/comms/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@

class Client(Logger):

_url: Optional[str]

encoders: List[Encoder]

encoder: Encoder

futures: Dict[int, Future] = {}

def __init__(self, encoders: Optional[List[Encoder]] = None):
def __init__(self, url: str = None, encoders: Optional[List[Encoder]] = None):
self._url = url

if encoders is None:
encoders = [JsonEncoder()]
else:
Expand All @@ -27,6 +31,18 @@ def __init__(self, encoders: Optional[List[Encoder]] = None):
# request based on the response from the server
self.encoder = JsonEncoder()

@staticmethod
def connectable(url: str) -> bool:
return False

@staticmethod
async def discover() -> List['Client']:
return []

@property
def url(self):
return self._url

async def start(self) -> None:
"""
Start this client.
Expand Down
6 changes: 5 additions & 1 deletion src/comms/Server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def __init__(self, processor: Processor, encoders: List[Encoder] = None):
assert encoders
self.encoders = encoders

@property
def url(self):
return None

async def start(self) -> None:
"""
Start this server.
Starts listening for requests.
"""
self.log(starting=True)
await self.open()
self.log(started=True, url=self.url)

async def open(self) -> None:
raise NotImplementedError()
Expand Down
4 changes: 2 additions & 2 deletions src/comms/StreamClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

class StreamClient(StreamMixin, Client):

def __init__(self, connection: Optional[StreamConnection] = None, encoders=None, ):
def __init__(self, connection: Optional[StreamConnection] = None, url: str = None, encoders=None, ):
StreamMixin.__init__(self, connection)
Client.__init__(self, encoders)
Client.__init__(self, url=url, encoders=encoders)

async def open(self) -> None:
assert self.connection
Expand Down
6 changes: 0 additions & 6 deletions src/comms/StreamMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ class StreamMixin:
def __init__(self, connection: Optional[StreamConnection] = None):
self.connection = connection

@property
def url(self) -> str:
# Currently we only use standard I/O pipes, but in the future
# may provide for named pipes
return 'pipe://stdio'

async def write(self, message: bytes) -> None:
assert self.connection
await self.connection.write(message)
Expand Down
38 changes: 33 additions & 5 deletions src/comms/TcpClient.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import asyncio
import re

Expand All @@ -9,18 +10,45 @@
class TcpClient(StreamClient):

def __init__(self, url: str = 'tcp://127.0.0.1', encoders=None):
StreamClient.__init__(self, encoders=encoders)
StreamClient.__init__(self, url=url, encoders=encoders)
match = TCP_URL_REGEX.match(url)
if match:
self._url = url
self._host = match.group(1)
self._port = int(match.group(2)) if match.group(2) else 2000
else:
raise RuntimeError(f'Invalid URL for TCP: {url}')

@property
def url(self):
return self._url
@staticmethod
def connectable(url: str) -> bool:
return url[:6] == 'tcp://'

@staticmethod
async def discover() -> List['Client']:
"""
Discover `TcpServers`.
Currently this is a naive implementation which scan a limited number
of ports on localhost. It is nonetheless, useful for testings.
Future implementations, may use a service discovery approach e.g mDNS, Consul
:raises exc: Any unhandled exception when attepting to scan a port
:return: List of ``TcpClients``
"""

clients = []
for port in range(2000, 2010):
client = TcpClient(f'tcp://127.0.0.1:{port}')
try:
await client.start()
except OSError as exc:
if exc.errno not in (
111, # "Connect call failed"
):
raise exc
else:
clients.append(client)
return clients

async def open(self) -> None:
reader, writer = await asyncio.open_connection(self._host, self._port)
Expand Down
16 changes: 12 additions & 4 deletions src/comms/TcpServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@ class TcpServer(StreamMultiServer):

def __init__(self, processor: Processor, host: str = '127.0.0.1', port: int = 2000, encoders=None):
StreamMultiServer.__init__(self, processor, encoders)
self.host = host
self.port = port
self._host = host
self._port = port

@property
def url(self):
return f'tcp://{self.host}:{self.port}'
return f'tcp://{self._host}:{self._port}'

async def open(self) -> None:
await asyncio.start_server(self.on_client_connected, self.host, self.port)
try:
await asyncio.start_server(self.on_client_connected, self._host, self._port)
except OSError as exc:
if 'address already in use' in str(exc):
# Port is already being used, try again with the next port
self._port += 1
await self.open()
else:
raise exc
27 changes: 22 additions & 5 deletions src/comms/UnixSocketClient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List
import asyncio
import os
import re

from .stencilaFiles import get_tempdir, list_tempfiles
from .StreamConnection import StreamConnection
from .StreamClient import StreamClient

Expand All @@ -9,17 +12,31 @@
class UnixSocketClient(StreamClient):

def __init__(self, url: str, encoders=None):
StreamClient.__init__(self, encoders=None)
StreamClient.__init__(self, url=url, encoders=None)
match = UNIX_URL_REGEX.match(url)
if match:
self._url = url
self._path = match.group(1)
else:
raise RuntimeError(f'Invalid URL for UNIX: {url}')

@property
def url(self):
return self._url
@staticmethod
def connectable(url: str) -> bool:
return url[:7] == 'unix://'

@staticmethod
async def discover() -> List['Client']:
clients = []
tempdir = get_tempdir()
for filename in list_tempfiles():
if filename.startswith('unix-'):
client = UnixSocketClient('unix://' + os.path.join(tempdir, filename))
try:
await client.start()
except ConnectionRefusedError as exc:
pass # TODO log(message=str(exc))
else:
clients.append(client)
return clients

async def open(self) -> None:
reader, writer = await asyncio.open_unix_connection(self._path)
Expand Down
25 changes: 21 additions & 4 deletions src/comms/UnixSocketServer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
from typing import Optional
import asyncio
import os
import random
import string
import tempfile

from ..Processor import Processor
from .stencilaFiles import get_tempfile, delete_tempfile
from .StreamMultiServer import StreamMultiServer

class UnixSocketServer(StreamMultiServer):
"""
A Server communicating over UNIX domain sockets
"""

def __init__(self, processor: Processor, path: str):
_path: Optional[str]

def __init__(self, processor: Processor):
StreamMultiServer.__init__(self, processor)
self.path = path
self._id = 'unix-py-' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=32))
self._path = None

@property
def url(self):
return f'unix://{self.path}'
return f'unix://{self._path}'

async def open(self) -> None:
"""
Start the UNIX socket server and create an
async connections when a client connects.
"""
await asyncio.start_unix_server(self.on_client_connected, self.path)
self._path = get_tempfile(self._id)
await asyncio.start_unix_server(self.on_client_connected, self._path)
# Change the permissions on the file so that no other user can read/write it
# This needs to be done after the server starts
os.chmod(self._path, 0o600)

async def close(self) -> None:
await StreamMultiServer.close(self)
delete_tempfile(self._id)
51 changes: 51 additions & 0 deletions src/comms/stencilaFiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List
import getpass
import os
import stat
import tempfile

def get_tempdir() -> str:
return os.path.join(tempfile.gettempdir(), 'stencila', getpass.getuser())

def create_tempdir() -> str:
tempdir = get_tempdir()
if not os.path.exists(tempdir):
os.makedirs(tempdir, mode=0o700)
return tempdir

def get_tempfile(name: str) -> str:
return os.path.join(create_tempdir(), name)

def list_tempfiles() -> List[str]:
return os.listdir(get_tempdir())

def create_tempfile(name: str, content: str = None) -> str:
path = get_tempfile(name)

# Write content to a secure file only readable by current user
# Based on https://stackoverflow.com/a/15015748/4625911

# Remove any existing file with potentially elevated mode
if os.path.isfile(path):
os.remove(path)

# Create a file handle
mode = 0o600
umask = 0o777 ^ mode # Prevents always downgrading umask to 0.
umask_original = os.umask(umask)
try:
fd = os.open(path, os.O_WRONLY | os.O_CREAT, mode)
finally:
os.umask(umask_original)

if content:
# Open file fd and write to file
with os.fdopen(fd, 'w') as file:
file.write(content)

return path

def delete_tempfile(name: str):
path = get_tempfile(name)
if os.path.exists(path):
os.remove(path)
Loading

0 comments on commit b6a2659

Please sign in to comment.