Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding module lookup for building trackers #803

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.6.0/installer/volcano-development.yaml

See the
`Volcano Quickstart <https://github.com/volcano-sh/volcano#user-content-quick-start-guide>`_
`Volcano Quickstart <https://github.com/volcano-sh/volcano#quick-start-guide>`_
for more information.
"""

Expand Down
12 changes: 8 additions & 4 deletions torchx/tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
-------------
To enable tracking it requires:

1. Defining tracker backends (entrypoints and configuration) on launcher side using :doc:`runner.config`
1. Defining tracker backends (entrypoints/modules and configuration) on launcher side using :doc:`runner.config`
2. Adding entrypoints within a user job using entry_points (`specification`_)

.. _specification: https://packaging.python.org/en/latest/specifications/entry-points/
Expand All @@ -49,13 +49,13 @@
User can define any number of tracker backends under **torchx:tracker** section in :doc:`runner.config`, where:
* Key: is an arbitrary name for the tracker, where the name will be used to configure its properties
under [tracker:<TRACKER_NAME>]
* Value: is *entrypoint/factory method* that must be available within user job. The value will be injected into a
* Value: is *entrypoint* or *module* factory method that must be available within user job. The value will be injected into a
user job and used to construct tracker implementation.

.. code-block:: ini

[torchx:tracker]
tracker_name=<entry_point>
tracker_name=<entry_point_or_module_factory_method>


Each tracker can be additionally configured (currently limited to `config` parameter) under `[tracker:<TRACKER NAME>]` section:
Expand All @@ -71,11 +71,15 @@

[torchx:tracker]
tracker1=tracker1
tracker12=backend_2_entry_point
tracker2=backend_2_entry_point
tracker3=torchx.tracker.mlflow:create_tracker

[tracker:tracker1]
config=s3://my_bucket/config.json

[tracker:tracker3]
config=my_config.json


2. User job configuration (Advanced)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
25 changes: 11 additions & 14 deletions torchx/tracker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Iterable, Mapping, Optional

from torchx.util.entrypoints import load_group
from torchx.util.modules import load_module

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,30 +178,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str


def build_trackers(
entrypoint_and_config: Mapping[str, Optional[str]]
factory_and_config: Mapping[str, Optional[str]]
) -> Iterable[TrackerBase]:
trackers = []

entrypoint_factories = load_group("torchx.tracker")
entrypoint_factories = load_group("torchx.tracker") or {}
if not entrypoint_factories:
logger.warning(
"No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
)
return trackers
logger.warning("No 'torchx.tracker' entry_points are defined.")

for entrypoint_key, config in entrypoint_and_config.items():
if entrypoint_key not in entrypoint_factories:
for factory_name, config in factory_and_config.items():
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
if not factory:
logger.warning(
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
)
continue
factory = entrypoint_factories[entrypoint_key]
if config:
logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`")
tracker = factory(config)
logger.info(f"Tracker config found for `{factory_name}` as `{config}`")
else:
logger.info(f"No tracker config specified for `{entrypoint_key}`")
tracker = factory(None)
clumsy marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"No tracker config specified for `{factory_name}`")
tracker = factory(config)
trackers.append(tracker)
return trackers

Expand Down
24 changes: 23 additions & 1 deletion torchx/tracker/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict
from typing import cast, DefaultDict, Dict, Iterable, Mapping, Optional, Tuple
from unittest import mock, TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from torchx.tracker import app_run_from_env
from torchx.tracker.api import (
Expand All @@ -27,6 +27,8 @@
TrackerSource,
)

from torchx.tracker.mlflow import MLflowTracker

RunId = str

DEFAULT_SOURCE: str = "__parent__"
Expand Down Expand Up @@ -271,6 +273,26 @@ def test_build_trackers_with_no_entrypoints_group_defined(self) -> None:
trackers = build_trackers(tracker_names)
self.assertEqual(0, len(list(trackers)))

def test_build_trackers_with_module(self) -> None:
module = MagicMock()
module.return_value = MagicMock(spec=MLflowTracker)
with patch(
"torchx.tracker.api.load_group",
return_value=None,
) and patch(
"torchx.tracker.api.load_module",
return_value=module,
):
tracker_names = {
"torchx.tracker.mlflow:create_tracker": (config := "myconfig.txt")
}
trackers = build_trackers(tracker_names)
trackers = list(trackers)
self.assertEqual(1, len(trackers))
tracker = trackers[0]
self.assertIsInstance(tracker, MLflowTracker)
module.assert_called_once_with(config)

def test_build_trackers(self) -> None:
with patch(
"torchx.tracker.api.load_group",
Expand Down
33 changes: 33 additions & 0 deletions torchx/util/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
from types import ModuleType
from typing import Callable, Optional, Union


def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]:
"""
Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr``

::


1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn``
1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module``
"""
parts = path.split(":", 2)
module_path, method = parts[0], parts[1] if len(parts) > 1 else None
module = None
i, n = -1, len(module_path)
try:
while i < n:
i = module_path.find(".", i + 1)
i = i if i >= 0 else n
module = importlib.import_module(module_path[:i])
return getattr(module, method) if method else module
except Exception:
return None
23 changes: 23 additions & 0 deletions torchx/util/test/modules_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torchx.util.modules import load_module


class ModulesTest(unittest.TestCase):
def test_load_module(self) -> None:
result = load_module("os.path")
import os

self.assertEqual(result, os.path)

def test_load_module_method(self) -> None:
result = load_module("os.path:join")
import os

self.assertEqual(result, os.path.join)
Loading