diff --git a/src/codeflare_sdk/common/kueue/kueue.py b/src/codeflare_sdk/common/kueue/kueue.py index a721713e..0d9bf6cf 100644 --- a/src/codeflare_sdk/common/kueue/kueue.py +++ b/src/codeflare_sdk/common/kueue/kueue.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional, List +import logging from codeflare_sdk.common import _kube_api_error_handling from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client from kubernetes import client @@ -20,6 +21,8 @@ from ...common.utils import get_current_namespace +logger = logging.getLogger(__name__) + def get_default_kueue_name(namespace: str) -> Optional[str]: """ @@ -144,6 +147,59 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool: return False +def priority_class_exists(priority_class_name: str) -> Optional[bool]: + """ + Checks if a WorkloadPriorityClass with the provided name exists in the cluster. + + WorkloadPriorityClass is a cluster-scoped resource. + + Args: + priority_class_name (str): + The name of the WorkloadPriorityClass to check for existence. + + Returns: + Optional[bool]: + True if the WorkloadPriorityClass exists, False if it doesn't exist, + None if we cannot verify (e.g., permission denied). + """ + try: + config_check() + api_instance = client.CustomObjectsApi(get_api_client()) + # Try to get the specific WorkloadPriorityClass by name + api_instance.get_cluster_custom_object( + group="kueue.x-k8s.io", + version="v1beta1", + plural="workloadpriorityclasses", + name=priority_class_name, + ) + return True + except client.ApiException as e: + if e.status == 404: + # Not found - doesn't exist + return False + elif e.status == 403: + # Permission denied - can't verify, return None + logger.warning( + f"Permission denied when checking WorkloadPriorityClass '{priority_class_name}'. " + f"Cannot verify if it exists." + ) + return None + else: + # Other API errors - log and return None (best effort) + logger.warning( + f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. " + f"Cannot verify if it exists." + ) + return None + except Exception as e: # pragma: no cover + # Unexpected errors - log and return None (best effort) + logger.warning( + f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. " + f"Cannot verify if it exists." + ) + return None + + def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]): """ Adds a local queue name label to the provided item. diff --git a/src/codeflare_sdk/ray/rayjobs/rayjob.py b/src/codeflare_sdk/ray/rayjobs/rayjob.py index c06c596e..feaeb646 100644 --- a/src/codeflare_sdk/ray/rayjobs/rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/rayjob.py @@ -23,7 +23,10 @@ from typing import Dict, Any, Optional, Tuple, Union from ray.runtime_env import RuntimeEnv -from codeflare_sdk.common.kueue.kueue import get_default_kueue_name +from codeflare_sdk.common.kueue.kueue import ( + get_default_kueue_name, + priority_class_exists, +) from codeflare_sdk.common.utils.constants import MOUNT_PATH from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version @@ -69,6 +72,7 @@ def __init__( ttl_seconds_after_finished: int = 0, active_deadline_seconds: Optional[int] = None, local_queue: Optional[str] = None, + priority_class: Optional[str] = None, ): """ Initialize a RayJob instance. @@ -86,11 +90,13 @@ def __init__( ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0) active_deadline_seconds: Maximum time the job can run before being terminated (optional) local_queue: The Kueue LocalQueue to submit the job to (optional) + priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional). Note: - True if cluster_config is provided (new cluster will be cleaned up) - False if cluster_name is provided (existing cluster will not be shut down) - User can explicitly set this value to override auto-detection + - Kueue labels (queue and priority) can be applied to both new and existing clusters """ if cluster_name is None and cluster_config is None: raise ValueError( @@ -124,6 +130,7 @@ def __init__( self.ttl_seconds_after_finished = ttl_seconds_after_finished self.active_deadline_seconds = active_deadline_seconds self.local_queue = local_queue + self.priority_class = priority_class if namespace is None: detected_namespace = get_current_namespace() @@ -165,6 +172,7 @@ def submit(self) -> str: # Validate configuration before submitting self._validate_ray_version_compatibility() self._validate_working_dir_entrypoint() + self._validate_priority_class() # Extract files from entrypoint and runtime_env working_dir files = extract_all_local_files(self) @@ -243,26 +251,35 @@ def _build_rayjob_cr(self) -> Dict[str, Any]: # Extract files once and use for both runtime_env and submitter pod files = extract_all_local_files(self) + # Build Kueue labels and annotations for all jobs (new and existing clusters) labels = {} - # If cluster_config is provided, use the local_queue from the cluster_config - if self._cluster_config is not None: - if self.local_queue: - labels["kueue.x-k8s.io/queue-name"] = self.local_queue + + # Queue name label - apply to all jobs when explicitly specified + # For new clusters, also auto-detect default queue if not specified + if self.local_queue: + labels["kueue.x-k8s.io/queue-name"] = self.local_queue + elif self._cluster_config is not None: + # Only auto-detect default queue for new clusters + default_queue = get_default_kueue_name(self.namespace) + if default_queue: + labels["kueue.x-k8s.io/queue-name"] = default_queue else: - default_queue = get_default_kueue_name(self.namespace) - if default_queue: - labels["kueue.x-k8s.io/queue-name"] = default_queue - else: - # No default queue found, use "default" as fallback - labels["kueue.x-k8s.io/queue-name"] = "default" - logger.warning( - f"No default Kueue LocalQueue found in namespace '{self.namespace}'. " - f"Using 'default' as the queue name. If a LocalQueue named 'default' " - f"does not exist, the RayJob submission will fail. " - f"To fix this, please explicitly specify the 'local_queue' parameter." - ) + # No default queue found, use "default" as fallback + labels["kueue.x-k8s.io/queue-name"] = "default" + logger.warning( + f"No default Kueue LocalQueue found in namespace '{self.namespace}'. " + f"Using 'default' as the queue name. If a LocalQueue named 'default' " + f"does not exist, the RayJob submission will fail. " + f"To fix this, please explicitly specify the 'local_queue' parameter." + ) + + # Priority class label - apply when specified + if self.priority_class: + labels["kueue.x-k8s.io/priority-class"] = self.priority_class - rayjob_cr["metadata"]["labels"] = labels + # Apply labels to metadata + if labels: + rayjob_cr["metadata"]["labels"] = labels # When using Kueue (queue label present), start with suspend=true # Kueue will unsuspend the job once the workload is admitted @@ -450,6 +467,36 @@ def _validate_cluster_config_image(self): elif is_warning: warnings.warn(f"Cluster config image: {message}") + def _validate_priority_class(self): + """ + Validate that the priority class exists in the cluster (best effort). + + Raises ValueError if the priority class is definitively known not to exist. + If we cannot verify (e.g., permission denied), logs a warning and allows submission. + """ + if self.priority_class: + logger.debug(f"Validating priority class '{self.priority_class}'...") + exists = priority_class_exists(self.priority_class) + + if exists is False: + # Definitively doesn't exist - fail validation + print( + f"❌ Priority class '{self.priority_class}' does not exist in the cluster. " + f"Submission cancelled." + ) + raise ValueError( + f"Priority class '{self.priority_class}' does not exist" + ) + elif exists is None: + # Cannot verify - log warning and allow submission + logger.warning( + f"Could not verify if priority class '{self.priority_class}' exists. " + f"Proceeding with submission - Kueue will validate on admission." + ) + else: + # exists is True - validation passed + logger.debug(f"Priority class '{self.priority_class}' verified.") + def _validate_working_dir_entrypoint(self): """ Validate entrypoint file configuration. diff --git a/src/codeflare_sdk/ray/rayjobs/test/test_rayjob.py b/src/codeflare_sdk/ray/rayjobs/test/test_rayjob.py index 928cc1f8..444f176a 100644 --- a/src/codeflare_sdk/ray/rayjobs/test/test_rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/test/test_rayjob.py @@ -1155,26 +1155,255 @@ def test_rayjob_kueue_explicit_local_queue(auto_mock_setup): ) -def test_rayjob_no_kueue_label_for_existing_cluster(auto_mock_setup): +def test_rayjob_queue_label_explicit_vs_default(auto_mock_setup, mocker): """ - Test RayJob doesn't add Kueue label for existing clusters. + Test queue label behavior: explicit queue vs default queue auto-detection. """ + # Mock default queue detection + mock_get_default = mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.get_default_kueue_name", + return_value="default-queue", + ) + + config = ManagedClusterConfig(num_workers=1) + + # Test 1: Explicit queue should be used (no default queue lookup) + mock_api_instance1 = auto_mock_setup["rayjob_api"] + mock_api_instance1.submit_job.return_value = {"metadata": {"name": "test-job-1"}} + rayjob1 = RayJob( + job_name="test-job-1", + entrypoint="python -c 'print()'", + cluster_config=config, + local_queue="explicit-queue", + ) + rayjob1.submit() + call_args1 = mock_api_instance1.submit_job.call_args + submitted_job1 = call_args1.kwargs["job"] + assert ( + submitted_job1["metadata"]["labels"]["kueue.x-k8s.io/queue-name"] + == "explicit-queue" + ) + # Should not call get_default_kueue_name when explicit queue is provided + mock_get_default.assert_not_called() + + # Reset mock for next test + mock_get_default.reset_mock() + mock_get_default.return_value = "default-queue" + + # Test 2: Default queue should be auto-detected for new clusters + mock_api_instance2 = auto_mock_setup["rayjob_api"] + mock_api_instance2.submit_job.return_value = {"metadata": {"name": "test-job-2"}} + rayjob2 = RayJob( + job_name="test-job-2", + entrypoint="python -c 'print()'", + cluster_config=config, + # No local_queue specified + ) + rayjob2.submit() + call_args2 = mock_api_instance2.submit_job.call_args + submitted_job2 = call_args2.kwargs["job"] + assert ( + submitted_job2["metadata"]["labels"]["kueue.x-k8s.io/queue-name"] + == "default-queue" + ) + # Should call get_default_kueue_name when no explicit queue + mock_get_default.assert_called_once() + + # Test 3: Existing cluster without explicit queue should not have queue label + mock_api_instance3 = auto_mock_setup["rayjob_api"] + mock_api_instance3.submit_job.return_value = {"metadata": {"name": "test-job-3"}} + mock_get_default.reset_mock() + rayjob3 = RayJob( + job_name="test-job-3", + cluster_name="existing-cluster", + entrypoint="python -c 'print()'", + # No local_queue specified + ) + rayjob3.submit() + call_args3 = mock_api_instance3.submit_job.call_args + submitted_job3 = call_args3.kwargs["job"] + assert "kueue.x-k8s.io/queue-name" not in submitted_job3["metadata"].get( + "labels", {} + ) + # Should not call get_default_kueue_name for existing clusters + mock_get_default.assert_not_called() + + +def test_rayjob_priority_class(auto_mock_setup, mocker): + """ + Test RayJob adds priority class label when specified. + """ + # Mock priority_class_exists to return True (priority class exists) + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.priority_class_exists", + return_value=True, + ) + + mock_api_instance = auto_mock_setup["rayjob_api"] + mock_api_instance.submit_job.return_value = {"metadata": {"name": "test-job"}} + + config = ManagedClusterConfig(num_workers=1) + rayjob = RayJob( + job_name="test-job", + entrypoint="python -c 'print()'", + cluster_config=config, + priority_class="high-priority", + ) + + rayjob.submit() + + call_args = mock_api_instance.submit_job.call_args + submitted_job = call_args.kwargs["job"] + assert ( + submitted_job["metadata"]["labels"]["kueue.x-k8s.io/priority-class"] + == "high-priority" + ) + + +def test_rayjob_priority_class_not_added_when_none(auto_mock_setup): + """ + Test RayJob doesn't add priority class label when not specified. + """ + mock_api_instance = auto_mock_setup["rayjob_api"] + mock_api_instance.submit_job.return_value = {"metadata": {"name": "test-job"}} + + config = ManagedClusterConfig(num_workers=1) + rayjob = RayJob( + job_name="test-job", + entrypoint="python -c 'print()'", + cluster_config=config, + ) + + rayjob.submit() + + call_args = mock_api_instance.submit_job.call_args + submitted_job = call_args.kwargs["job"] + # Priority class label should not be present + assert "kueue.x-k8s.io/priority-class" not in submitted_job["metadata"].get( + "labels", {} + ) + + +def test_rayjob_priority_class_validation_invalid(auto_mock_setup, mocker): + """ + Test RayJob validates priority class exists before submission. + """ + # Mock priority_class_exists to return False (priority class doesn't exist) + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.priority_class_exists", + return_value=False, + ) + + config = ManagedClusterConfig(num_workers=1) + rayjob = RayJob( + job_name="test-job", + entrypoint="python -c 'print()'", + cluster_config=config, + priority_class="invalid-priority", + ) + + # Should raise ValueError before submission + with pytest.raises( + ValueError, match="Priority class 'invalid-priority' does not exist" + ): + rayjob.submit() + + +def test_rayjob_priority_class_validation_cannot_verify(auto_mock_setup, mocker): + """ + Test RayJob allows submission when priority class cannot be verified (best effort). + """ + # Mock priority_class_exists to return None (cannot verify, e.g., permission denied) + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.priority_class_exists", + return_value=None, + ) + + mock_api_instance = auto_mock_setup["rayjob_api"] + mock_api_instance.submit_job.return_value = {"metadata": {"name": "test-job"}} + + config = ManagedClusterConfig(num_workers=1) + rayjob = RayJob( + job_name="test-job", + entrypoint="python -c 'print()'", + cluster_config=config, + priority_class="unknown-priority", + ) + + # Should submit successfully when we can't verify (best effort) + rayjob.submit() + + call_args = mock_api_instance.submit_job.call_args + submitted_job = call_args.kwargs["job"] + # Priority class label should still be added + assert ( + submitted_job["metadata"]["labels"]["kueue.x-k8s.io/priority-class"] + == "unknown-priority" + ) + + +def test_rayjob_kueue_labels_with_existing_cluster(auto_mock_setup, mocker): + """ + Test RayJob adds Kueue labels when using existing cluster with explicit queue. + """ + # Mock priority_class_exists to return True (priority class exists) + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.priority_class_exists", + return_value=True, + ) + mock_api_instance = auto_mock_setup["rayjob_api"] mock_api_instance.submit_job.return_value = {"metadata": {"name": "test-job"}} - # Using existing cluster (no cluster_config) rayjob = RayJob( job_name="test-job", cluster_name="existing-cluster", entrypoint="python -c 'print()'", + local_queue="my-queue", + priority_class="medium-priority", ) rayjob.submit() - # Verify no Kueue label was added call_args = mock_api_instance.submit_job.call_args submitted_job = call_args.kwargs["job"] - assert "kueue.x-k8s.io/queue-name" not in submitted_job["metadata"]["labels"] + + # Verify Kueue labels are present + assert "kueue.x-k8s.io/queue-name" in submitted_job["metadata"]["labels"] + assert ( + submitted_job["metadata"]["labels"]["kueue.x-k8s.io/queue-name"] == "my-queue" + ) + assert ( + submitted_job["metadata"]["labels"]["kueue.x-k8s.io/priority-class"] + == "medium-priority" + ) + + # Verify suspend is True when Kueue is used + assert submitted_job["spec"]["suspend"] is True + + +def test_rayjob_no_kueue_label_for_existing_cluster_without_queue(auto_mock_setup): + """ + Test RayJob doesn't add Kueue label for existing clusters when no queue specified. + """ + mock_api_instance = auto_mock_setup["rayjob_api"] + mock_api_instance.submit_job.return_value = {"metadata": {"name": "test-job"}} + + # Using existing cluster (no cluster_config) and no explicit queue + rayjob = RayJob( + job_name="test-job", + cluster_name="existing-cluster", + entrypoint="python -c 'print()'", + ) + + rayjob.submit() + + # Verify no Kueue label was added (no auto-detection for existing clusters) + call_args = mock_api_instance.submit_job.call_args + submitted_job = call_args.kwargs["job"] + assert "kueue.x-k8s.io/queue-name" not in submitted_job["metadata"].get( + "labels", {} + ) def test_rayjob_with_ttl_and_deadline(auto_mock_setup):