Skip to content

Commit

Permalink
Change number of workers are tracked in preparation for remote workers
Browse files Browse the repository at this point in the history
  • Loading branch information
Bo Peng committed Aug 5, 2019
1 parent a754f20 commit e3db6ea
Showing 1 changed file with 47 additions and 33 deletions.
80 changes: 47 additions & 33 deletions src/sos/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,29 @@ class WorkerManager(object):
# manager worker processes

def __init__(self, worker_procs, backend_socket):
self._worker_procs = worker_procs
self._max_workers = int(self._worker_procs[0])
if isinstance(worker_procs, (int, str)):
self._worker_procs = [str(worker_procs)]
else:
# should be a sequence
self._worker_procs = worker_procs

# the first item in self._worker_procs is always considered to be the localhost, which is where
# the router lives. The rest of the hosts will be considered as remote workers.
try:
self._max_workers = [int(x.rsplit(':', 1)[-1]) for x in self._worker_procs]
self._num_workers = [0 for x in self._worker_procs]
except:
raise RuntimeError(f'Incorrect format for option -j ({self._worker_procs}), which should be one or more [host:]nproc')

self._local_workers = []
self._num_local_workers = 0

self._num_remote_workers = {}

self._n_requested = 0
self._n_processed = 0

self._local_worker_alive_time = time.time()
self._last_pending_time = {}
# self._last_pending_time = {}

self._substep_requests = []
self._step_requests = {}
Expand All @@ -397,14 +410,14 @@ def __init__(self, worker_procs, backend_socket):

# start a worker, note that we do not start all workers for performance
# considerations
self.start_local_worker()
self.start_worker()

def report(self, msg):
if 'WORKER' in env.config['SOS_DEBUG'] or 'ALL' in env.config[
'SOS_DEBUG']:
env.log_to_file(
'WORKER',
f'{msg.upper()}: {self._num_local_workers} workers (of which {len(self._blocking_ports)} is blocking), {self._n_requested} requested, {self._n_processed} processed'
f'{msg.upper()}: {self._num_workers} workers (of which {len(self._blocking_ports)} is blocking), {self._n_requested} requested, {self._n_processed} processed'
)

def add_request(self, msg_type, msg):
Expand All @@ -419,8 +432,8 @@ def add_request(self, msg_type, msg):

# start a worker is necessary (max_procs could be incorrectly set to be 0 or less)
# if we are just starting, so do not start two workers
if self._n_processed > 0 and not self._available_ports and self._num_local_workers < self._max_workers:
self.start_local_worker()
if self._n_processed > 0 and not self._available_ports and sum(self._num_workers) < sum(self._max_workers):
self.start_worker()

def worker_available(self, blocking, excluded):
if self._available_ports:
Expand All @@ -433,12 +446,12 @@ def worker_available(self, blocking, excluded):

if not blocking:
# no available port, can we start a new worker?
if self._num_local_workers < self._max_workers:
self.start_local_worker()
if sum(self._num_workers) < sum(self._max_workers):
self.start_worker()
return None

# we start a worker right now.
self.start_local_worker()
self.start_worker()
while True:
if not self._worker_backend_socket.poll(5000):
raise RuntimeError('No worker is started after 5 seconds')
Expand All @@ -447,10 +460,10 @@ def worker_available(self, blocking, excluded):
if port is None or port in excluded:
continue
self._claimed_ports.add(port)
self._max_workers += 1
self._max_workers[0] += 1
self._blocking_ports.add(port)
env.logger.debug(
f'Increasing maximum number of workers to {self._max_workers} to accommodate a blocking subworkflow.'
f'Increasing maximum number of local workers to {self._max_workers[0]} to accommodate a blocking subworkflow.'
)
return port

Expand All @@ -474,20 +487,20 @@ def process_request(self, num_pending, ports, request_blocking=False):
# the port is claimed, but the real message is not yet available
self._worker_backend_socket.send(encode_msg({}))
self.report(f'pending with claimed {ports}')
elif any(port in self._blocking_ports for port in ports):
# in block list but appear to be idle, kill it
self._max_workers -= 1
env.logger.debug(
f'Reduce maximum number of workers to {self._max_workers} after completion of a blocking subworkflow.'
)
for port in ports:
if port in self._blocking_ports:
self._blocking_ports.remove(port)
if port in self._available_ports:
self._available_ports.remove(port)
self._worker_backend_socket.send(encode_msg(None))
self._num_local_workers -= 1
self.report(f'Blocking worker {ports} killed')
# elif any(port in self._blocking_ports for port in ports):
# # in block list but appear to be idle, kill it
# self._max_workers -= 1
# env.logger.debug(
# f'Reduce maximum number of workers to {self._max_workers} after completion of a blocking subworkflow.'
# )
# for port in ports:
# if port in self._blocking_ports:
# self._blocking_ports.remove(port)
# if port in self._available_ports:
# self._available_ports.remove(port)
# self._worker_backend_socket.send(encode_msg(None))
# self._num_local_workers -= 1
# self.report(f'Blocking worker {ports} killed')
elif self._substep_requests:
# port is not claimed, free to use for substep worker
msg = self._substep_requests.pop()
Expand Down Expand Up @@ -526,12 +539,12 @@ def process_request(self, num_pending, ports, request_blocking=False):
f'pending with port {ports} at num_pending {num_pending}')
self._last_pending_msg[(ports, num_pending)] = time.time()

def start_local_worker(self):
def start_worker(self):
worker = SoS_Worker(env.config)
worker.start()
self._local_worker_alive_time = time.time()
self._local_workers.append(worker)
self._num_local_workers += 1
self._num_workers[0] += 1
self.report('start worker')

def check_workers(self):
Expand All @@ -545,15 +558,16 @@ def check_workers(self):
self._local_workers = [
worker for worker in self._local_workers if worker.is_alive()
]
if len(self._local_workers) < self._num_local_workers:
if len(self._local_workers) < self._num_workers[0]:
raise ProcessKilled('One of the local workers has been killed.')

def kill_all(self):
'''Kill all workers'''
while self._num_local_workers > 0 and self._worker_backend_socket.poll(1000):
total_num_workers = sum(self._num_workers)
while total_num_workers > 0 and self._worker_backend_socket.poll(1000):
msg = decode_msg(self._worker_backend_socket.recv())
self._worker_backend_socket.send(encode_msg(None))
self._num_local_workers -= 1
total_num_workers -= 1
self.report(f'Kill {msg[1:]}')
# join all processes
# join all local processes
[worker.join() for worker in self._local_workers]

0 comments on commit e3db6ea

Please sign in to comment.