@@ -69,6 +69,7 @@ def __init__(
6969 scheduler_args : Optional [Dict [str , str ]] = None ,
7070 image : Optional [str ] = None ,
7171 workspace : Optional [str ] = f"file://{ Path .cwd ()} " ,
72+ mcad : Optional [str ] = False ,
7273 ):
7374 if bool (script ) == bool (m ): # logical XOR
7475 raise ValueError (
@@ -93,6 +94,7 @@ def __init__(
9394 )
9495 self .image = image
9596 self .workspace = workspace
97+ self .mcad = mcad
9698
9799 def _dry_run (self , cluster : "Cluster" ):
98100 j = f"{ cluster .config .num_workers } x{ max (cluster .config .num_gpus , 1 )} " # # of proc. = # of gpus
@@ -136,6 +138,9 @@ def _dry_run_no_cluster(self):
136138 if self .scheduler_args is not None :
137139 if self .scheduler_args .get ("namespace" ) is None :
138140 self .scheduler_args ["namespace" ] = get_current_namespace ()
141+ scheduler = "kueue_job"
142+ if self .mcad == True :
143+ scheduler = "kubernetes_mcad"
139144 runner = get_runner ()
140145 return (
141146 runner .dryrun (
@@ -172,7 +177,7 @@ def _dry_run_no_cluster(self):
172177 if self .image is not None
173178 else self ._missing_spec ("image" ),
174179 ),
175- scheduler = "kubernetes_mcad" ,
180+ scheduler = scheduler ,
176181 cfg = self .scheduler_args ,
177182 workspace = "" ,
178183 ),
0 commit comments