Skip to content

Commit 5e50ded

Browse files
committed
RHOAIENG-39073: Add priority class support
1 parent 8eac545 commit 5e50ded

File tree

5 files changed

+405
-26
lines changed

5 files changed

+405
-26
lines changed

.github/workflows/rayjob_e2e_tests.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ jobs:
119119
kubectl create clusterrolebinding sdk-user-secret-manager --clusterrole=secret-manager --user=sdk-user
120120
kubectl create clusterrole workload-reader --verb=get,list,watch --resource=workloads
121121
kubectl create clusterrolebinding sdk-user-workload-reader --clusterrole=workload-reader --user=sdk-user
122+
kubectl create clusterrole workloadpriorityclass-reader --verb=get,list --resource=workloadpriorityclasses
123+
kubectl create clusterrolebinding sdk-user-workloadpriorityclass-reader --clusterrole=workloadpriorityclass-reader --user=sdk-user
122124
kubectl config use-context sdk-user
123125
124126
- name: Run RayJob E2E tests

src/codeflare_sdk/common/kueue/kueue.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
from typing import Optional, List
16+
import logging
1617
from codeflare_sdk.common import _kube_api_error_handling
1718
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
1819
from kubernetes import client
1920
from kubernetes.client.exceptions import ApiException
2021

2122
from ...common.utils import get_current_namespace
2223

24+
logger = logging.getLogger(__name__)
25+
2326

2427
def get_default_kueue_name(namespace: str) -> Optional[str]:
2528
"""
@@ -144,6 +147,50 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
144147
return False
145148

146149

150+
def priority_class_exists(priority_class_name: str) -> Optional[bool]:
151+
"""
152+
Checks if a WorkloadPriorityClass with the provided name exists in the cluster.
153+
154+
WorkloadPriorityClass is a cluster-scoped resource.
155+
156+
Args:
157+
priority_class_name (str):
158+
The name of the WorkloadPriorityClass to check for existence.
159+
160+
Returns:
161+
Optional[bool]:
162+
True if the WorkloadPriorityClass exists, False if it doesn't exist,
163+
None if we cannot verify (e.g., permission denied).
164+
"""
165+
try:
166+
config_check()
167+
api_instance = client.CustomObjectsApi(get_api_client())
168+
# Try to get the specific WorkloadPriorityClass by name
169+
api_instance.get_cluster_custom_object(
170+
group="kueue.x-k8s.io",
171+
version="v1beta1",
172+
plural="workloadpriorityclasses",
173+
name=priority_class_name,
174+
)
175+
return True
176+
except client.ApiException as e:
177+
if e.status == 404:
178+
return False
179+
180+
logger.warning(
181+
f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. "
182+
f"Cannot verify if it exists."
183+
)
184+
return None
185+
186+
except Exception as e:
187+
logger.warning(
188+
f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. "
189+
f"Cannot verify if it exists."
190+
)
191+
return None
192+
193+
147194
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
148195
"""
149196
Adds a local queue name label to the provided item.

src/codeflare_sdk/common/kueue/test_kueue.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
import os
2424
import filecmp
2525
from pathlib import Path
26-
from .kueue import list_local_queues, local_queue_exists, add_queue_label
26+
from .kueue import (
27+
list_local_queues,
28+
local_queue_exists,
29+
add_queue_label,
30+
priority_class_exists,
31+
)
2732

2833
parent = Path(__file__).resolve().parents[4] # project directory
2934
aw_dir = os.path.expanduser("~/.codeflare/resources/")
@@ -292,6 +297,52 @@ def test_add_queue_label_with_invalid_local_queue(mocker):
292297
add_queue_label(item, namespace, local_queue)
293298

294299

300+
def test_priority_class_exists_found(mocker):
301+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
302+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
303+
mock_api.return_value.get_cluster_custom_object.return_value = {
304+
"metadata": {"name": "high-priority"}
305+
}
306+
307+
assert priority_class_exists("high-priority") is True
308+
309+
310+
def test_priority_class_exists_not_found(mocker):
311+
from kubernetes.client import ApiException
312+
313+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
314+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
315+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
316+
status=404
317+
)
318+
319+
assert priority_class_exists("missing-priority") is False
320+
321+
322+
def test_priority_class_exists_permission_denied(mocker):
323+
from kubernetes.client import ApiException
324+
325+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
326+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
327+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
328+
status=403
329+
)
330+
331+
assert priority_class_exists("some-priority") is None
332+
333+
334+
def test_priority_class_exists_other_error(mocker):
335+
from kubernetes.client import ApiException
336+
337+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
338+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
339+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
340+
status=500
341+
)
342+
343+
assert priority_class_exists("some-priority") is None
344+
345+
295346
# Make sure to always keep this function last
296347
def test_cleanup():
297348
os.remove(f"{aw_dir}unit-test-cluster-kueue.yaml")

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from typing import Dict, Any, Optional, Tuple, Union
2424

2525
from ray.runtime_env import RuntimeEnv
26-
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
26+
from codeflare_sdk.common.kueue.kueue import (
27+
get_default_kueue_name,
28+
priority_class_exists,
29+
)
2730
from codeflare_sdk.common.utils.constants import MOUNT_PATH
2831

2932
from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version
@@ -69,6 +72,7 @@ def __init__(
6972
ttl_seconds_after_finished: int = 0,
7073
active_deadline_seconds: Optional[int] = None,
7174
local_queue: Optional[str] = None,
75+
priority_class: Optional[str] = None,
7276
):
7377
"""
7478
Initialize a RayJob instance.
@@ -86,11 +90,13 @@ def __init__(
8690
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
8791
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
8892
local_queue: The Kueue LocalQueue to submit the job to (optional)
93+
priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional).
8994
9095
Note:
9196
- True if cluster_config is provided (new cluster will be cleaned up)
9297
- False if cluster_name is provided (existing cluster will not be shut down)
9398
- User can explicitly set this value to override auto-detection
99+
- Kueue labels (queue and priority) can be applied to both new and existing clusters
94100
"""
95101
if cluster_name is None and cluster_config is None:
96102
raise ValueError(
@@ -124,6 +130,7 @@ def __init__(
124130
self.ttl_seconds_after_finished = ttl_seconds_after_finished
125131
self.active_deadline_seconds = active_deadline_seconds
126132
self.local_queue = local_queue
133+
self.priority_class = priority_class
127134

128135
if namespace is None:
129136
detected_namespace = get_current_namespace()
@@ -165,6 +172,7 @@ def submit(self) -> str:
165172
# Validate configuration before submitting
166173
self._validate_ray_version_compatibility()
167174
self._validate_working_dir_entrypoint()
175+
self._validate_priority_class()
168176

169177
# Extract files from entrypoint and runtime_env working_dir
170178
files = extract_all_local_files(self)
@@ -243,30 +251,39 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
243251
# Extract files once and use for both runtime_env and submitter pod
244252
files = extract_all_local_files(self)
245253

254+
# Build Kueue labels and annotations for all jobs (new and existing clusters)
246255
labels = {}
247-
# If cluster_config is provided, use the local_queue from the cluster_config
248-
if self._cluster_config is not None:
249-
if self.local_queue:
250-
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
256+
257+
# Queue name label - apply to all jobs when explicitly specified
258+
# For new clusters, also auto-detect default queue if not specified
259+
if self.local_queue:
260+
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
261+
elif self._cluster_config is not None:
262+
# Only auto-detect default queue for new clusters
263+
default_queue = get_default_kueue_name(self.namespace)
264+
if default_queue:
265+
labels["kueue.x-k8s.io/queue-name"] = default_queue
251266
else:
252-
default_queue = get_default_kueue_name(self.namespace)
253-
if default_queue:
254-
labels["kueue.x-k8s.io/queue-name"] = default_queue
255-
else:
256-
# No default queue found, use "default" as fallback
257-
labels["kueue.x-k8s.io/queue-name"] = "default"
258-
logger.warning(
259-
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
260-
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
261-
f"does not exist, the RayJob submission will fail. "
262-
f"To fix this, please explicitly specify the 'local_queue' parameter."
263-
)
267+
# No default queue found, use "default" as fallback
268+
labels["kueue.x-k8s.io/queue-name"] = "default"
269+
logger.warning(
270+
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
271+
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
272+
f"does not exist, the RayJob submission will fail. "
273+
f"To fix this, please explicitly specify the 'local_queue' parameter."
274+
)
275+
276+
# Priority class label - apply when specified
277+
if self.priority_class:
278+
labels["kueue.x-k8s.io/priority-class"] = self.priority_class
264279

265-
rayjob_cr["metadata"]["labels"] = labels
280+
# Apply labels to metadata
281+
if labels:
282+
rayjob_cr["metadata"]["labels"] = labels
266283

267-
# When using Kueue (queue label present), start with suspend=true
284+
# When using Kueue with lifecycled clusters (queue label present), start with suspend=true
268285
# Kueue will unsuspend the job once the workload is admitted
269-
if labels.get("kueue.x-k8s.io/queue-name"):
286+
if labels.get("kueue.x-k8s.io/queue-name") and self._cluster_config is not None:
270287
rayjob_cr["spec"]["suspend"] = True
271288

272289
# Add active deadline if specified
@@ -450,6 +467,32 @@ def _validate_cluster_config_image(self):
450467
elif is_warning:
451468
warnings.warn(f"Cluster config image: {message}")
452469

470+
def _validate_priority_class(self):
471+
"""
472+
Validate that the priority class exists in the cluster (best effort).
473+
474+
Raises ValueError if the priority class is definitively known not to exist.
475+
If we cannot verify (e.g., permission denied), logs a warning and allows submission.
476+
"""
477+
if self.priority_class:
478+
logger.debug(f"Validating priority class '{self.priority_class}'...")
479+
exists = priority_class_exists(self.priority_class)
480+
481+
if exists is False:
482+
# Definitively doesn't exist - fail validation
483+
raise ValueError(
484+
f"Priority class '{self.priority_class}' does not exist"
485+
)
486+
elif exists is None:
487+
# Cannot verify - log warning and allow submission
488+
logger.warning(
489+
f"Could not verify if priority class '{self.priority_class}' exists. "
490+
f"Proceeding with submission - Kueue will validate on admission."
491+
)
492+
else:
493+
# exists is True - validation passed
494+
logger.debug(f"Priority class '{self.priority_class}' verified.")
495+
453496
def _validate_working_dir_entrypoint(self):
454497
"""
455498
Validate entrypoint file configuration.

0 commit comments

Comments
 (0)