Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/issue 509 #539

Merged
merged 13 commits into from
Feb 1, 2021
1 change: 1 addition & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ coverage:
threshold: 1%
paths:
- "src"
patch: off
46 changes: 33 additions & 13 deletions src/radical/entk/execman/base/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _tmgr(self, uid, rmgr, pending_queue, completed_queue,

# --------------------------------------------------------------------------
#
def _sync_with_master(self, obj, obj_type, channel, queue):
def _sync_with_master(self, obj, obj_type, channel, conn_params, queue):

corr_id = str(uuid.uuid4())
body = json.dumps({'object': obj.to_dict(),
Expand All @@ -148,9 +148,16 @@ def _sync_with_master(self, obj, obj_type, channel, queue):

self._prof.prof('pub_sync', state=obj.state, uid=obj.uid, msg=msg)
self._log.debug('%s (%s) to sync with amgr', obj.uid, obj.state)

channel.basic_publish(exchange='', routing_key=queue, body=body,
try:
channel.basic_publish(exchange='', routing_key=queue, body=body,
properties=pika.BasicProperties(correlation_id=corr_id))
except (pika.exceptions.ConnectionClosed,
pika.exceptions.ChannelClosed):
connection = pika.BlockingConnection(conn_params)
channel = connection.channel()
channel.basic_publish(exchange='', routing_key=queue, body=body,
properties=pika.BasicProperties(correlation_id=corr_id))


# all queue name parts up to the last three are used as sid, the last
# three parts are channel specifiers which need to be inversed to obtain
Expand Down Expand Up @@ -190,7 +197,7 @@ def _sync_with_master(self, obj, obj_type, channel, queue):

# --------------------------------------------------------------------------
#
def _advance(self, obj, obj_type, new_state, channel, queue):
def _advance(self, obj, obj_type, new_state, channel, conn_params, queue):

try:
old_state = obj.state
Expand All @@ -203,14 +210,14 @@ def _advance(self, obj, obj_type, new_state, channel, queue):
self._prof.prof('advance', uid=obj.uid, state=obj.state, msg=msg)
self._log.info('Transition %s to %s', obj.uid, new_state)

self._sync_with_master(obj, obj_type, channel, queue)
self._sync_with_master(obj, obj_type, channel, conn_params, queue)


except Exception as ex:
self._log.exception('Transition %s to state %s failed, error: %s',
obj.uid, new_state, ex)
obj.state = old_state
self._sync_with_master(obj, obj_type, channel, queue)
self._sync_with_master(obj, obj_type, channel, conn_params, queue)
raise EnTKError(ex) from ex


Expand Down Expand Up @@ -240,14 +247,25 @@ def _heartbeat(self):
while not self._hb_terminate.is_set():

corr_id = str(uuid.uuid4())
try:
# Heartbeat request signal sent to task manager via rpc-queue
props = pika.BasicProperties(reply_to=self._hb_response_q,
correlation_id=corr_id)
mq_channel.basic_publish(exchange='',
routing_key=self._hb_request_q,
properties=props,
body='request')
except (pika.exceptions.ConnectionClosed,
pika.exceptions.ChannelClosed):
mq_connection = pika.BlockingConnection(self._rmq_conn_params)
mq_channel = mq_connection.channel()
props = pika.BasicProperties(reply_to=self._hb_response_q,
correlation_id=corr_id)
mq_channel.basic_publish(exchange='',
routing_key=self._hb_request_q,
properties=props,
body='request')

# Heartbeat request signal sent to task manager via rpc-queue
props = pika.BasicProperties(reply_to=self._hb_response_q,
correlation_id=corr_id)
mq_channel.basic_publish(exchange='',
routing_key=self._hb_request_q,
properties=props,
body='request')
self._log.info('Sent heartbeat request')

# Sleep for hb_interval and then check if tmgr responded
Expand All @@ -257,11 +275,13 @@ def _heartbeat(self):
queue=self._hb_response_q)
if not body:
# no usable response
self._log.error('Heartbeat response no body')
return
# raise EnTKError('heartbeat timeout')

if corr_id != props.correlation_id:
# incorrect response
self._log.error('Heartbeat response wrong correlation')
return
# raise EnTKError('heartbeat timeout')

Expand Down
61 changes: 41 additions & 20 deletions src/radical/entk/execman/rp/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,40 @@ def _tmgr(self, uid, rmgr, pending_queue, completed_queue,
try:

# ------------------------------------------------------------------
def heartbeat_response(mq_channel):
def heartbeat_response(mq_channel, conn_params):

channel = mq_channel
try:

# Get request from heartbeat-req for heartbeat response
method_frame, props, body = \
mq_channel.basic_get(queue=self._hb_request_q)
channel.basic_get(queue=self._hb_request_q)

if not body:
return

self._log.info('Received heartbeat request')

nprops = pika.BasicProperties(
try:
nprops = pika.BasicProperties(
correlation_id=props.correlation_id)
channel.basic_publish(exchange='',
routing_key=self._hb_response_q,
properties=nprops,
body='response')
except (pika.exceptions.ConnectionClosed,
pika.exceptions.ChannelClosed):
connection = pika.BlockingConnection(conn_params)
channel = connection.channel()
nprops = pika.BasicProperties(
correlation_id=props.correlation_id)
mq_channel.basic_publish(
exchange='',
routing_key=self._hb_response_q,
properties=nprops,
body='response')
channel.basic_publish(exchange='',
routing_key=self._hb_response_q,
properties=nprops,
body='response')

self._log.info('Sent heartbeat response')

mq_channel.basic_ack(delivery_tag=method_frame.delivery_tag)
channel.basic_ack(delivery_tag=method_frame.delivery_tag)

except Exception as ex:
self._log.exception('Failed to respond to heartbeat, ' +
Expand Down Expand Up @@ -173,7 +183,7 @@ def heartbeat_response(mq_channel):
mq_channel.basic_ack(
delivery_tag=method_frame.delivery_tag)

heartbeat_response(mq_channel)
heartbeat_response(mq_channel, rmq_conn_params)

except Exception as ex:
self._log.exception('Error in task execution: %s', ex)
Expand Down Expand Up @@ -236,10 +246,11 @@ def load_placeholder(task, rts_uid):
'rts_uid': rts_uid}

# ----------------------------------------------------------------------
def unit_state_cb(unit, state):
def unit_state_cb(unit, state, cb_data):

try:

channel = cb_data['channel']
conn_params = cb_data['params']
self._log.debug('Unit %s in state %s' % (unit.uid, unit.state))

if unit.state in rp.FINAL:
Expand All @@ -248,16 +259,24 @@ def unit_state_cb(unit, state):
task = create_task_from_cu(unit, self._prof)

self._advance(task, 'Task', states.COMPLETED,
mq_channel, '%s-cb-to-sync' % self._sid)
channel, conn_params,
'%s-cb-to-sync' % self._sid)

load_placeholder(task, unit.uid)

task_as_dict = json.dumps(task.to_dict())
try:
channel.basic_publish(exchange='',
routing_key='%s-completedq-1' % self._sid,
body=task_as_dict)
except (pika.exceptions.ConnectionClosed,
pika.exceptions.ChannelClosed):
connection = pika.BlockingConnection(conn_params)
channel = connection.channel()
channel.basic_publish(exchange='',
routing_key='%s-completedq-1' % self._sid,
body=task_as_dict)

mq_channel.basic_publish(
exchange='',
routing_key='%s-completedq-1' % self._sid,
body=task_as_dict)

self._log.info('Pushed task %s with state %s to completed '
'queue %s-completedq-1',
Expand All @@ -279,7 +298,8 @@ def unit_state_cb(unit, state):

umgr = rp.UnitManager(session=rmgr._session)
umgr.add_pilots(rmgr.pilot)
umgr.register_callback(unit_state_cb)
umgr.register_callback(unit_state_cb, cb_data={'channel': mq_channel,
'params': rmq_conn_params})

try:

Expand Down Expand Up @@ -311,7 +331,8 @@ def unit_state_cb(unit, state):
task, placeholders, self._prof))

self._advance(task, 'Task', states.SUBMITTING,
mq_channel, '%s-tmgr-to-sync' % self._sid)
mq_channel, rmq_conn_params,
'%s-tmgr-to-sync' % self._sid)

umgr.submit_units(bulk_cuds)
mq_connection.close()
Expand Down
13 changes: 7 additions & 6 deletions tests/test_component/test_tmgr_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def test_advance(self, mocked_init, mocked_Logger, mocked_Profiler,

global_syncs = []

def _sync_side_effect(log_entry, uid, state, msg):
def _sync_side_effect(obj, obj_type, channel, conn_params, queue):
nonlocal global_syncs
global_syncs.append([log_entry, uid, state, msg])
global_syncs.append([obj, obj_type, channel, conn_params, queue])

tmgr._log = mocked_Logger
tmgr._prof = mocked_Profiler
Expand All @@ -120,12 +120,13 @@ def _sync_side_effect(log_entry, uid, state, msg):
obj.parent_pipeline = {'uid': 'test_pipe'}
obj.uid = 'test_object'
obj.state = 'test_state'
tmgr._advance(obj, 'Task', None, 'channel','queue')
self.assertEqual(global_syncs[0],[obj, 'Task', 'channel','queue'])
tmgr._advance(obj, 'Task', None, 'channel','params','queue')
self.assertEqual(global_syncs[0],[obj, 'Task', 'channel', 'params', 'queue'])
self.assertIsNone(obj.state)
global_syncs = []
tmgr._advance(obj, 'Stage', 'new_state', 'channel','queue')
self.assertEqual(global_syncs[0],[obj, 'Stage', 'channel','queue'])
tmgr._advance(obj, 'Stage', 'new_state', 'channel', 'params', 'queue')
self.assertEqual(global_syncs[0],[obj, 'Stage', 'channel', 'params',
'queue'])
self.assertEqual(obj.state, 'new_state')

# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/test_component/test_tmgr_rp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_start_manager(self, mocked_init, mocked_Logger, mocked_Profiler):
tmgr._prof = mocked_Profiler
tmgr._uid = 'tmgr.0000'
tmgr._rmgr = 'test_rmgr'
tmgr._rmq_conn_params = 'test_params'
tmgr._rmq_conn_params = rmq_params
tmgr._pending_queue = ['pending_queues']
tmgr._completed_queue = ['completed_queues']
tmgr._tmgr = _tmgr_side_effect
Expand Down
99 changes: 99 additions & 0 deletions tests/test_integration/test_tmgr_base/test_heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# pylint: disable=protected-access, unused-argument
# pylint: disable=no-value-for-parameter
import os
import pika
import time

from unittest import TestCase, mock

import threading as mt
import radical.utils as ru

from radical.entk.execman.base import Base_TaskManager as BaseTmgr


# ------------------------------------------------------------------------------
#
class TestTask(TestCase):

# --------------------------------------------------------------------------
#
@mock.patch.object(BaseTmgr, '__init__', return_value=None)
@mock.patch('radical.utils.Profiler')
def test_heartbeat(self, mocked_init, mocked_Profiler):

hostname = os.environ.get('RMQ_HOSTNAME', 'localhost')
port = int(os.environ.get('RMQ_PORT', '5672'))
username = os.environ.get('RMQ_USERNAME','guest')
password = os.environ.get('RMQ_PASSWORD','guest')
credentials = pika.PlainCredentials(username, password)
rmq_conn_params = pika.ConnectionParameters(host=hostname, port=port,
credentials=credentials)
tmgr = BaseTmgr(None, None, None, None, None, None)
tmgr._uid = 'tmgr.0000'
tmgr._log = ru.Logger('radical.entk.manager.base', level='DEBUG')
tmgr._prof = mocked_Profiler
tmgr._hb_interval = 0.1
tmgr._hb_terminate = mt.Event()
tmgr._hb_request_q = 'tmgr-hb-request'
tmgr._hb_response_q = 'tmgr-hb-response'
tmgr._rmq_conn_params = rmq_conn_params
mq_connection = pika.BlockingConnection(rmq_conn_params)
mq_channel = mq_connection.channel()
mq_channel.queue_declare(queue='tmgr-hb-request')
mq_channel.queue_declare(queue='tmgr-hb-response')
tmgr._log.info('Starting test')
master_thread = mt.Thread(target=tmgr._heartbeat,
name='tmgr_heartbeat')

master_thread.start()
time.sleep(0.1)
body = None
try:
for _ in range(5):
while body is None:
_, props, body = mq_channel.basic_get(queue='tmgr-hb-request')
self.assertEqual(body, b'request')
nprops = pika.BasicProperties(correlation_id=props.correlation_id)
mq_channel.basic_publish(exchange='',
routing_key='tmgr-hb-response',
properties=nprops,
body='response')
self.assertTrue(master_thread.is_alive())
body = None

time.sleep(0.5)
self.assertFalse(master_thread.is_alive())
master_thread.join()
mq_channel.queue_delete(queue='tmgr-hb-request')
mq_channel.queue_delete(queue='tmgr-hb-response')
mq_channel.queue_declare(queue='tmgr-hb-request')
mq_channel.queue_declare(queue='tmgr-hb-response')

master_thread = mt.Thread(target=tmgr._heartbeat,
name='tmgr_heartbeat')
master_thread.start()
body = None
while body is None:
_, props, body = mq_channel.basic_get(queue='tmgr-hb-request')
mq_channel.basic_publish(exchange='',
routing_key='tmgr-hb-response',
body='response')
time.sleep(0.2)
self.assertFalse(master_thread.is_alive())

except Exception as ex:
tmgr._hb_terminate.set()
master_thread.join()
mq_channel.queue_delete(queue='tmgr-hb-request')
mq_channel.queue_delete(queue='tmgr-hb-response')
mq_channel.close()
mq_connection.close()
raise ex
else:
tmgr._hb_terminate.set()
master_thread.join()
mq_channel.queue_delete(queue='tmgr-hb-request')
mq_channel.queue_delete(queue='tmgr-hb-response')
mq_channel.close()
mq_connection.close()