Skip to content

Commit

Permalink
Merge pull request #252 from ddelange/patch-1
Browse files Browse the repository at this point in the history
Allow passing kwds to ProcessPool
  • Loading branch information
mmckerns committed Dec 22, 2022
2 parents d7e9e0b + ab32720 commit dfd15d0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
18 changes: 11 additions & 7 deletions pathos/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ def __init__(self, *args, **kwds):
kwds['ncpus'] = kwds.pop('nodes')
elif arglen:
kwds['ncpus'] = args[0]
self.__nodes = kwds.get('ncpus', cpu_count())
self.__nodes = kwds.pop('ncpus', cpu_count())

# Create an identifier for the pool
self._id = kwds.get('id', None) #'pool'
self._id = kwds.pop('id', None) #'pool'
if self._id is None:
self._id = self.__nodes

self._kwds = kwds

# Create a new server if one isn't already initialized
self._serve()
return
Expand All @@ -114,16 +116,17 @@ def _serve(self, nodes=None): #XXX: should be STATE method; use id
"""Create a new server if one isn't already initialized"""
if nodes is None: nodes = self.__nodes
_pool = __STATE.get(self._id, None)
if not _pool or nodes != _pool.__nodes:
if not _pool or nodes != _pool.__nodes or self._kwds != _pool._kwds:
self._clear()
_pool = Pool(nodes)
_pool = Pool(nodes, **self._kwds)
_pool.__nodes = nodes
_pool._kwds = self._kwds
__STATE[self._id] = _pool
return _pool
def _clear(self): #XXX: should be STATE method; use id
"""Remove server with matching state"""
_pool = __STATE.get(self._id, None)
if _pool and self.__nodes == _pool.__nodes:
if _pool and self.__nodes == _pool.__nodes and self._kwds == _pool._kwds:
_pool.close()
_pool.join()
__STATE.pop(self._id, None)
Expand Down Expand Up @@ -177,14 +180,15 @@ def __set_nodes(self, nodes):
def restart(self, force=False):
"restart a closed pool"
_pool = __STATE.get(self._id, None)
if _pool and self.__nodes == _pool.__nodes:
if _pool and self.__nodes == _pool.__nodes and self._kwds == _pool._kwds:
RUN = 0
if not force:
assert _pool._state != RUN
# essentially, 'clear' and 'serve'
self._clear()
_pool = Pool(self.__nodes)
_pool = Pool(self.__nodes, **self._kwds)
_pool.__nodes = self.__nodes
_pool._kwds = self._kwds
__STATE[self._id] = _pool
return _pool
def close(self):
Expand Down
43 changes: 42 additions & 1 deletion pathos/tests/test_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# Copyright (c) 2016-2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/pathos/blob/master/LICENSE
import time

def test_mp():
# instantiate and configure the worker pool
from pathos.pools import ProcessPool
pool = ProcessPool(nodes=4)

_result = list(map(pow, [1,2,3,4], [5,6,7,8]))
_result = list(map(pow, [1,2,3,4], [5,6,7,8]))

# do a blocking map on the chosen function
result = pool.map(pow, [1,2,3,4], [5,6,7,8])
Expand All @@ -27,6 +28,45 @@ def test_mp():
result = result_queue.get()
assert result == _result

# test ProcessPool keyword argument propagation
pool.clear()
pool = ProcessPool(nodes=4, initializer=lambda: time.sleep(0.6))
start = time.monotonic()
result = pool.map(pow, [1,2,3,4], [5,6,7,8])
end = time.monotonic()
assert result == _result
assert end - start > 0.5

def test_tp():
# instantiate and configure the worker pool
from pathos.pools import ThreadPool
pool = ThreadPool(nodes=4)

_result = list(map(pow, [1,2,3,4], [5,6,7,8]))

# do a blocking map on the chosen function
result = pool.map(pow, [1,2,3,4], [5,6,7,8])
assert result == _result

# do a non-blocking map, then extract the result from the iterator
result_iter = pool.imap(pow, [1,2,3,4], [5,6,7,8])
result = list(result_iter)
assert result == _result

# do an asynchronous map, then get the results
result_queue = pool.amap(pow, [1,2,3,4], [5,6,7,8])
result = result_queue.get()
assert result == _result

# test ThreadPool keyword argument propagation
pool.clear()
pool = ThreadPool(nodes=4, initializer=lambda: time.sleep(0.6))
start = time.monotonic()
result = pool.map(pow, [1,2,3,4], [5,6,7,8])
end = time.monotonic()
assert result == _result
assert end - start > 0.5


def test_chunksize():
# instantiate and configure the worker pool
Expand Down Expand Up @@ -115,5 +155,6 @@ def test_chunksize():
from pathos.helpers import freeze_support, shutdown
freeze_support()
test_mp()
test_tp()
test_chunksize()
shutdown()
18 changes: 11 additions & 7 deletions pathos/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def __init__(self, *args, **kwds):
kwds['nthreads'] = kwds.pop('nodes')
elif arglen:
kwds['nthreads'] = args[0]
self.__nodes = kwds.get('nthreads', cpu_count())
self.__nodes = kwds.pop('nthreads', cpu_count())

# Create an identifier for the pool
self._id = kwds.get('id', None) #'threads'
self._id = kwds.pop('id', None) #'threads'
if self._id is None:
self._id = self.__nodes

self._kwds = kwds

# Create a new server if one isn't already initialized
self._serve()
return
Expand All @@ -111,16 +113,17 @@ def _serve(self, nodes=None): #XXX: should be STATE method; use id
"""Create a new server if one isn't already initialized"""
if nodes is None: nodes = self.__nodes
_pool = __STATE.get(self._id, None)
if not _pool or nodes != _pool.__nodes:
if not _pool or nodes != _pool.__nodes or self._kwds != _pool._kwds:
self._clear()
_pool = _ThreadPool(nodes)
_pool = _ThreadPool(nodes, **self._kwds)
_pool.__nodes = nodes
_pool._kwds = self._kwds
__STATE[self._id] = _pool
return _pool
def _clear(self): #XXX: should be STATE method; use id
"""Remove server with matching state"""
_pool = __STATE.get(self._id, None)
if _pool and self.__nodes == _pool.__nodes:
if _pool and self.__nodes == _pool.__nodes and self._kwds == _pool._kwds:
_pool.close()
_pool.join()
__STATE.pop(self._id, None)
Expand Down Expand Up @@ -174,14 +177,15 @@ def __set_nodes(self, nodes):
def restart(self, force=False):
"restart a closed pool"
_pool = __STATE.get(self._id, None)
if _pool and self.__nodes == _pool.__nodes:
if _pool and self.__nodes == _pool.__nodes and self._kwds == _pool._kwds:
RUN = 0
if not force:
assert _pool._state != RUN
# essentially, 'clear' and 'serve'
self._clear()
_pool = _ThreadPool(self.__nodes)
_pool = _ThreadPool(self.__nodes, **self._kwds)
_pool.__nodes = self.__nodes
_pool._kwds = self._kwds
__STATE[self._id] = _pool
return _pool
def close(self):
Expand Down

0 comments on commit dfd15d0

Please sign in to comment.