diff --git a/src/codeflare_sdk/job/jobs.py b/src/codeflare_sdk/job/jobs.py index 655107df..73a95e98 100644 --- a/src/codeflare_sdk/job/jobs.py +++ b/src/codeflare_sdk/job/jobs.py @@ -69,6 +69,7 @@ def __init__( scheduler_args: Optional[Dict[str, str]] = None, image: Optional[str] = None, workspace: Optional[str] = f"file://{Path.cwd()}", + mcad: Optional[str] = False, ): if bool(script) == bool(m): # logical XOR raise ValueError( @@ -93,6 +94,7 @@ def __init__( ) self.image = image self.workspace = workspace + self.mcad = mcad def _dry_run(self, cluster: "Cluster"): j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus @@ -136,6 +138,9 @@ def _dry_run_no_cluster(self): if self.scheduler_args is not None: if self.scheduler_args.get("namespace") is None: self.scheduler_args["namespace"] = get_current_namespace() + scheduler = "kueue" + if self.mcad == True: + scheduler = "kubernetes_mcad" runner = get_runner() return ( runner.dryrun( @@ -172,7 +177,7 @@ def _dry_run_no_cluster(self): if self.image is not None else self._missing_spec("image"), ), - scheduler="kubernetes_mcad", + scheduler=scheduler, cfg=self.scheduler_args, workspace="", ),