Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add restrict_to_localhost flag for the nameserver #22

Merged
merged 11 commits into from
Mar 4, 2020
2 changes: 2 additions & 0 deletions .stickler.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
linters:
flake8:
fixer: true
python: 3
config: setup.cfg
fixers:
enable: true
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
language: python
python:
- '3.4'
- '3.5'
- '3.6'
- '3.7'
- '3.8'
install:
- pip install .
- pip install coveralls
Expand Down
5 changes: 4 additions & 1 deletion bin/nameserver
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ if __name__ == '__main__':
action="store_true")
parser.add_argument("--no-multicast", help="disable multicasting",
action="store_true")
parser.add_argument("-r", "--restrict-to-localhost", help="restrict incomming processes to localhost",
mraspaud marked this conversation as resolved.
Show resolved Hide resolved
action="store_true")
opts = parser.parse_args()

if opts.log:
Expand All @@ -67,8 +69,9 @@ if __name__ == '__main__':
logger = logging.getLogger("nameserver")

multicast_enabled = (opts.no_multicast == False)
local_only = (opts.restrict_to_localhost)

ns = NameServer(multicast_enabled=multicast_enabled)
ns = NameServer(multicast_enabled=multicast_enabled, restrict_to_localhost=local_only)

if opts.daemon is None:
try:
Expand Down
80 changes: 47 additions & 33 deletions posttroll/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

from datetime import datetime, timedelta

from zmq import REQ, REP, LINGER, POLLIN, NOBLOCK
import netifaces
from zmq import REP, LINGER

from posttroll.bbmcast import MulticastReceiver, SocketTimeout
from posttroll.message import Message
Expand All @@ -51,20 +52,32 @@

default_publish_port = 16543

#-----------------------------------------------------------------------------
ten_minutes = timedelta(minutes=10)
zero_seconds = timedelta(seconds=0)


def get_local_ips():
inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET)
for iface in netifaces.interfaces()]
ips = []
for addr in inet_addrs:
if addr is not None:
for add in addr:
ips.append(add['addr'])
return ips

# -----------------------------------------------------------------------------
#
# General thread to receive broadcast addresses.
#
#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


class AddressReceiver(object):
"""General thread to receive broadcast addresses."""

"""General thread to receive broadcast addresses.
"""

def __init__(self, max_age=timedelta(minutes=10), port=None,
do_heartbeat=True, multicast_enabled=True):
def __init__(self, max_age=ten_minutes, port=None,
do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False):
self._max_age = max_age
self._port = port or default_publish_port
self._address_lock = threading.Lock()
Expand All @@ -76,29 +89,27 @@ def __init__(self, max_age=timedelta(minutes=10), port=None,
self._do_run = False
self._is_running = False
self._thread = threading.Thread(target=self._run)
self._restrict_to_localhost = restrict_to_localhost
self._local_ips = get_local_ips()

def start(self):
"""Start the receiver.
"""
"""Start the receiver."""
if not self._is_running:
self._do_run = True
self._thread.start()
return self

def stop(self):
"""Stop the receiver.
"""
"""Stop the receiver."""
self._do_run = False
return self

def is_running(self):
"""Check if the receiver is alive.
"""
"""Check if the receiver is alive."""
return self._is_running

def get(self, name=""):
"""Get the address(es).
"""
"""Get the address(es)."""
addrs = []

with self._address_lock:
Expand All @@ -111,9 +122,8 @@ def get(self, name=""):
LOGGER.debug('return address %s', str(addrs))
return addrs

def _check_age(self, pub, min_interval=timedelta(seconds=0)):
"""Check the age of the receiver.
"""
def _check_age(self, pub, min_interval=zero_seconds):
"""Check the age of the receiver."""
now = datetime.utcnow()
if (now - self._last_age_check) <= min_interval:
return
Expand All @@ -136,17 +146,13 @@ def _check_age(self, pub, min_interval=timedelta(seconds=0)):
del self._addresses[addr]

def _run(self):
"""Run the receiver.
"""
"""Run the receiver."""
port = broadcast_port
nameservers = []
if self._multicast_enabled:
recv = MulticastReceiver(port).settimeout(2.)
while True:
try:
recv = MulticastReceiver(port).settimeout(2.)
LOGGER.info("Receiver initialized.")
break
recv = MulticastReceiver(port)
except IOError as err:
if err.errno == errno.ENODEV:
LOGGER.error("Receiver initialization failed "
Expand All @@ -156,6 +162,11 @@ def _run(self):
time.sleep(10)
else:
raise
else:
recv.settimeout(tout=2.0)
LOGGER.info("Receiver initialized.")
break

else:
recv = _SimpleReceiver(port)
nameservers = ["localhost"]
Expand All @@ -167,8 +178,13 @@ def _run(self):
while self._do_run:
try:
data, fromaddr = recv()
if self._multicast_enabled:
ip_, port = fromaddr
if self._restrict_to_localhost and ip_ not in self._local_ips:
# discard external message
LOGGER.debug('Discard external message')
continue
LOGGER.debug("data %s", data)
del fromaddr
except SocketTimeout:
if self._multicast_enabled:
LOGGER.debug("Multicast socket timed out on recv!")
Expand Down Expand Up @@ -198,17 +214,15 @@ def _run(self):
recv.close()

def _add(self, adr, metadata):
"""Add an address.
"""
"""Add an address."""
with self._address_lock:
metadata["receive_time"] = datetime.utcnow()
self._addresses[adr] = metadata


class _SimpleReceiver(object):

""" Simple listing on port for address messages.
"""
""" Simple listing on port for address messages."""

def __init__(self, port=None):
self._port = port or default_publish_port
Expand All @@ -221,11 +235,11 @@ def __call__(self):
return message, None

def close(self):
"""Close the receiver.
"""
"""Close the receiver."""
self._socket.setsockopt(LINGER, 1)
self._socket.close()

#-----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# default
getaddress = AddressReceiver
114 changes: 58 additions & 56 deletions posttroll/bbmcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# posttroll. If not, see <http://www.gnu.org/licenses/>.

"""Send/receive UDP multicast packets.

Requires that your OS kernel supports IP multicast.

This is based on python-examples Demo/sockets/mcast.py
Expand All @@ -48,17 +49,15 @@

SocketTimeout = timeout # for easy access to socket.timeout

#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
#
# Sender.
#
#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


class MulticastSender(object):

"""Multicast sender on *port* and *mcgroup*.
"""
"""Multicast sender on *port* and *mcgroup*."""

def __init__(self, port, mcgroup=MC_GROUP):
self.port = port
Expand All @@ -70,8 +69,7 @@ def __call__(self, data):
self.socket.sendto(data.encode(), (self.group, self.port))

def close(self):
"""Close the sender.
"""
"""Close the sender."""
self.socket.close()

# Allow non-object interface
Expand All @@ -81,30 +79,32 @@ def mcast_sender(mcgroup=MC_GROUP):
"""Non-object interface for sending multicast messages.
"""
sock = socket(AF_INET, SOCK_DGRAM)
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if _is_broadcast_group(mcgroup):
group = '<broadcast>'
sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
elif((int(mcgroup.split(".")[0]) > 239) or
(int(mcgroup.split(".")[0]) < 224)):
raise IOError("Invalid multicast address.")
else:
group = mcgroup
ttl = struct.pack('b', TTL_LOCALNET) # Time-to-live
sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl)
try:
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if _is_broadcast_group(mcgroup):
group = '<broadcast>'
sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
elif((int(mcgroup.split(".")[0]) > 239) or
(int(mcgroup.split(".")[0]) < 224)):
raise IOError("Invalid multicast address.")
else:
group = mcgroup
ttl = struct.pack('b', TTL_LOCALNET) # Time-to-live
sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl)
except Exception:
sock.close()
raise
return sock, group

#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
#
# Receiver.
#
#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


class MulticastReceiver(object):

"""Multicast receiver on *port* for an *mcgroup*.
"""
"""Multicast receiver on *port* for an *mcgroup*."""
BUFSIZE = 1024

def __init__(self, port, mcgroup=MC_GROUP):
Expand All @@ -123,10 +123,8 @@ def __call__(self):
return data.decode(), sender

def close(self):
"""Close the receiver.
"""
self.socket.setsockopt(SOL_SOCKET, SO_LINGER,
struct.pack('ii', 1, 1))
"""Close the receiver."""
self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack('ii', 1, 1))
self.socket.close()

# Allow non-object interface
Expand All @@ -144,41 +142,45 @@ def mcast_receiver(port, mcgroup=MC_GROUP):
# Create a socket
sock = socket(AF_INET, SOCK_DGRAM)

# Allow multiple copies of this program on one machine
# (not strictly needed)
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if group:
sock.setsockopt(SOL_IP, IP_MULTICAST_TTL, TTL_LOCALNET) # default
sock.setsockopt(SOL_IP, IP_MULTICAST_LOOP, 1) # default

# Bind it to the port
sock.bind(('', port))

# Look up multicast group address in name server
# (doesn't hurt if it is already in ddd.ddd.ddd.ddd format)
if group:
group = gethostbyname(group)

# Construct binary group address
bytes_ = [int(b) for b in group.split(".")]
grpaddr = 0
for byte in bytes_:
grpaddr = (grpaddr << 8) | byte

# Construct struct mreq from grpaddr and ifaddr
ifaddr = INADDR_ANY
mreq = struct.pack('!LL', grpaddr, ifaddr)

# Add group membership
sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq)
try:
# Allow multiple copies of this program on one machine
# (not strictly needed)
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if group:
sock.setsockopt(SOL_IP, IP_MULTICAST_TTL, TTL_LOCALNET) # default
sock.setsockopt(SOL_IP, IP_MULTICAST_LOOP, 1) # default

# Bind it to the port
sock.bind(('', port))

# Look up multicast group address in name server
# (doesn't hurt if it is already in ddd.ddd.ddd.ddd format)
if group:
group = gethostbyname(group)

# Construct binary group address
bytes_ = [int(b) for b in group.split(".")]
grpaddr = 0
for byte in bytes_:
grpaddr = (grpaddr << 8) | byte

# Construct struct mreq from grpaddr and ifaddr
ifaddr = INADDR_ANY
mreq = struct.pack('!LL', grpaddr, ifaddr)

# Add group membership
sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq)
except Exception:
sock.close()
raise

return sock, group or '<broadcast>'

#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
#
# Small helpers.
#
#-----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


def _is_broadcast_group(group):
Expand Down