2929from os import urandom
3030from base64 import b64encode
3131from urllib3 .util import parse_url
32+ import json
33+
34+ from codeflare_sdk .cluster .config import DEFAULT_CUSTOM_RESOURCE_MAPPING
3235
3336
3437def read_template (template ):
@@ -264,50 +267,36 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
264267
265268def update_custompodresources (
266269 item ,
267- min_cpu ,
268- max_cpu ,
269- min_memory ,
270- max_memory ,
271- gpu ,
270+ min_cpu : int ,
271+ max_cpu : int ,
272+ min_memory : int ,
273+ max_memory : int ,
274+ worker_custom_resources : typing . Dict [ str , int ] ,
272275 workers ,
273- head_cpus ,
274- head_memory ,
275- head_gpus ,
276+ head_cpus : int ,
277+ head_memory : int ,
278+ head_custom_resources : typing . Dict [ str , int ] ,
276279):
277280 if "custompodresources" in item .keys ():
278281 custompodresources = item .get ("custompodresources" )
279- for i in range (len (custompodresources )):
280- resource = custompodresources [i ]
281- if i == 0 :
282- # Leave head node resources as template default
283- resource ["requests" ]["cpu" ] = head_cpus
284- resource ["limits" ]["cpu" ] = head_cpus
285- resource ["requests" ]["memory" ] = str (head_memory ) + "G"
286- resource ["limits" ]["memory" ] = str (head_memory ) + "G"
287- resource ["requests" ]["nvidia.com/gpu" ] = head_gpus
288- resource ["limits" ]["nvidia.com/gpu" ] = head_gpus
289-
290- else :
291- for k , v in resource .items ():
292- if k == "replicas" and i == 1 :
293- resource [k ] = workers
294- if k == "requests" or k == "limits" :
295- for spec , _ in v .items ():
296- if spec == "cpu" :
297- if k == "limits" :
298- resource [k ][spec ] = max_cpu
299- else :
300- resource [k ][spec ] = min_cpu
301- if spec == "memory" :
302- if k == "limits" :
303- resource [k ][spec ] = str (max_memory ) + "G"
304- else :
305- resource [k ][spec ] = str (min_memory ) + "G"
306- if spec == "nvidia.com/gpu" :
307- if i == 0 :
308- resource [k ][spec ] = 0
309- else :
310- resource [k ][spec ] = gpu
282+ head_resources = custompodresources [0 ]
283+ head_resources ["requests" ]["cpu" ] = head_cpus
284+ head_resources ["limits" ]["cpu" ] = head_cpus
285+ head_resources ["requests" ]["memory" ] = str (head_memory ) + "G"
286+ head_resources ["limits" ]["memory" ] = str (head_memory ) + "G"
287+ for r , val in head_custom_resources .items ():
288+ head_resources ["requests" ][r ] = val
289+ head_resources ["limits" ][r ] = val
290+
291+ worker_resources = custompodresources [1 ]
292+ worker_resources ["replicas" ] = workers
293+ worker_resources ["requests" ]["cpu" ] = min_cpu
294+ worker_resources ["limits" ]["cpu" ] = max_cpu
295+ worker_resources ["requests" ]["memory" ] = f"{ min_memory } G"
296+ worker_resources ["limits" ]["memory" ] = f"{ max_memory } G"
297+ for r , value in worker_custom_resources .items ():
298+ worker_resources ["requests" ][r ] = value
299+ worker_resources ["limits" ][r ] = value
311300 else :
312301 sys .exit ("Error: malformed template" )
313302
@@ -349,19 +338,60 @@ def update_env(spec, env):
349338 container ["env" ] = env
350339
351340
352- def update_resources (spec , min_cpu , max_cpu , min_memory , max_memory , gpu ):
341+ def update_resources (
342+ spec : dict ,
343+ min_cpu : int ,
344+ max_cpu : int ,
345+ min_memory : int ,
346+ max_memory : int ,
347+ worker_custom_resources : typing .Dict [str , int ],
348+ ):
353349 container = spec .get ("containers" )
354350 for resource in container :
355351 requests = resource .get ("resources" ).get ("requests" )
356352 if requests is not None :
357353 requests ["cpu" ] = min_cpu
358354 requests ["memory" ] = str (min_memory ) + "G"
359- requests ["nvidia.com/gpu" ] = gpu
355+ for r , value in worker_custom_resources .items ():
356+ requests [r ] = value
360357 limits = resource .get ("resources" ).get ("limits" )
361358 if limits is not None :
362359 limits ["cpu" ] = max_cpu
363360 limits ["memory" ] = str (max_memory ) + "G"
364- limits ["nvidia.com/gpu" ] = gpu
361+ for r , value in worker_custom_resources .items ():
362+ limits [r ] = value
363+
364+
365+ def _get_resource_mapping (
366+ resource : str , custom_mapping : typing .Optional [typing .Dict [str , str ]]
367+ ):
368+ # throws value error if no mapping exists
369+ mapping = custom_mapping or {}
370+ return mapping .get (resource , None ) or DEFAULT_CUSTOM_RESOURCE_MAPPING [resource ]
371+
372+
373+ def _get_ray_start_params_from_resources (
374+ start_params : typing .Dict ,
375+ resources : typing .Dict [str , int ],
376+ custom_mapping : typing .Optional [typing .Dict [str , str ]],
377+ ):
378+ ray_resources = {}
379+ for r , value in resources .items ():
380+ ray_resource = _get_resource_mapping (r , custom_mapping )
381+ if ray_resource == "GPU" :
382+ start_params ["num-gpus" ] = start_params .get ("num-gpus" , 0 ) + value
383+ else :
384+ ray_resources [ray_resource ] = ray_resources .get (ray_resource , 0 ) + value
385+
386+ # this looks ugly, but it's to get the string into the same form as it appeears here
387+ # https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/config.html#id1
388+ if ray_resources :
389+ start_params ["resources" ] = (
390+ '"' + json .dumps (ray_resources ).replace ('"' , '\\ "' ) + '"'
391+ )
392+ if start_params .get ("num-gpus" ) is not None :
393+ start_params ["num-gpus" ] = str (start_params ["num-gpus" ])
394+ return start_params
365395
366396
367397def update_nodes (
@@ -371,27 +401,36 @@ def update_nodes(
371401 max_cpu ,
372402 min_memory ,
373403 max_memory ,
374- gpu ,
404+ worker_custom_resources : typing . Dict [ str , int ] ,
375405 workers ,
376406 image ,
377407 instascale ,
378408 env ,
379409 image_pull_secrets ,
380410 head_cpus ,
381411 head_memory ,
382- head_gpus ,
412+ head_custom_resources : typing .Dict [str , int ],
413+ custom_resource_mapping : typing .Optional [typing .Dict [str , str ]] = None ,
383414):
384415 if "generictemplate" in item .keys ():
385416 head = item .get ("generictemplate" ).get ("spec" ).get ("headGroupSpec" )
386- head ["rayStartParams" ]["num-gpus" ] = str (int (head_gpus ))
417+
418+ # TODO: should get custom resources too
419+ head ["rayStartParams" ] = _get_ray_start_params_from_resources (
420+ head ["rayStartParams" ], head_custom_resources , custom_resource_mapping
421+ )
387422
388423 worker = item .get ("generictemplate" ).get ("spec" ).get ("workerGroupSpecs" )[0 ]
389424 # Head counts as first worker
390425 worker ["replicas" ] = workers
391426 worker ["minReplicas" ] = workers
392427 worker ["maxReplicas" ] = workers
393428 worker ["groupName" ] = "small-group-" + appwrapper_name
394- worker ["rayStartParams" ]["num-gpus" ] = str (int (gpu ))
429+
430+ # TODO: should get custom resources too
431+ worker ["rayStartParams" ] = _get_ray_start_params_from_resources (
432+ worker ["rayStartParams" ], worker_custom_resources , custom_resource_mapping
433+ )
395434
396435 for comp in [head , worker ]:
397436 spec = comp .get ("template" ).get ("spec" )
@@ -402,10 +441,22 @@ def update_nodes(
402441 if comp == head :
403442 # TODO: Eventually add head node configuration outside of template
404443 update_resources (
405- spec , head_cpus , head_cpus , head_memory , head_memory , head_gpus
444+ spec ,
445+ head_cpus ,
446+ head_cpus ,
447+ head_memory ,
448+ head_memory ,
449+ head_custom_resources ,
406450 )
407451 else :
408- update_resources (spec , min_cpu , max_cpu , min_memory , max_memory , gpu )
452+ update_resources (
453+ spec ,
454+ min_cpu ,
455+ max_cpu ,
456+ min_memory ,
457+ max_memory ,
458+ worker_custom_resources ,
459+ )
409460
410461
411462def update_ca_secret (ca_secret_item , cluster_name , namespace ):
@@ -645,12 +696,12 @@ def generate_appwrapper(
645696 namespace : str ,
646697 head_cpus : int ,
647698 head_memory : int ,
648- head_gpus : int ,
699+ head_custom_resources : typing . Dict [ str , int ] ,
649700 min_cpu : int ,
650701 max_cpu : int ,
651702 min_memory : int ,
652703 max_memory : int ,
653- gpu : int ,
704+ worker_custom_resources : typing . Dict [ str , int ] ,
654705 workers : int ,
655706 template : str ,
656707 image : str ,
@@ -665,6 +716,7 @@ def generate_appwrapper(
665716 openshift_oauth : bool ,
666717 ingress_domain : str ,
667718 ingress_options : dict ,
719+ custom_resource_mapping : typing .Dict [str , str ],
668720):
669721 user_yaml = read_template (template )
670722 appwrapper_name , cluster_name = gen_names (name )
@@ -681,11 +733,11 @@ def generate_appwrapper(
681733 max_cpu ,
682734 min_memory ,
683735 max_memory ,
684- gpu ,
736+ worker_custom_resources ,
685737 workers ,
686738 head_cpus ,
687739 head_memory ,
688- head_gpus ,
740+ head_custom_resources ,
689741 )
690742 update_nodes (
691743 item ,
@@ -694,15 +746,16 @@ def generate_appwrapper(
694746 max_cpu ,
695747 min_memory ,
696748 max_memory ,
697- gpu ,
749+ worker_custom_resources ,
698750 workers ,
699751 image ,
700752 instascale ,
701753 env ,
702754 image_pull_secrets ,
703755 head_cpus ,
704756 head_memory ,
705- head_gpus ,
757+ head_custom_resources ,
758+ custom_resource_mapping = custom_resource_mapping ,
706759 )
707760 update_dashboard_exposure (
708761 ingress_item ,
0 commit comments