In [None]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# @noautodeps
# pyre-ignore-all-errors
import json
import logging
import socket
import sys

import cloudpickle
from example_actors.compute_world_size_actor import ComputeWorldSizeActor
from monarch.actor import Actor, endpoint
from monarch.job import SlurmJob


logging.basicConfig(
    level=logging.INFO,
    format="%(name)s %(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)


logger: logging.Logger = logging.getLogger(__name__)


class _HostnameActor(Actor):
    """Helper actor to get hostname from rank 0"""
    @endpoint
    def get_hostname(self) -> str:
        return socket.gethostname()


async def main():
    num_nodes = 2
    gpus_per_node = 8
    mesh_name = "mesh0"
    master_port = 29500
    
    # Create SLURM job
    slurm_job = SlurmJob(
        meshes={mesh_name: num_nodes},
        job_name="monarch_example",
        gpus_per_node=gpus_per_node,
        time_limit="06:00:00",
    )

    try:
        # Get job state and create process mesh
        job_state = slurm_job.state()
        proc_mesh = job_state.mesh0.spawn_procs({"gpus": gpus_per_node})
        
        # Get master_addr from rank 0
        hostname_actor = proc_mesh.spawn("hostname_actor", _HostnameActor)
        hostname_values = await hostname_actor.flatten("rank").slice(rank=0).get_hostname.call()
        master_addr = hostname_values.item()
        
        # Spawn actor
        actor = proc_mesh.spawn("compute_world_size_actor", ComputeWorldSizeActor)

        logger.info("computing world size...")
        # this is redundant but is here for example sake
        values = await actor.compute_world_size.call(
            master_addr=master_addr,
            master_port=master_port,
        )

        values_by_rank = {f"rank_{p.rank}": v for p, v in list(values.flatten("rank"))}

        logger.info(
            f"""computed world_sizes:
    {'-'*40}
    {json.dumps(values_by_rank, indent=2)}
    {'-'*40}"""
        )
    finally:
        # Cancel the SLURM job, releasing all reserved nodes back to the cluster
        slurm_job.kill()
        logger.info("Job terminated successfully")


if __name__ == "__main__":
    cloudpickle.register_pickle_by_value(sys.modules[ComputeWorldSizeActor.__module__])

    await main()