Skip to content

Commit

Permalink
Resolve potential race condition in mo.wait_actor_pool_recovered (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
keyile authored Aug 14, 2021
1 parent 322c187 commit 2fb00d8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 13 deletions.
20 changes: 10 additions & 10 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,15 +864,11 @@ async def handle_control_command(self,
self.sub_processes[message.address],
timeout=timeout,
force=force)
if self._auto_recover:
self._recover_events[message.address] = asyncio.Event()
processor.result = ResultMessage(message.message_id, True,
protocol=message.protocol)
elif message.control_message_type == ControlMessageType.wait_pool_recovered:
# check the aliveness of sub pool first, in case monitor task haven't found it.
if not await self.is_sub_pool_alive(self.sub_processes[message.address]):
if self._auto_recover and message.address not in self._recover_events:
self._recover_events[message.address] = asyncio.Event()
if self._auto_recover and message.address not in self._recover_events:
self._recover_events[message.address] = asyncio.Event()

event = self._recover_events.get(message.address, None)
if event is not None:
Expand Down Expand Up @@ -1019,6 +1015,10 @@ async def is_sub_pool_alive(self, process: SubProcessHandle):
bool
"""

@abstractmethod
def recover_sub_pool(self, address):
"""Recover a sub actor pool"""

def process_sub_pool_lost(self, address: str):
if self._auto_recover in (False, 'process'):
# process down, when not auto_recover
Expand All @@ -1030,18 +1030,18 @@ async def monitor_sub_pools(self):
while not self._stopped.is_set():
for address in self.sub_processes:
process = self.sub_processes[address]
recover_events_discovered = (address in self._recover_events)
if not await self.is_sub_pool_alive(process): # pragma: no cover
if self._on_process_down is not None:
self._on_process_down(self, address)
self.process_sub_pool_lost(address)
if self._auto_recover:
if address not in self._recover_events:
self._recover_events[address] = asyncio.Event()
await self.recover_sub_pool(address)
if self._on_process_recover is not None:
self._on_process_recover(self, address)
event = self._recover_events.pop(address)
event.set()
if recover_events_discovered:
event = self._recover_events.pop(address)
event.set()

# check every half second
await asyncio.sleep(.5)
Expand Down
3 changes: 3 additions & 0 deletions mars/oscar/backends/ray/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..message import CreateActorMessage
from ..pool import AbstractActorPool, MainActorPoolBase, SubActorPoolBase, create_actor_pool, _register_message_handler
from ..router import Router
from ... import ServerClosed
from ....serialization.ray import register_ray_serializers
from ....utils import lazy_import

Expand Down Expand Up @@ -174,6 +175,8 @@ def _set_ray_server(self, actor_pool: AbstractActorPool):
async def __on_ray_recv__(self, channel_id: ChannelID, message):
"""Method for communication based on ray actors"""
try:
if self._ray_server is None:
raise ServerClosed(f'Remote server {channel_id.dest_address} closed')
return await self._ray_server.__on_ray_recv__(channel_id, message)
except Exception: # pragma: no cover
return RayChannelException(*sys.exc_info())
Expand Down
115 changes: 112 additions & 3 deletions mars/oscar/backends/ray/tests/test_ray_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os

import pytest

import mars.oscar as mo
from mars.oscar.errors import ServerClosed
from mars.oscar.backends.allocate_strategy import ProcessIndex, MainPool
from mars.oscar.backends.ray.pool import RayMainPool, RayMainActorPool, create_actor_pool, PoolStatus
from mars.oscar.backends.ray.utils import process_placement_to_address
from mars.oscar.context import get_context
from mars.tests.core import require_ray
from .....utils import lazy_import
from ..pool import RayMainPool, RayMainActorPool, create_actor_pool
from ..utils import process_placement_to_address
from mars.utils import lazy_import

ray = lazy_import('ray')


class TestActor(mo.Actor):
async def kill(self, address, uid):
actor_ref = await mo.actor_ref(address, uid)
task = asyncio.create_task(actor_ref.crash())
return await task

async def crash(self):
os._exit(0)


@require_ray
@pytest.mark.asyncio
async def test_main_pool(ray_start_regular):
Expand Down Expand Up @@ -71,3 +87,96 @@ async def test_shutdown_sub_pool(ray_start_regular):
await sub_pool_handle1.actor_pool.remote('health_check')
with pytest.raises(AttributeError, match='NoneType'):
await sub_pool_handle2.actor_pool.remote('health_check')


@require_ray
@pytest.mark.asyncio
async def test_server_closed(ray_start_regular):
pg_name, n_process = 'ray_cluster', 1
pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}])
ray.get(pg.ready())
address = process_placement_to_address(pg_name, 0, process_index=0)
# start the actor pool
actor_handle = await mo.create_actor_pool(address, n_process=n_process)
await actor_handle.actor_pool.remote('start')

ctx = get_context()
actor_main = await ctx.create_actor(
TestActor, address=address, uid='Test-main',
allocate_strategy=ProcessIndex(0))

actor_sub = await ctx.create_actor(
TestActor, address=address, uid='Test-sub',
allocate_strategy=ProcessIndex(1))

# test calling from ray driver to ray actor
task = asyncio.create_task(actor_sub.crash())

with pytest.raises(ServerClosed):
# process already died,
# ServerClosed will be raised
await task

# wait for recover of sub pool
await ctx.wait_actor_pool_recovered(actor_sub.address, address)

# test calling from ray actor to ray actor
task = asyncio.create_task(actor_main.kill(actor_sub.address, 'Test-sub'))

with pytest.raises(ServerClosed):
await task


@require_ray
@pytest.mark.asyncio
@pytest.mark.parametrize(
'auto_recover',
[False, True, 'actor', 'process']
)
async def test_auto_recover(ray_start_regular, auto_recover):
pg_name, n_process = 'ray_cluster', 1
pg = ray.util.placement_group(name=pg_name, bundles=[{'CPU': n_process}])
assert pg.wait(timeout_seconds=20)
address = process_placement_to_address(pg_name, 0, process_index=0)
actor_handle = await mo.create_actor_pool(address, n_process=n_process, auto_recover=auto_recover)
await actor_handle.actor_pool.remote('start')

ctx = get_context()

# wait for recover of main pool always returned immediately
await ctx.wait_actor_pool_recovered(address, address)

# create actor on main
actor_ref = await ctx.create_actor(
TestActor, address=address,
allocate_strategy=MainPool())

with pytest.raises(ValueError):
# cannot kill actors on main pool
await mo.kill_actor(actor_ref)

# create actor
actor_ref = await ctx.create_actor(
TestActor, address=address,
allocate_strategy=ProcessIndex(1))
# kill_actor will cause kill corresponding process
await ctx.kill_actor(actor_ref)

if auto_recover:
await ctx.wait_actor_pool_recovered(actor_ref.address, address)
sub_pool_address = process_placement_to_address(pg_name, 0, process_index=1)
sub_pool_handle = ray.get_actor(sub_pool_address)
assert await sub_pool_handle.actor_pool.remote('health_check') == PoolStatus.HEALTHY

expect_has_actor = True if auto_recover in ['actor', True] else False
assert await ctx.has_actor(actor_ref) is expect_has_actor
else:
with pytest.raises((ServerClosed, ConnectionError)):
await ctx.has_actor(actor_ref)

if 'COV_CORE_SOURCE' in os.environ:
for addr in [process_placement_to_address(pg_name, 0, process_index=i) for i in range(2)]:
# must save the local reference until this is fixed:
# https://github.com/ray-project/ray/issues/7815
ray_actor = ray.get_actor(addr)
ray.get(ray_actor.cleanup.remote())

0 comments on commit 2fb00d8

Please sign in to comment.