Skip to content

Commit 936ca30

Browse files
committed
add support for custom accelerators
Signed-off-by: Kevin <kpostlet@redhat.com>
1 parent 47381fb commit 936ca30

File tree

4 files changed

+149
-58
lines changed

4 files changed

+149
-58
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ def create_app_wrapper(self):
197197
namespace=namespace,
198198
head_cpus=head_cpus,
199199
head_memory=head_memory,
200-
head_gpus=head_gpus,
200+
head_custom_resources=self.config.head_custom_resources,
201201
min_cpu=min_cpu,
202202
max_cpu=max_cpu,
203203
min_memory=min_memory,
204204
max_memory=max_memory,
205-
gpu=gpu,
205+
worker_custom_resources=self.config.worker_custom_resources,
206206
workers=workers,
207207
template=template,
208208
image=image,
@@ -217,6 +217,7 @@ def create_app_wrapper(self):
217217
openshift_oauth=self.config.openshift_oauth,
218218
ingress_domain=ingress_domain,
219219
ingress_options=ingress_options,
220+
custom_resource_mapping=self.config.custom_resource_mapping,
220221
)
221222

222223
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,18 @@
2020

2121
from dataclasses import dataclass, field
2222
import pathlib
23+
import warnings
2324

2425
dir = pathlib.Path(__file__).parent.parent.resolve()
2526

27+
# NOTE: this isn't an enum because the values for ray custom resources can be arbitrary strings
28+
DEFAULT_CUSTOM_RESOURCE_MAPPING = {
29+
"nvidia.com/gpu": "GPU",
30+
"gpu.intel.com/i915": "GPU",
31+
"habana.ai/gaudi": "HPU",
32+
"google.com/tpu": "TPU",
33+
}
34+
2635

2736
@dataclass
2837
class ClusterConfiguration:
@@ -55,3 +64,32 @@ class ClusterConfiguration:
5564
openshift_oauth: bool = False # NOTE: to use the user must have permission to create a RoleBinding for system:auth-delegator
5665
ingress_options: dict = field(default_factory=dict)
5766
ingress_domain: str = None
67+
head_custom_resources: dict = field(default_factory=dict)
68+
worker_custom_resources: dict = field(default_factory=dict)
69+
custom_resource_mapping: dict = field(
70+
default_factory=dict
71+
) # for custom resources not in the default mapping
72+
73+
def __post_init__(self):
74+
if (
75+
self.head_gpus
76+
and self.head_custom_resources
77+
or self.num_gpus
78+
and self.worker_custom_resources
79+
):
80+
raise ValueError(
81+
"Cannot set both head_gpus and head_custom_resources or num_gpus and worker_custom_resources"
82+
)
83+
if self.head_gpus != 0:
84+
warnings.warn(
85+
"head_gpus being deprecated, use gpu_custom_resources with resource 'nvidia.com/gpu'",
86+
PendingDeprecationWarning,
87+
)
88+
self.head_custom_resources["nvidia.com/gpu"] = self.head_gpus
89+
if self.num_gpus != 0:
90+
warnings.warn(
91+
"num_gpus being deprecated use worker_custom_resources with resource 'nvidia.com/gpu'",
92+
PendingDeprecationWarning,
93+
)
94+
self.worker_custom_resources["nvidia.com/gpu"] = self.num_gpus
95+
pass

src/codeflare_sdk/templates/base-template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ spec:
226226
# the following params are used to complete the ray start: ray start --block ...
227227
rayStartParams:
228228
block: 'true'
229-
num-gpus: 1
230229
#pod template
231230
template:
232231
metadata:

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 108 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from os import urandom
3030
from base64 import b64encode
3131
from urllib3.util import parse_url
32+
import json
33+
34+
from codeflare_sdk.cluster.config import DEFAULT_CUSTOM_RESOURCE_MAPPING
3235

3336

3437
def read_template(template):
@@ -264,50 +267,36 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
264267

265268
def update_custompodresources(
266269
item,
267-
min_cpu,
268-
max_cpu,
269-
min_memory,
270-
max_memory,
271-
gpu,
270+
min_cpu: int,
271+
max_cpu: int,
272+
min_memory: int,
273+
max_memory: int,
274+
worker_custom_resources: typing.Dict[str, int],
272275
workers,
273-
head_cpus,
274-
head_memory,
275-
head_gpus,
276+
head_cpus: int,
277+
head_memory: int,
278+
head_custom_resources: typing.Dict[str, int],
276279
):
277280
if "custompodresources" in item.keys():
278281
custompodresources = item.get("custompodresources")
279-
for i in range(len(custompodresources)):
280-
resource = custompodresources[i]
281-
if i == 0:
282-
# Leave head node resources as template default
283-
resource["requests"]["cpu"] = head_cpus
284-
resource["limits"]["cpu"] = head_cpus
285-
resource["requests"]["memory"] = str(head_memory) + "G"
286-
resource["limits"]["memory"] = str(head_memory) + "G"
287-
resource["requests"]["nvidia.com/gpu"] = head_gpus
288-
resource["limits"]["nvidia.com/gpu"] = head_gpus
289-
290-
else:
291-
for k, v in resource.items():
292-
if k == "replicas" and i == 1:
293-
resource[k] = workers
294-
if k == "requests" or k == "limits":
295-
for spec, _ in v.items():
296-
if spec == "cpu":
297-
if k == "limits":
298-
resource[k][spec] = max_cpu
299-
else:
300-
resource[k][spec] = min_cpu
301-
if spec == "memory":
302-
if k == "limits":
303-
resource[k][spec] = str(max_memory) + "G"
304-
else:
305-
resource[k][spec] = str(min_memory) + "G"
306-
if spec == "nvidia.com/gpu":
307-
if i == 0:
308-
resource[k][spec] = 0
309-
else:
310-
resource[k][spec] = gpu
282+
head_resources = custompodresources[0]
283+
head_resources["requests"]["cpu"] = head_cpus
284+
head_resources["limits"]["cpu"] = head_cpus
285+
head_resources["requests"]["memory"] = str(head_memory) + "G"
286+
head_resources["limits"]["memory"] = str(head_memory) + "G"
287+
for r, val in head_custom_resources.items():
288+
head_resources["requests"][r] = val
289+
head_resources["limits"][r] = val
290+
291+
worker_resources = custompodresources[1]
292+
worker_resources["replicas"] = workers
293+
worker_resources["requests"]["cpu"] = min_cpu
294+
worker_resources["limits"]["cpu"] = max_cpu
295+
worker_resources["requests"]["memory"] = f"{min_memory}G"
296+
worker_resources["limits"]["memory"] = f"{max_memory}G"
297+
for r, value in worker_custom_resources.items():
298+
worker_resources["requests"][r] = value
299+
worker_resources["limits"][r] = value
311300
else:
312301
sys.exit("Error: malformed template")
313302

@@ -349,19 +338,60 @@ def update_env(spec, env):
349338
container["env"] = env
350339

351340

352-
def update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu):
341+
def update_resources(
342+
spec: dict,
343+
min_cpu: int,
344+
max_cpu: int,
345+
min_memory: int,
346+
max_memory: int,
347+
worker_custom_resources: typing.Dict[str, int],
348+
):
353349
container = spec.get("containers")
354350
for resource in container:
355351
requests = resource.get("resources").get("requests")
356352
if requests is not None:
357353
requests["cpu"] = min_cpu
358354
requests["memory"] = str(min_memory) + "G"
359-
requests["nvidia.com/gpu"] = gpu
355+
for r, value in worker_custom_resources.items():
356+
requests[r] = value
360357
limits = resource.get("resources").get("limits")
361358
if limits is not None:
362359
limits["cpu"] = max_cpu
363360
limits["memory"] = str(max_memory) + "G"
364-
limits["nvidia.com/gpu"] = gpu
361+
for r, value in worker_custom_resources.items():
362+
limits[r] = value
363+
364+
365+
def _get_resource_mapping(
366+
resource: str, custom_mapping: typing.Optional[typing.Dict[str, str]]
367+
):
368+
# throws value error if no mapping exists
369+
mapping = custom_mapping or {}
370+
return mapping.get(resource, None) or DEFAULT_CUSTOM_RESOURCE_MAPPING[resource]
371+
372+
373+
def _get_ray_start_params_from_resources(
374+
start_params: typing.Dict,
375+
resources: typing.Dict[str, int],
376+
custom_mapping: typing.Optional[typing.Dict[str, str]],
377+
):
378+
ray_resources = {}
379+
for r, value in resources.items():
380+
ray_resource = _get_resource_mapping(r, custom_mapping)
381+
if ray_resource == "GPU":
382+
start_params["num-gpus"] = start_params.get("num-gpus", 0) + value
383+
else:
384+
ray_resources[ray_resource] = ray_resources.get(ray_resource, 0) + value
385+
386+
# this looks ugly, but it's to get the string into the same form as it appeears here
387+
# https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/config.html#id1
388+
if ray_resources:
389+
start_params["resources"] = (
390+
'"' + json.dumps(ray_resources).replace('"', '\\"') + '"'
391+
)
392+
if start_params.get("num-gpus") is not None:
393+
start_params["num-gpus"] = str(start_params["num-gpus"])
394+
return start_params
365395

366396

367397
def update_nodes(
@@ -371,27 +401,36 @@ def update_nodes(
371401
max_cpu,
372402
min_memory,
373403
max_memory,
374-
gpu,
404+
worker_custom_resources: typing.Dict[str, int],
375405
workers,
376406
image,
377407
instascale,
378408
env,
379409
image_pull_secrets,
380410
head_cpus,
381411
head_memory,
382-
head_gpus,
412+
head_custom_resources: typing.Dict[str, int],
413+
custom_resource_mapping: typing.Optional[typing.Dict[str, str]] = None,
383414
):
384415
if "generictemplate" in item.keys():
385416
head = item.get("generictemplate").get("spec").get("headGroupSpec")
386-
head["rayStartParams"]["num-gpus"] = str(int(head_gpus))
417+
418+
# TODO: should get custom resources too
419+
head["rayStartParams"] = _get_ray_start_params_from_resources(
420+
head["rayStartParams"], head_custom_resources, custom_resource_mapping
421+
)
387422

388423
worker = item.get("generictemplate").get("spec").get("workerGroupSpecs")[0]
389424
# Head counts as first worker
390425
worker["replicas"] = workers
391426
worker["minReplicas"] = workers
392427
worker["maxReplicas"] = workers
393428
worker["groupName"] = "small-group-" + appwrapper_name
394-
worker["rayStartParams"]["num-gpus"] = str(int(gpu))
429+
430+
# TODO: should get custom resources too
431+
worker["rayStartParams"] = _get_ray_start_params_from_resources(
432+
worker["rayStartParams"], worker_custom_resources, custom_resource_mapping
433+
)
395434

396435
for comp in [head, worker]:
397436
spec = comp.get("template").get("spec")
@@ -402,10 +441,22 @@ def update_nodes(
402441
if comp == head:
403442
# TODO: Eventually add head node configuration outside of template
404443
update_resources(
405-
spec, head_cpus, head_cpus, head_memory, head_memory, head_gpus
444+
spec,
445+
head_cpus,
446+
head_cpus,
447+
head_memory,
448+
head_memory,
449+
head_custom_resources,
406450
)
407451
else:
408-
update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu)
452+
update_resources(
453+
spec,
454+
min_cpu,
455+
max_cpu,
456+
min_memory,
457+
max_memory,
458+
worker_custom_resources,
459+
)
409460

410461

411462
def update_ca_secret(ca_secret_item, cluster_name, namespace):
@@ -645,12 +696,12 @@ def generate_appwrapper(
645696
namespace: str,
646697
head_cpus: int,
647698
head_memory: int,
648-
head_gpus: int,
699+
head_custom_resources: typing.Dict[str, int],
649700
min_cpu: int,
650701
max_cpu: int,
651702
min_memory: int,
652703
max_memory: int,
653-
gpu: int,
704+
worker_custom_resources: typing.Dict[str, int],
654705
workers: int,
655706
template: str,
656707
image: str,
@@ -665,6 +716,7 @@ def generate_appwrapper(
665716
openshift_oauth: bool,
666717
ingress_domain: str,
667718
ingress_options: dict,
719+
custom_resource_mapping: typing.Dict[str, str],
668720
):
669721
user_yaml = read_template(template)
670722
appwrapper_name, cluster_name = gen_names(name)
@@ -681,11 +733,11 @@ def generate_appwrapper(
681733
max_cpu,
682734
min_memory,
683735
max_memory,
684-
gpu,
736+
worker_custom_resources,
685737
workers,
686738
head_cpus,
687739
head_memory,
688-
head_gpus,
740+
head_custom_resources,
689741
)
690742
update_nodes(
691743
item,
@@ -694,15 +746,16 @@ def generate_appwrapper(
694746
max_cpu,
695747
min_memory,
696748
max_memory,
697-
gpu,
749+
worker_custom_resources,
698750
workers,
699751
image,
700752
instascale,
701753
env,
702754
image_pull_secrets,
703755
head_cpus,
704756
head_memory,
705-
head_gpus,
757+
head_custom_resources,
758+
custom_resource_mapping=custom_resource_mapping,
706759
)
707760
update_dashboard_exposure(
708761
ingress_item,

0 commit comments

Comments
 (0)