From 0bc477dd5fa0536b48b06faf348a0ff2c35ee31b Mon Sep 17 00:00:00 2001
From: Antonin Stefanutti <antonin@stefanutti.fr>
Date: Fri, 21 Jun 2024 17:26:11 +0200
Subject: [PATCH] Add arguments to pass Ray cluster head and worker templates

---
 poetry.lock                              | 13 +++++++-
 pyproject.toml                           |  1 +
 src/codeflare_sdk/cluster/cluster.py     | 10 +++---
 src/codeflare_sdk/cluster/config.py      |  4 +++
 src/codeflare_sdk/utils/generate_yaml.py | 24 +++++++++++---
 tests/unit_test.py                       | 42 +++++++++++++++++++++---
 6 files changed, 79 insertions(+), 15 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index 1868163e..938b69e3 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1273,6 +1273,17 @@ docs = ["IPython", "bump2version", "furo", "sphinx", "sphinx-argparse", "towncri
 lint = ["black", "check-manifest", "flake8", "isort", "mypy"]
 test = ["Cython", "greenlet", "ipython", "pytest", "pytest-cov", "setuptools"]
 
+[[package]]
+name = "mergedeep"
+version = "1.3.4"
+description = "A deep merge function for 🐍."
+optional = false
+python-versions = ">=3.6"
+files = [
+    {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"},
+    {file = "mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8"},
+]
+
 [[package]]
 name = "msgpack"
 version = "1.0.8"
@@ -2795,4 +2806,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9"
-content-hash = "d656bab99c2e5a911ee1003db9e0682141328ae3ef1e1620945f8479451425bf"
+content-hash = "641a3685dcb9a044e49d903bdb0d8911d410f7a65e7835dda087a194c47e3c64"
diff --git a/pyproject.toml b/pyproject.toml
index af7dd1ca..5841ebd5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ cryptography = "40.0.2"
 executing = "1.2.0"
 pydantic = "< 2"
 ipywidgets = "8.1.2"
+mergedeep = "1.3.4"
 
 [tool.poetry.group.docs]
 optional = true
diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py
index 015f15ed..149a5d48 100644
--- a/src/codeflare_sdk/cluster/cluster.py
+++ b/src/codeflare_sdk/cluster/cluster.py
@@ -18,11 +18,9 @@
 cluster setup queue, a list of all existing clusters, and the user's working namespace.
 """
 
-import re
 from time import sleep
 from typing import List, Optional, Tuple, Dict
 
-from kubernetes import config
 from ray.job_submission import JobSubmissionClient
 
 from .auth import config_check, api_config_handler
@@ -41,13 +39,11 @@
     RayCluster,
     RayClusterStatus,
 )
-from kubernetes import client, config
-from kubernetes.utils import parse_quantity
 import yaml
 import os
 import requests
 
-from kubernetes import config
+from kubernetes import client, config
 from kubernetes.client.rest import ApiException
 
 
@@ -145,6 +141,8 @@ def create_app_wrapper(self):
         gpu = self.config.num_gpus
         workers = self.config.num_workers
         template = self.config.template
+        head_template = self.config.head_template
+        worker_template = self.config.worker_template
         image = self.config.image
         appwrapper = self.config.appwrapper
         env = self.config.envs
@@ -167,6 +165,8 @@ def create_app_wrapper(self):
             gpu=gpu,
             workers=workers,
             template=template,
+            head_template=head_template,
+            worker_template=worker_template,
             image=image,
             appwrapper=appwrapper,
             env=env,
diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py
index 97067365..66bf7a13 100644
--- a/src/codeflare_sdk/cluster/config.py
+++ b/src/codeflare_sdk/cluster/config.py
@@ -22,6 +22,8 @@
 import pathlib
 import typing
 
+import kubernetes
+
 dir = pathlib.Path(__file__).parent.parent.resolve()
 
 
@@ -46,6 +48,8 @@ class ClusterConfiguration:
     max_memory: typing.Union[int, str] = 2
     num_gpus: int = 0
     template: str = f"{dir}/templates/base-template.yaml"
+    head_template: kubernetes.client.V1PodTemplateSpec = None
+    worker_template: kubernetes.client.V1PodTemplateSpec = None
     appwrapper: bool = False
     envs: dict = field(default_factory=dict)
     image: str = ""
diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py
index 3192ae1b..56b79352 100755
--- a/src/codeflare_sdk/utils/generate_yaml.py
+++ b/src/codeflare_sdk/utils/generate_yaml.py
@@ -20,16 +20,12 @@
 from typing import Optional
 import typing
 import yaml
-import sys
 import os
-import argparse
 import uuid
 from kubernetes import client, config
 from .kube_api_helpers import _kube_api_error_handling
 from ..cluster.auth import api_config_handler, config_check
-from os import urandom
-from base64 import b64encode
-from urllib3.util import parse_url
+from mergedeep import merge, Strategy
 
 
 def read_template(template):
@@ -278,6 +274,16 @@ def write_user_yaml(user_yaml, output_file_name):
     print(f"Written to: {output_file_name}")
 
 
+def apply_head_template(cluster_yaml: dict, head_template: client.V1PodTemplateSpec):
+    head = cluster_yaml.get("spec").get("headGroupSpec")
+    merge(head["template"], head_template.to_dict(), strategy=Strategy.ADDITIVE)
+
+
+def apply_worker_template(cluster_yaml: dict, worker_template: client.V1PodTemplateSpec):
+    worker = cluster_yaml.get("spec").get("workerGroupSpecs")[0]
+    merge(worker["template"], worker_template.to_dict(), strategy=Strategy.ADDITIVE)
+
+
 def generate_appwrapper(
     name: str,
     namespace: str,
@@ -291,6 +297,8 @@ def generate_appwrapper(
     gpu: int,
     workers: int,
     template: str,
+    head_template: client.V1PodTemplateSpec,
+    worker_template: client.V1PodTemplateSpec,
     image: str,
     appwrapper: bool,
     env,
@@ -302,6 +310,12 @@ def generate_appwrapper(
     volume_mounts: list[client.V1VolumeMount],
 ):
     cluster_yaml = read_template(template)
+
+    if head_template:
+        apply_head_template(cluster_yaml, head_template)
+    if worker_template:
+        apply_worker_template(cluster_yaml, worker_template)
+
     appwrapper_name, cluster_name = gen_names(name)
     update_names(cluster_yaml, cluster_name, namespace)
     update_nodes(
diff --git a/tests/unit_test.py b/tests/unit_test.py
index db908df6..31c25a06 100644
--- a/tests/unit_test.py
+++ b/tests/unit_test.py
@@ -20,8 +20,6 @@
 import re
 import uuid
 
-from codeflare_sdk.cluster import cluster
-
 parent = Path(__file__).resolve().parents[1]
 aw_dir = os.path.expanduser("~/.codeflare/resources/")
 sys.path.append(str(parent) + "/src")
@@ -69,17 +67,18 @@
     createClusterConfig,
 )
 
-import codeflare_sdk.utils.kube_api_helpers
 from codeflare_sdk.utils.generate_yaml import (
     gen_names,
     is_openshift_cluster,
 )
 
 import openshift
-from openshift.selector import Selector
 import ray
 import pytest
 import yaml
+
+from kubernetes.client import V1PodTemplateSpec, V1PodSpec, V1Toleration
+
 from unittest.mock import MagicMock
 from pytest_mock import MockerFixture
 from ray.job_submission import JobSubmissionClient
@@ -268,6 +267,41 @@ def test_config_creation():
     assert config.appwrapper == True
 
 
+def test_cluster_config_with_worker_template(mocker):
+    mocker.patch("kubernetes.client.ApisApi.get_api_versions")
+    mocker.patch(
+        "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
+        return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
+    )
+
+    cluster = Cluster(ClusterConfiguration(
+        name="unit-test-cluster",
+        namespace="ns",
+        num_workers=2,
+        min_cpus=3,
+        max_cpus=4,
+        min_memory=5,
+        max_memory=6,
+        num_gpus=7,
+        image="test/ray:2.20.0-py39-cu118",
+        worker_template=V1PodTemplateSpec(
+            spec=V1PodSpec(
+                containers=[],
+                tolerations=[V1Toleration(
+                    key="nvidia.com/gpu",
+                    operator="Exists",
+                    effect="NoSchedule",
+                )],
+                node_selector={
+                    "nvidia.com/gpu.present": "true",
+                },
+            )
+        ),
+    ))
+
+    assert cluster
+
+
 def test_cluster_creation(mocker):
     # Create AppWrapper containing a Ray Cluster with no local queue specified
     mocker.patch("kubernetes.client.ApisApi.get_api_versions")