From ee1a6f2051ce1a45dff7917ef7393223824f656e Mon Sep 17 00:00:00 2001 From: kryanbeane Date: Wed, 5 Nov 2025 12:12:51 +0000 Subject: [PATCH] fix(RHOAIENG-37842): fix heterogeneous oauth test --- .../e2e/heterogeneous_clusters_oauth_test.py | 16 +-- tests/e2e/local_interactive_sdk_oauth_test.py | 1 + tests/e2e/support.py | 117 +++++++++++++++++- 3 files changed, 120 insertions(+), 14 deletions(-) diff --git a/tests/e2e/heterogeneous_clusters_oauth_test.py b/tests/e2e/heterogeneous_clusters_oauth_test.py index 0fbe4df3..8b8ef340 100644 --- a/tests/e2e/heterogeneous_clusters_oauth_test.py +++ b/tests/e2e/heterogeneous_clusters_oauth_test.py @@ -1,5 +1,3 @@ -from time import sleep -import time from codeflare_sdk import ( Cluster, ClusterConfiguration, @@ -55,11 +53,13 @@ def run_heterogeneous_clusters( namespace=self.namespace, name=cluster_name, num_workers=1, - head_cpu_requests=1, - head_cpu_limits=1, - worker_cpu_requests=1, + head_cpu_requests="500m", + head_cpu_limits="500m", + head_memory_requests=2, + head_memory_limits=4, + worker_cpu_requests="500m", worker_cpu_limits=1, - worker_memory_requests=1, + worker_memory_requests=2, worker_memory_limits=4, image=ray_image, verify_tls=False, @@ -67,10 +67,10 @@ def run_heterogeneous_clusters( ) ) cluster.apply() - sleep(5) + # Wait for the cluster to be scheduled and ready, we don't need the dashboard for this check + cluster.wait_ready(dashboard_check=False) node_name = get_pod_node(self, self.namespace, cluster_name) print(f"Cluster {cluster_name}-{flavor} is running on node: {node_name}") - sleep(5) assert ( node_name in expected_nodes ), f"Node {node_name} is not in the expected nodes for flavor {flavor}." diff --git a/tests/e2e/local_interactive_sdk_oauth_test.py b/tests/e2e/local_interactive_sdk_oauth_test.py index 8be0bf9c..a5faad2a 100644 --- a/tests/e2e/local_interactive_sdk_oauth_test.py +++ b/tests/e2e/local_interactive_sdk_oauth_test.py @@ -12,6 +12,7 @@ from support import * +@pytest.mark.skip(reason="Remote ray.init() is temporarily unsupported") @pytest.mark.openshift class TestRayLocalInteractiveOauth: def setup_method(self): diff --git a/tests/e2e/support.py b/tests/e2e/support.py index 85b3dd35..ea8f6c45 100644 --- a/tests/e2e/support.py +++ b/tests/e2e/support.py @@ -5,6 +5,7 @@ from time import sleep from codeflare_sdk import get_cluster from kubernetes import client, config +from kubernetes.client import V1Toleration from codeflare_sdk.common.kubernetes_cluster.kube_api_helpers import ( _kube_api_error_handling, ) @@ -146,6 +147,92 @@ def random_choice(): return "".join(random.choices(alphabet, k=5)) +def _parse_label_env(env_var, default): + """Parse label from environment variable (format: 'key=value').""" + label_str = os.getenv(env_var, default) + return label_str.split("=") + + +def get_master_taint_key(self): + """ + Detect the actual master/control-plane taint key from nodes. + Returns the taint key if found, or defaults to control-plane. + """ + # Check env var first (most efficient) + if os.getenv("TOLERATION_KEY"): + return os.getenv("TOLERATION_KEY") + + # Try to detect from cluster nodes + try: + nodes = self.api_instance.list_node() + taint_key = next( + ( + taint.key + for node in nodes.items + if node.spec.taints + for taint in node.spec.taints + if taint.key + in [ + "node-role.kubernetes.io/master", + "node-role.kubernetes.io/control-plane", + ] + ), + None, + ) + if taint_key: + return taint_key + except Exception as e: + print(f"Warning: Could not detect master taint key: {e}") + + # Default fallback + return "node-role.kubernetes.io/control-plane" + + +def ensure_nodes_labeled_for_flavors(self, num_flavors, with_labels): + """ + Check if required node labels exist for ResourceFlavor targeting. + This handles both default (worker-1=true) and non-default (ingress-ready=true) flavors. + + NOTE: This function does NOT modify cluster nodes. It only checks if required labels exist. + If labels don't exist, the test will use whatever labels are available on the cluster. + For shared clusters, set WORKER_LABEL and CONTROL_LABEL env vars to match existing labels. + """ + if not with_labels: + return + + worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true") + control_label, control_value = _parse_label_env( + "CONTROL_LABEL", "ingress-ready=true" + ) + + try: + worker_nodes = self.api_instance.list_node( + label_selector="node-role.kubernetes.io/worker" + ) + + if not worker_nodes.items: + print("Warning: No worker nodes found") + return + + # Check labels based on num_flavors + labels_to_check = [("WORKER_LABEL", worker_label, worker_value)] + if num_flavors > 1: + labels_to_check.append(("CONTROL_LABEL", control_label, control_value)) + + for env_var, label, value in labels_to_check: + has_label = any( + node.metadata.labels and node.metadata.labels.get(label) == value + for node in worker_nodes.items + ) + if not has_label: + print( + f"Warning: Label {label}={value} not found (set {env_var} env var to match existing labels)" + ) + + except Exception as e: + print(f"Warning: Could not check existing labels: {e}") + + def create_namespace(self): try: self.namespace = f"test-ns-{random_choice()}" @@ -280,14 +367,13 @@ def create_cluster_queue(self, cluster_queue, flavor): def create_resource_flavor( self, flavor, default=True, with_labels=False, with_tolerations=False ): - worker_label, worker_value = os.getenv("WORKER_LABEL", "worker-1=true").split("=") - control_label, control_value = os.getenv( + worker_label, worker_value = _parse_label_env("WORKER_LABEL", "worker-1=true") + control_label, control_value = _parse_label_env( "CONTROL_LABEL", "ingress-ready=true" - ).split("=") - toleration_key = os.getenv( - "TOLERATION_KEY", "node-role.kubernetes.io/control-plane" ) + toleration_key = os.getenv("TOLERATION_KEY") or get_master_taint_key(self) + node_labels = {} if with_labels: node_labels = ( @@ -451,6 +537,25 @@ def get_nodes_by_label(self, node_labels): return [node.metadata.name for node in nodes.items] +def get_tolerations_from_flavor(self, flavor_name): + """ + Extract tolerations from a ResourceFlavor and convert them to V1Toleration objects. + Returns a list of V1Toleration objects, or empty list if no tolerations found. + """ + flavor_spec = get_flavor_spec(self, flavor_name) + tolerations_spec = flavor_spec.get("spec", {}).get("tolerations", []) + + return [ + V1Toleration( + key=tol_spec.get("key"), + operator=tol_spec.get("operator", "Equal"), + value=tol_spec.get("value"), + effect=tol_spec.get("effect"), + ) + for tol_spec in tolerations_spec + ] + + def assert_get_cluster_and_jobsubmit( self, cluster_name, accelerator=None, number_of_gpus=None ): @@ -514,7 +619,7 @@ def wait_for_kueue_admission(self, job_api, job_name, namespace, timeout=120): workload = get_kueue_workload_for_job(self, job_name, namespace) if workload: conditions = workload.get("status", {}).get("conditions", []) - print(f" DEBUG: Workload conditions for '{job_name}':") + print(f"DEBUG: Workload conditions for '{job_name}':") for condition in conditions: print( f" - {condition.get('type')}: {condition.get('status')} - {condition.get('reason', '')} - {condition.get('message', '')}"