Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/codeflare_sdk/common/kueue/kueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

from typing import Optional, List
import logging
from codeflare_sdk.common import _kube_api_error_handling
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
from kubernetes import client
from kubernetes.client.exceptions import ApiException

from ...common.utils import get_current_namespace

logger = logging.getLogger(__name__)


def get_default_kueue_name(namespace: str) -> Optional[str]:
"""
Expand Down Expand Up @@ -144,6 +147,59 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
return False


def priority_class_exists(priority_class_name: str) -> Optional[bool]:
"""
Checks if a WorkloadPriorityClass with the provided name exists in the cluster.

WorkloadPriorityClass is a cluster-scoped resource.

Args:
priority_class_name (str):
The name of the WorkloadPriorityClass to check for existence.

Returns:
Optional[bool]:
True if the WorkloadPriorityClass exists, False if it doesn't exist,
None if we cannot verify (e.g., permission denied).
"""
try:
config_check()
api_instance = client.CustomObjectsApi(get_api_client())
# Try to get the specific WorkloadPriorityClass by name
api_instance.get_cluster_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
plural="workloadpriorityclasses",
name=priority_class_name,
)
return True
except client.ApiException as e:
if e.status == 404:
# Not found - doesn't exist
return False
elif e.status == 403:
# Permission denied - can't verify, return None
logger.warning(
f"Permission denied when checking WorkloadPriorityClass '{priority_class_name}'. "
f"Cannot verify if it exists."
)
return None
else:
# Other API errors - log and return None (best effort)
logger.warning(
f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. "
f"Cannot verify if it exists."
)
return None
except Exception as e: # pragma: no cover
# Unexpected errors - log and return None (best effort)
logger.warning(
f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. "
f"Cannot verify if it exists."
)
return None


def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
"""
Adds a local queue name label to the provided item.
Expand Down
83 changes: 65 additions & 18 deletions src/codeflare_sdk/ray/rayjobs/rayjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from typing import Dict, Any, Optional, Tuple, Union

from ray.runtime_env import RuntimeEnv
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
from codeflare_sdk.common.kueue.kueue import (
get_default_kueue_name,
priority_class_exists,
)
from codeflare_sdk.common.utils.constants import MOUNT_PATH

from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version
Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(
ttl_seconds_after_finished: int = 0,
active_deadline_seconds: Optional[int] = None,
local_queue: Optional[str] = None,
priority_class: Optional[str] = None,
):
"""
Initialize a RayJob instance.
Expand All @@ -86,11 +90,13 @@ def __init__(
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
local_queue: The Kueue LocalQueue to submit the job to (optional)
priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional).

Note:
- True if cluster_config is provided (new cluster will be cleaned up)
- False if cluster_name is provided (existing cluster will not be shut down)
- User can explicitly set this value to override auto-detection
- Kueue labels (queue and priority) can be applied to both new and existing clusters
"""
if cluster_name is None and cluster_config is None:
raise ValueError(
Expand Down Expand Up @@ -124,6 +130,7 @@ def __init__(
self.ttl_seconds_after_finished = ttl_seconds_after_finished
self.active_deadline_seconds = active_deadline_seconds
self.local_queue = local_queue
self.priority_class = priority_class

if namespace is None:
detected_namespace = get_current_namespace()
Expand Down Expand Up @@ -165,6 +172,7 @@ def submit(self) -> str:
# Validate configuration before submitting
self._validate_ray_version_compatibility()
self._validate_working_dir_entrypoint()
self._validate_priority_class()

# Extract files from entrypoint and runtime_env working_dir
files = extract_all_local_files(self)
Expand Down Expand Up @@ -243,26 +251,35 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
# Extract files once and use for both runtime_env and submitter pod
files = extract_all_local_files(self)

# Build Kueue labels and annotations for all jobs (new and existing clusters)
labels = {}
# If cluster_config is provided, use the local_queue from the cluster_config
if self._cluster_config is not None:
if self.local_queue:
labels["kueue.x-k8s.io/queue-name"] = self.local_queue

# Queue name label - apply to all jobs when explicitly specified
# For new clusters, also auto-detect default queue if not specified
if self.local_queue:
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
elif self._cluster_config is not None:
# Only auto-detect default queue for new clusters
default_queue = get_default_kueue_name(self.namespace)
if default_queue:
labels["kueue.x-k8s.io/queue-name"] = default_queue
else:
default_queue = get_default_kueue_name(self.namespace)
if default_queue:
labels["kueue.x-k8s.io/queue-name"] = default_queue
else:
# No default queue found, use "default" as fallback
labels["kueue.x-k8s.io/queue-name"] = "default"
logger.warning(
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
f"does not exist, the RayJob submission will fail. "
f"To fix this, please explicitly specify the 'local_queue' parameter."
)
# No default queue found, use "default" as fallback
labels["kueue.x-k8s.io/queue-name"] = "default"
logger.warning(
f"No default Kueue LocalQueue found in namespace '{self.namespace}'. "
f"Using 'default' as the queue name. If a LocalQueue named 'default' "
f"does not exist, the RayJob submission will fail. "
f"To fix this, please explicitly specify the 'local_queue' parameter."
)

# Priority class label - apply when specified
if self.priority_class:
labels["kueue.x-k8s.io/priority-class"] = self.priority_class

rayjob_cr["metadata"]["labels"] = labels
# Apply labels to metadata
if labels:
rayjob_cr["metadata"]["labels"] = labels

# When using Kueue (queue label present), start with suspend=true
# Kueue will unsuspend the job once the workload is admitted
Expand Down Expand Up @@ -450,6 +467,36 @@ def _validate_cluster_config_image(self):
elif is_warning:
warnings.warn(f"Cluster config image: {message}")

def _validate_priority_class(self):
"""
Validate that the priority class exists in the cluster (best effort).

Raises ValueError if the priority class is definitively known not to exist.
If we cannot verify (e.g., permission denied), logs a warning and allows submission.
"""
if self.priority_class:
logger.debug(f"Validating priority class '{self.priority_class}'...")
exists = priority_class_exists(self.priority_class)

if exists is False:
# Definitively doesn't exist - fail validation
print(
f"❌ Priority class '{self.priority_class}' does not exist in the cluster. "
f"Submission cancelled."
)
raise ValueError(
f"Priority class '{self.priority_class}' does not exist"
)
elif exists is None:
# Cannot verify - log warning and allow submission
logger.warning(
f"Could not verify if priority class '{self.priority_class}' exists. "
f"Proceeding with submission - Kueue will validate on admission."
)
else:
# exists is True - validation passed
logger.debug(f"Priority class '{self.priority_class}' verified.")

def _validate_working_dir_entrypoint(self):
"""
Validate entrypoint file configuration.
Expand Down
Loading
Loading