diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 27106ef..32cf5bf 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -146,6 +146,42 @@ def _check_age(self, pub, min_interval=zero_seconds): for addr in to_del: del self._addresses[addr] + def _recv_loop_and_parse_data(self, recv, pub): + while self._do_run: + try: + data, fromaddr = recv() + if self._multicast_enabled: + ip_, port = fromaddr + if self._restrict_to_localhost and ip_ not in self._local_ips: + # discard external message + LOGGER.debug('Discard external message') + continue + LOGGER.debug("data %s", data) + except SocketTimeout: + if self._multicast_enabled: + LOGGER.debug("Multicast socket timed out on recv!") + continue + finally: + self._check_age(pub, min_interval=self._max_age / 20) + if self._do_heartbeat: + pub.heartbeat(min_interval=29) + msg = Message.decode(data) + name = msg.subject.split("/")[1] + if (msg.type == 'info' and + msg.subject.lower().startswith(self._subject)): + addr = msg.data["URI"] + msg.data['status'] = True + metadata = copy.copy(msg.data) + metadata["name"] = name + + LOGGER.debug('receiving address %s %s %s', str(addr), + str(name), str(metadata)) + if addr not in self._addresses: + LOGGER.info("nameserver: publish add '%s'", + str(msg)) + pub.send(msg.encode()) + self._add(addr, metadata) + def _run(self): """Run the receiver.""" port = broadcast_port @@ -173,46 +209,17 @@ def _run(self): nameservers = ["localhost"] self._is_running = True - with Publish("address_receiver", self._port, ["addresses"], - nameservers=nameservers) as pub: - try: - while self._do_run: - try: - data, fromaddr = recv() - if self._multicast_enabled: - ip_, port = fromaddr - if self._restrict_to_localhost and ip_ not in self._local_ips: - # discard external message - LOGGER.debug('Discard external message') - continue - LOGGER.debug("data %s", data) - except SocketTimeout: - if self._multicast_enabled: - LOGGER.debug("Multicast socket timed out on recv!") - continue - finally: - self._check_age(pub, min_interval=self._max_age / 20) - if self._do_heartbeat: - pub.heartbeat(min_interval=29) - msg = Message.decode(data) - name = msg.subject.split("/")[1] - if(msg.type == 'info' and - msg.subject.lower().startswith(self._subject)): - addr = msg.data["URI"] - msg.data['status'] = True - metadata = copy.copy(msg.data) - metadata["name"] = name - - LOGGER.debug('receiving address %s %s %s', str(addr), - str(name), str(metadata)) - if addr not in self._addresses: - LOGGER.info("nameserver: publish add '%s'", - str(msg)) - pub.send(msg.encode()) - self._add(addr, metadata) - finally: - self._is_running = False - recv.close() + try: + with Publish("address_receiver", self._port, ["addresses"], + nameservers=nameservers) as pub: + try: + self._recv_loop_and_parse_data(recv, pub) + finally: + self._is_running = False + recv.close() + except OSError: + LOGGER.exception("Fails to start address receiver run loop.") + self._is_running = False def _add(self, adr, metadata): """Add an address.""" diff --git a/posttroll/ns.py b/posttroll/ns.py index 5baf54a..83ef232 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -132,6 +132,9 @@ def run(self, *args): multicast_enabled=self._multicast_enabled, restrict_to_localhost=self._restrict_to_localhost) arec.start() + if not arec.is_running(): + logger.error("Address Receiver fails to start.") + return port = PORT try: diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 67bdd47..47a33fc 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -446,6 +446,16 @@ def test_localhost_restriction(self, mcrec, pub, msg): msg.decode.assert_not_called() adr.stop() + @mock.patch("posttroll.address_receiver.Publish") + def test_publish_oserror(self, pub): + """Test address receiver handle oserror in publish.""" + pub.side_effect = OSError + from posttroll.address_receiver import AddressReceiver + adr = AddressReceiver() + adr.start() + time.sleep(3) + self.assertFalse(adr.is_running()) + adr.stop() class TestPublisherDictConfig(unittest.TestCase): """Test configuring publishers with a dictionary.""" @@ -593,6 +603,18 @@ def test_dict_config_subscriber(NSSubscriber, Subscriber): NSSubscriber.assert_not_called() +@mock.patch('posttroll.ns.AddressReceiver') +def test_nameserver_addressreceiver_fails_to_start(arec): + from posttroll.ns import NameServer + arec_instance = mock.Mock() + arec.return_value = arec_instance + arec_instance.is_running.return_value = False + ns = NameServer(max_age=timedelta(seconds=3), + multicast_enabled=False) + ns_run_ret = ns.run() + assert ns_run_ret is None + + @mock.patch('posttroll.subscriber.NSSubscriber.start') def test_dict_config_full_nssubscriber(NSSubscriber_start): """Test that all NSSubscriber options are passed."""