Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Add MPI support on Ray cluster #40917

Merged
merged 32 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/docker/base.build.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM $DOCKER_IMAGE_BASE_TEST

ENV RAY_INSTALL_JAVA=1

RUN apt-get install -y -qq maven openjdk-8-jre openjdk-8-jdk
RUN apt-get install -y -qq maven openjdk-8-jre openjdk-8-jdk

COPY . .

Expand Down
2 changes: 1 addition & 1 deletion ci/docker/base.gpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ RUN apt-get install -y -qq \
language-pack-en tmux cmake gdb vim htop \
libgtk2.0-dev zlib1g-dev libgl1-mesa-dev \
clang-format-12 jq \
clang-tidy-12 clang-12
clang-tidy-12 clang-12 libmpich-dev
# Make using GCC 9 explicit.
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 90 --slave /usr/bin/g++ g++ /usr/bin/g++-9 \
--slave /usr/bin/gcov gcov /usr/bin/gcov-9
Expand Down
2 changes: 1 addition & 1 deletion ci/docker/base.test.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ apt-get install -y -qq \
libgtk2.0-dev zlib1g-dev libgl1-mesa-dev \
liblz4-dev libunwind-dev libncurses5 \
clang-format-12 jq \
clang-tidy-12 clang-12
clang-tidy-12 clang-12 libmpich-dev

# Make using GCC 9 explicit.
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 90 --slave /usr/bin/g++ g++ /usr/bin/g++-9 \
Expand Down
3 changes: 3 additions & 0 deletions python/ray/_private/runtime_env/agent/runtime_env_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ray._private.runtime_env.py_modules import PyModulesPlugin
from ray._private.runtime_env.working_dir import WorkingDirPlugin
from ray._private.runtime_env.nsight import NsightPlugin
from ray._private.runtime_env.mpi import MPIPlugin
from ray.core.generated import (
runtime_env_agent_pb2,
agent_manager_pb2,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
# and unify with nsight and other profilers.
self._nsight_plugin = NsightPlugin(self._runtime_env_dir)
self._container_manager = ContainerManager(temp_dir)
self._mpi_plugin = MPIPlugin()

# TODO(architkulkarni): "base plugins" and third-party plugins should all go
# through the same code path. We should never need to refer to
Expand All @@ -214,6 +216,7 @@ def __init__(
self._py_modules_plugin,
self._java_jars_plugin,
self._nsight_plugin,
self._mpi_plugin,
]
self._plugin_manager = RuntimeEnvPluginManager()
for plugin in self._base_plugins:
Expand Down
81 changes: 81 additions & 0 deletions python/ray/_private/runtime_env/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys
import logging
import argparse
from typing import List, Optional
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
import subprocess

default_logger = logging.getLogger(__name__)


class MPIPlugin(RuntimeEnvPlugin):
priority = 90
fishbone marked this conversation as resolved.
Show resolved Hide resolved
name = "mpi"

def modify_context(
self,
uris: List[str], # noqa: ARG002
runtime_env: "RuntimeEnv", # noqa: F821 ARG002
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger, # noqa: ARG002
) -> None:
mpi_config = runtime_env.mpi()
if mpi_config is None:
return
try:
proc = subprocess.run(
["mpirun", "--version"], capture_output=True, check=True
)
except subprocess.CalledProcessError:
logger.exception(
"Failed to run mpi run. Please make sure mpi has been installed"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we kill proc here? Or does it guarantee the proc is killed? (if so can you comment here?)

Copy link
Contributor Author

@fishbone fishbone Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify context is in runtime env agent I think. Exception should be good? I can test it.

# The worker will fail to run and exception will be thrown in runtime
# env agent.
raise

logger.info(f"Running MPI plugin\n {proc.stdout.decode()}")

from pathlib import Path

# mpirun -n 10 python mpi.py worker_entry_func

worker_entry = mpi_config.get("worker_entry", None)
assert worker_entry is not None, "`worker_entry` must be setup in the runtime env."
assertPath(worker_entry).is_file(), "`worker_entry` must be a file."

cmds = (
["mpirun"]
+ mpi_config.get("args", [])
+ [
context.py_executable,
str(Path(__file__).absolute()),
str(Path(worker_entry).absolute()),
]
)
# Construct the start cmd
context.py_executable = " ".join(cmds)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Setup MPI worker")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the main function? or the parser?

This will will be used as the mpi entry point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main function. I don't see any main function from other plugins though. Maybe it should be a part of mpi_worker.py not here? (this means the function is executed when you import it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPIRUN is like a fork and the rest plugin doesn't have this.
The function won't execute when import since it checks main. If you import, __name__ won't be __main__.

This piece of code is part of the plugin, that's why I put it here and it's simple.

But open to move if you insist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I see. I think this is a bit confusing when I read code first time. I prefer to move it to a separate file (something like mpi_start.py), but it is also okay if you add comments in details in the main block. E.g., "the plugin starts a subprocess that runs this main method. It is not executed as a part of normal plugin" or sth like that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about moving it to python/ray/_private/workers/mpi_workers.py but feel it just moves the code to far away from the place where it's used. I think split it and move it to the other file is better.

parser.add_argument("worker_entry")
parser.add_argument("main_entry")

args, remaining_args = parser.parse_known_args()

from mpi4py import MPI

comm = MPI.COMM_WORLD

rank = comm.Get_rank()

entry_file = args.main_entry if rank == 0 else args.worker_entry

import importlib

sys.argv[1:] = remaining_args
spec = importlib.util.spec_from_file_location("__main__", entry_file)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
8 changes: 7 additions & 1 deletion python/ray/runtime_env/runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class MyClass:
"docker",
"worker_process_setup_hook",
"_nsight",
"mpi",
}

extensions_fields: Set[str] = {
Expand All @@ -287,6 +288,7 @@ def __init__(
nsight: Optional[Union[str, Dict[str, str]]] = None,
config: Optional[Union[Dict, RuntimeEnvConfig]] = None,
_validate: bool = True,
mpi: Optional[Dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update a doc, or consider _mpi!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll update the doc. i think it's a feature.

**kwargs,
):
super().__init__()
Expand All @@ -310,7 +312,8 @@ def __init__(
runtime_env["config"] = config
if worker_process_setup_hook is not None:
runtime_env["worker_process_setup_hook"] = worker_process_setup_hook

if mpi is not None:
runtime_env["mpi"] = mpi
if runtime_env.get("java_jars"):
runtime_env["java_jars"] = runtime_env.get("java_jars")

Expand Down Expand Up @@ -444,6 +447,9 @@ def java_jars(self) -> List[str]:
return list(self["java_jars"])
return []

def mpi(self) -> Optional[Union[str, Dict[str, str]]]:
return self.get("mpi", None)

def nsight(self) -> Optional[Union[str, Dict[str, str]]]:
return self.get("_nsight", None)

Expand Down
39 changes: 39 additions & 0 deletions python/ray/tests/mpi/mpi_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# copied from here
# https://github.com/alshedivat/ACM-Python-Tutorials-KAUST-2014/blob/master/mpi4py-tutorial/examples/compute_pi-montecarlo-mpi.py
from mpi4py import MPI
import numpy


def compute_pi(samples):
count = 0
for x, y in samples:
if x**2 + y**2 <= 1:
count += 1
pi = 4 * float(count) / len(samples)
return pi


def run():
comm = MPI.COMM_WORLD
nprocs = comm.Get_size()
myrank = comm.Get_rank()

if myrank == 0:
numpy.random.seed(1)
N = 100000 // nprocs
samples = numpy.random.random((nprocs, N, 2))
else:
samples = None

samples = comm.scatter(samples, root=0)

mypi = compute_pi(samples) / nprocs

pi = comm.reduce(mypi, root=0)

if myrank == 0:
return pi


if __name__ == "__main__":
run()
45 changes: 45 additions & 0 deletions python/ray/tests/mpi/test_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import ray


def test_mpi_func_pi(ray_start_regular):
@ray.remote(
runtime_env={
"mpi": {
"args": ["-n", "4"],
"worker_entry": "mpi_worker.py",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does it find the file? From the current directory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be clear about it in the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's from the working dir. I'll add the doc.

}
}
)
def calc_pi():
from mpi_worker import run

return run()

assert "3.14" == "%.2f" % (ray.get(calc_pi.remote()))


def test_mpi_actor_pi(ray_start_regular):
@ray.remote(
runtime_env={
"mpi": {
"args": ["-n", "4"],
"worker_entry": "mpi_worker.py",
}
}
)
class Actor:
def calc_pi(self):
from mpi_worker import run

return run()

actor = Actor.remote()

assert "3.14" == "%.2f" % (ray.get(actor.calc_pi.remote()))


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-sv", __file__]))
4 changes: 2 additions & 2 deletions python/ray/widgets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from functools import wraps
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

from packaging.version import Version

from ray._private.thirdparty.tabulate.tabulate import tabulate
from ray.util.annotations import DeveloperAPI
from ray.widgets import Template
Expand Down Expand Up @@ -102,6 +100,8 @@ def _has_outdated(
outdated = []
for (lib, version) in deps:
try:
from packaging.version import Version
fishbone marked this conversation as resolved.
Show resolved Hide resolved

module = importlib.import_module(lib)
if version and Version(module.__version__) < Version(version):
outdated.append([lib, version, module.__version__])
Expand Down
3 changes: 3 additions & 0 deletions python/requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,6 @@ numexpr==2.8.4

# For `serve run --reload` CLI.
watchfiles==0.19.0

# For mpi test
mpi4py
Loading