Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torchx.schedulers.docker_scheduler as docker_scheduler
import torchx.schedulers.kubernetes_scheduler as kubernetes_scheduler
import torchx.schedulers.local_scheduler as local_scheduler
import torchx.schedulers.ray_scheduler as ray_scheduler
import torchx.schedulers.slurm_scheduler as slurm_scheduler
from torchx.schedulers.api import Scheduler
from torchx.specs.api import SchedulerBackend
Expand All @@ -37,9 +36,6 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
"kubernetes": kubernetes_scheduler.create_scheduler,
}

if ray_scheduler.has_ray():
default_schedulers["ray"] = ray_scheduler.create_scheduler

return load_group(
"torchx.schedulers",
default=default_schedulers,
Expand Down
22 changes: 11 additions & 11 deletions torchx/schedulers/test/ray_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Type, cast
from typing import Any, Dict, Iterator, Type
from unittest import TestCase
from unittest.mock import patch

from torchx.schedulers import get_schedulers
from torchx.schedulers.ray_scheduler import RayScheduler, _logger, has_ray
from torchx.specs import AppDef, CfgVal, Resource, Role, runopts


if has_ray():

class RaySchedulerRegistryTest(TestCase):
def test_get_schedulers_returns_ray_scheduler(self) -> None:
schedulers = get_schedulers("test_session")
# TODO(aivanou): enable after 0.1.1 release
# class RaySchedulerRegistryTest(TestCase):
# def test_get_schedulers_returns_ray_scheduler(self) -> None:
# schedulers = get_schedulers("test_session")

self.assertIn("ray", schedulers)
# self.assertIn("ray", schedulers)

scheduler = schedulers["ray"]
# scheduler = schedulers["ray"]

self.assertIsInstance(scheduler, RayScheduler)
# self.assertIsInstance(scheduler, RayScheduler)

ray_scheduler = cast(RayScheduler, scheduler)
# ray_scheduler = cast(RayScheduler, scheduler)

self.assertEqual(ray_scheduler.backend, "ray")
self.assertEqual(ray_scheduler.session_name, "test_session")
# self.assertEqual(ray_scheduler.backend, "ray")
# self.assertEqual(ray_scheduler.session_name, "test_session")

class RaySchedulerTest(TestCase):
def setUp(self) -> None:
Expand Down