From 9175ad11fda550331b6f5e71f3bd4527345833d7 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Wed, 29 May 2024 10:03:07 +0200 Subject: [PATCH] Add type hints for input check --- pympipool/shared/inputcheck.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pympipool/shared/inputcheck.py b/pympipool/shared/inputcheck.py index fded2483..11595dfe 100644 --- a/pympipool/shared/inputcheck.py +++ b/pympipool/shared/inputcheck.py @@ -1,7 +1,9 @@ import inspect +from typing import List +from concurrent.futures import Executor -def check_oversubscribe(oversubscribe): +def check_oversubscribe(oversubscribe: bool): if oversubscribe: raise ValueError( "Oversubscribing is not supported for the pympipool.flux.PyFLuxExecutor backend." @@ -9,14 +11,14 @@ def check_oversubscribe(oversubscribe): ) -def check_command_line_argument_lst(command_line_argument_lst): +def check_command_line_argument_lst(command_line_argument_lst: List[str]): if len(command_line_argument_lst) > 0: raise ValueError( "The command_line_argument_lst parameter is not supported for the SLURM backend." ) -def check_gpus_per_worker(gpus_per_worker): +def check_gpus_per_worker(gpus_per_worker: int): if gpus_per_worker != 0: raise TypeError( "GPU assignment is not supported for the pympipool.mpi.PyMPIExecutor backend." @@ -26,7 +28,7 @@ def check_gpus_per_worker(gpus_per_worker): ) -def check_threads_per_core(threads_per_core): +def check_threads_per_core(threads_per_core: int): if threads_per_core != 1: raise TypeError( "Thread based parallelism is not supported for the pympipool.mpi.PyMPIExecutor backend." @@ -36,14 +38,14 @@ def check_threads_per_core(threads_per_core): ) -def check_executor(executor): +def check_executor(executor: Executor): if executor is not None: raise ValueError( "The executor parameter is only supported for the flux framework backend." ) -def check_resource_dict(function): +def check_resource_dict(function: callable): if "resource_dict" in inspect.signature(function).parameters.keys(): raise ValueError( "The parameter resource_dict is used internally in pympipool, " @@ -51,21 +53,21 @@ def check_resource_dict(function): ) -def check_resource_dict_is_empty(resource_dict): +def check_resource_dict_is_empty(resource_dict: dict): if len(resource_dict) > 0: raise ValueError( "When block_allocation is enabled, the resource requirements have to be defined on the executor level." ) -def check_refresh_rate(refresh_rate): +def check_refresh_rate(refresh_rate: float): if refresh_rate != 0.01: raise ValueError( "The sleep_interval parameter is only used when disable_dependencies=False." ) -def check_backend(backend): +def check_backend(backend: str): if backend not in ["auto", "mpi", "slurm", "flux"]: raise ValueError( 'The currently implemented backends are ["flux", "mpi", "slurm"]. ' @@ -75,12 +77,12 @@ def check_backend(backend): ) -def check_init_function(block_allocation, init_function): +def check_init_function(block_allocation: bool, init_function: callable): if not block_allocation and init_function is not None: raise ValueError("") -def validate_number_of_cores(max_cores, max_workers): +def validate_number_of_cores(max_cores: int, max_workers: int) -> int: # only overwrite max_cores when it is set to 1 if max_workers != 1 and max_cores == 1: return max_workers