From dafc8c895f58cd228256fc58b96d5813ede063ad Mon Sep 17 00:00:00 2001 From: jigold Date: Mon, 14 Aug 2023 16:31:20 -0400 Subject: [PATCH] [hailctl] Autocomplete for hailctl config {get,set,unset} (#13224) To get this to work, I had to run the following command: ``` hailctl --install-completion zsh ``` And then add this line to the end of the ~/.zshrc file: `autoload -Uz compinit && compinit` --- hail/python/hail/backend/backend.py | 4 +- hail/python/hail/backend/service_backend.py | 20 ++-- hail/python/hail/docs/install/macosx.rst | 11 +++ hail/python/hail/utils/java.py | 4 +- .../hailtop/aiocloud/aiogoogle/user_config.py | 6 +- hail/python/hailtop/batch/backend.py | 6 +- hail/python/hailtop/batch/batch.py | 4 +- hail/python/hailtop/config/__init__.py | 4 +- hail/python/hailtop/config/user_config.py | 32 +++++-- hail/python/hailtop/config/variables.py | 20 ++++ hail/python/hailtop/hailctl/config/cli.py | 73 +++++++++----- .../hailctl/config/config_variables.py | 94 +++++++++++++++++++ .../test/hailtop/hailctl/config/conftest.py | 32 +++++++ .../test/hailtop/hailctl/config/test_cli.py | 77 +++++++++++++-- hail/scripts/test_requester_pays_parsing.py | 9 +- 15 files changed, 324 insertions(+), 72 deletions(-) create mode 100644 hail/python/hailtop/config/variables.py create mode 100644 hail/python/hailtop/hailctl/config/config_variables.py diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 2a93ad631b16..d193b2fb8774 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -4,7 +4,7 @@ import pkg_resources import zipfile -from hailtop.config.user_config import configuration_of +from hailtop.config.user_config import unchecked_configuration_of from hailtop.fs.fs import FS from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS @@ -197,7 +197,7 @@ def persist_expression(self, expr: Expression) -> Expression: def _initialize_flags(self, initial_flags: Dict[str, str]) -> None: self.set_flags(**{ - k: configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar) + k: unchecked_configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar) for k, (deprecated_envvar, default) in Backend._flags_env_vars_and_defaults.items() if k not in initial_flags }, **initial_flags) diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 4c7d9aa8c8c7..66a7823c0397 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -19,7 +19,7 @@ from hail.ir.renderer import CSERenderer from hailtop import yamlx -from hailtop.config import (configuration_of, get_remote_tmpdir) +from hailtop.config import (ConfigVariable, configuration_of, get_remote_tmpdir) from hailtop.utils import async_to_blocking, secret_alnum_string, TransientError, Timings, am_i_interactive, retry_transient_errors from hailtop.utils.rich_progress_bar import BatchProgressBar from hailtop.batch_client import client as hb @@ -206,7 +206,7 @@ async def create(*, token: Optional[str] = None, regions: Optional[List[str]] = None, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None): - billing_project = configuration_of('batch', 'billing_project', billing_project, None) + billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None) if billing_project is None: raise ValueError( "No billing project. Call 'init_batch' with the billing " @@ -224,17 +224,17 @@ async def create(*, batch_attributes: Dict[str, str] = dict() remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir) - jar_url = configuration_of('query', 'jar_url', jar_url, None) + jar_url = configuration_of(ConfigVariable.QUERY_JAR_URL, jar_url, None) jar_spec = GitRevision(revision()) if jar_url is None else JarUrl(jar_url) - driver_cores = configuration_of('query', 'batch_driver_cores', driver_cores, None) - driver_memory = configuration_of('query', 'batch_driver_memory', driver_memory, None) - worker_cores = configuration_of('query', 'batch_worker_cores', worker_cores, None) - worker_memory = configuration_of('query', 'batch_worker_memory', worker_memory, None) - name_prefix = configuration_of('query', 'name_prefix', name_prefix, '') + driver_cores = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_CORES, driver_cores, None) + driver_memory = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_MEMORY, driver_memory, None) + worker_cores = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_CORES, worker_cores, None) + worker_memory = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_MEMORY, worker_memory, None) + name_prefix = configuration_of(ConfigVariable.QUERY_NAME_PREFIX, name_prefix, '') if regions is None: - regions_from_conf = configuration_of('batch', 'regions', regions, None) + regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, regions, None) if regions_from_conf is not None: assert isinstance(regions_from_conf, str) regions = regions_from_conf.split(',') @@ -245,7 +245,7 @@ async def create(*, assert len(regions) > 0, regions if disable_progress_bar is None: - disable_progress_bar_str = configuration_of('query', 'disable_progress_bar', None, None) + disable_progress_bar_str = configuration_of(ConfigVariable.QUERY_DISABLE_PROGRESS_BAR, None, None) if disable_progress_bar_str is None: disable_progress_bar = not am_i_interactive() else: diff --git a/hail/python/hail/docs/install/macosx.rst b/hail/python/hail/docs/install/macosx.rst index 5eb3f009e6b6..8b9c8d6b9dca 100644 --- a/hail/python/hail/docs/install/macosx.rst +++ b/hail/python/hail/docs/install/macosx.rst @@ -14,3 +14,14 @@ Install Hail on Mac OS X - Install Python 3.9 or later. We recommend `Miniconda `__. - Open Terminal.app and execute ``pip install hail``. If this command fails with a message about "Rust", please try this instead: ``pip install hail --only-binary=:all:``. - `Run your first Hail query! `__ + +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +hailctl Autocompletion (Optional) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Install autocompletion with ``hailctl --install-completion zsh`` +- Ensure this line is in your zsh config file (~/.zshrc) and then reload your terminal. + + .. code-block:: + + autoload -Uz compinit && compinit diff --git a/hail/python/hail/utils/java.py b/hail/python/hail/utils/java.py index a311884f99c4..3f19d0d6e771 100644 --- a/hail/python/hail/utils/java.py +++ b/hail/python/hail/utils/java.py @@ -2,11 +2,11 @@ import sys import re -from hailtop.config import configuration_of +from hailtop.config import ConfigVariable, configuration_of def choose_backend(backend: Optional[str] = None) -> str: - return configuration_of('query', 'backend', backend, 'spark') + return configuration_of(ConfigVariable.QUERY_BACKEND, backend, 'spark') class HailUserError(Exception): diff --git a/hail/python/hailtop/aiocloud/aiogoogle/user_config.py b/hail/python/hailtop/aiocloud/aiogoogle/user_config.py index b6a713a91e53..53c4e832de37 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/user_config.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/user_config.py @@ -6,7 +6,7 @@ from dataclasses import dataclass -from hailtop.config.user_config import configuration_of +from hailtop.config.user_config import ConfigVariable, configuration_of GCSRequesterPaysConfiguration = Union[str, Tuple[str, List[str]]] @@ -19,8 +19,8 @@ def get_gcs_requester_pays_configuration( if gcs_requester_pays_configuration: return gcs_requester_pays_configuration - project = configuration_of('gcs_requester_pays', 'project', None, None) - buckets = configuration_of('gcs_requester_pays', 'buckets', None, None) + project = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None) + buckets = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None) spark_conf = get_spark_conf_gcs_requester_pays_configuration() diff --git a/hail/python/hailtop/batch/backend.py b/hail/python/hailtop/batch/backend.py index 175aea1eb2c8..ce066f00a9f9 100644 --- a/hail/python/hailtop/batch/backend.py +++ b/hail/python/hailtop/batch/backend.py @@ -15,7 +15,7 @@ from rich.progress import track from hailtop import pip_version -from hailtop.config import configuration_of, get_deploy_config, get_remote_tmpdir +from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir from hailtop.utils.rich_progress_bar import SimpleRichProgressBar from hailtop.utils import parse_docker_image_reference, async_to_blocking, bounded_gather, url_scheme from hailtop.batch.hail_genetics_images import HAIL_GENETICS_IMAGES, hailgenetics_python_dill_image_for_current_python_version @@ -474,7 +474,7 @@ def __init__(self, warnings.warn('Use of deprecated positional argument \'bucket\' in ServiceBackend(). Specify \'bucket\' as a keyword argument instead.') bucket = args[1] - billing_project = configuration_of('batch', 'billing_project', billing_project, None) + billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None) if billing_project is None: raise ValueError( 'the billing_project parameter of ServiceBackend must be set ' @@ -501,7 +501,7 @@ def __init__(self, self.__fs: RouterAsyncFS = RouterAsyncFS(gcs_kwargs=gcs_kwargs) if regions is None: - regions_from_conf = configuration_of('batch', 'regions', None, None) + regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, None, None) if regions_from_conf is not None: assert isinstance(regions_from_conf, str) regions = regions_from_conf.split(',') diff --git a/hail/python/hailtop/batch/batch.py b/hail/python/hailtop/batch/batch.py index e4d63eac114e..8f5dba433b47 100644 --- a/hail/python/hailtop/batch/batch.py +++ b/hail/python/hailtop/batch/batch.py @@ -10,7 +10,7 @@ from hailtop.aiocloud.aioazure.fs import AzureAsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS import hailtop.batch_client.client as _bc -from hailtop.config import configuration_of +from hailtop.config import ConfigVariable, configuration_of from . import backend as _backend, job, resource as _resource # pylint: disable=cyclic-import from .exceptions import BatchException @@ -167,7 +167,7 @@ def __init__(self, if backend: self._backend = backend else: - backend_config = configuration_of('batch', 'backend', None, 'local') + backend_config = configuration_of(ConfigVariable.BATCH_BACKEND, None, 'local') if backend_config == 'service': self._backend = _backend.ServiceBackend() else: diff --git a/hail/python/hailtop/config/__init__.py b/hail/python/hailtop/config/__init__.py index 4f5bc19a168f..e81036748452 100644 --- a/hail/python/hailtop/config/__init__.py +++ b/hail/python/hailtop/config/__init__.py @@ -1,6 +1,7 @@ from .user_config import (get_user_config, get_user_config_path, get_remote_tmpdir, configuration_of) from .deploy_config import get_deploy_config, DeployConfig +from .variables import ConfigVariable __all__ = [ 'get_deploy_config', @@ -8,5 +9,6 @@ 'get_user_config_path', 'get_remote_tmpdir', 'DeployConfig', - 'configuration_of' + 'ConfigVariable', + 'configuration_of', ] diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index b54c075931ba..fcd8c072faa3 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -5,6 +5,8 @@ import warnings from pathlib import Path +from .variables import ConfigVariable + user_config = None @@ -36,15 +38,12 @@ def get_user_config() -> configparser.ConfigParser: T = TypeVar('T') -def configuration_of(section: str, - option: str, - explicit_argument: Optional[T], - fallback: T, - *, - deprecated_envvar: Optional[str] = None) -> Union[str, T]: - assert VALID_SECTION_AND_OPTION_RE.fullmatch(section), (section, option) - assert VALID_SECTION_AND_OPTION_RE.fullmatch(option), (section, option) - +def unchecked_configuration_of(section: str, + option: str, + explicit_argument: Optional[T], + fallback: T, + *, + deprecated_envvar: Optional[str] = None) -> Union[str, T]: if explicit_argument is not None: return explicit_argument @@ -69,6 +68,19 @@ def configuration_of(section: str, return fallback +def configuration_of(config_variable: ConfigVariable, + explicit_argument: Optional[T], + fallback: T, + *, + deprecated_envvar: Optional[str] = None) -> Union[str, T]: + if '/' in config_variable.value: + section, option = config_variable.value.split('/') + else: + section = 'global' + option = config_variable.value + return unchecked_configuration_of(section, option, explicit_argument, fallback, deprecated_envvar=deprecated_envvar) + + def get_remote_tmpdir(caller_name: str, *, bucket: Optional[str] = None, @@ -87,7 +99,7 @@ def get_remote_tmpdir(caller_name: str, raise ValueError(f'Cannot specify both \'remote_tmpdir\' and \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.') if bucket is None and remote_tmpdir is None: - remote_tmpdir = configuration_of('batch', 'remote_tmpdir', None, None) + remote_tmpdir = configuration_of(ConfigVariable.BATCH_REMOTE_TMPDIR, None, None) if remote_tmpdir is None: if bucket is None: diff --git a/hail/python/hailtop/config/variables.py b/hail/python/hailtop/config/variables.py new file mode 100644 index 000000000000..3797679b13c0 --- /dev/null +++ b/hail/python/hailtop/config/variables.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class ConfigVariable(str, Enum): + DOMAIN = 'domain' + GCS_REQUESTER_PAYS_PROJECT = 'gcs_requester_pays/project' + GCS_REQUESTER_PAYS_BUCKETS = 'gcs_requester_pays/buckets' + BATCH_BUCKET = 'batch/bucket' + BATCH_REMOTE_TMPDIR = 'batch/remote_tmpdir' + BATCH_REGIONS = 'batch/regions' + BATCH_BILLING_PROJECT = 'batch/billing_project' + BATCH_BACKEND = 'batch/backend' + QUERY_BACKEND = 'query/backend' + QUERY_JAR_URL = 'query/jar_url' + QUERY_BATCH_DRIVER_CORES = 'query/batch_driver_cores' + QUERY_BATCH_WORKER_CORES = 'query/batch_worker_cores' + QUERY_BATCH_DRIVER_MEMORY = 'query/batch_driver_memory' + QUERY_BATCH_WORKER_MEMORY = 'query/batch_worker_memory' + QUERY_NAME_PREFIX = 'query/name_prefix' + QUERY_DISABLE_PROGRESS_BAR = 'query/disable_progress_bar' diff --git a/hail/python/hailtop/hailctl/config/cli.py b/hail/python/hailtop/hailctl/config/cli.py index 273a078fcd78..b55039f4f310 100644 --- a/hail/python/hailtop/hailctl/config/cli.py +++ b/hail/python/hailtop/hailctl/config/cli.py @@ -1,13 +1,15 @@ import os import sys -import re -import warnings from typing import Optional, Tuple, Annotated as Ann +from rich import print import typer from typer import Argument as Arg +from hailtop.config.variables import ConfigVariable +from .config_variables import config_variables + app = typer.Typer( name='config', @@ -29,7 +31,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: from the configuration parameter, for example: "batch/billing_project". Parameters may also have no slashes, indicating the parameter is a global -parameter, for example: "email". +parameter, for example: "domain". A parameter with more than one slash is invalid, for example: "batch/billing/project". @@ -41,35 +43,32 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: sys.exit(1) +def complete_config_variable(incomplete: str): + for var, var_info in config_variables().items(): + if var.value.startswith(incomplete): + yield (var.value, var_info.help_msg) + + @app.command() -def set(parameter: str, value: str): +def set(parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", autocompletion=complete_config_variable)], value: str): '''Set a Hail configuration parameter.''' - from hailtop.aiotools.router_fs import RouterAsyncFS # pylint: disable=import-outside-toplevel from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel - config = get_user_config() - config_file = get_user_config_path() - section, key, path = get_section_key_path(parameter) + if parameter not in config_variables(): + print(f"Error: unknown parameter {parameter!r}", file=sys.stderr) + sys.exit(1) - validations = { - ('batch', 'bucket'): ( - lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, - 'should be valid Google Bucket identifier, with no gs:// prefix', - ), - ('batch', 'remote_tmpdir'): ( - RouterAsyncFS.valid_url, - 'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/', - ), - ('email',): (lambda x: re.fullmatch(r'.+@.+', x) is not None, 'should be valid email address'), - } + section, key, _ = get_section_key_path(parameter.value) + + config_variable_info = config_variables()[parameter] + validation_func, error_msg = config_variable_info.validation - validation_func, msg = validations.get(path, (lambda _: True, '')) # type: ignore if not validation_func(value): - print(f"Error: bad value {value!r} for parameter {parameter!r} {msg}", file=sys.stderr) + print(f"Error: bad value {value!r} for parameter {parameter!r} {error_msg}", file=sys.stderr) sys.exit(1) - if path == ('batch', 'bucket'): - warnings.warn("'batch/bucket' has been deprecated. Use 'batch/remote_tmpdir' instead.") + config = get_user_config() + config_file = get_user_config_path() if section not in config: config[section] = {} @@ -84,8 +83,30 @@ def set(parameter: str, value: str): config.write(f) +def get_config_variable(incomplete: str): + from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel + + config = get_user_config() + + elements = [] + for section_name, section in config.items(): + for item_name, value in section.items(): + if section_name == 'global': + path = item_name + else: + path = f'{section_name}/{item_name}' + elements.append((path, value)) + + config_items = {var.name: var_info.help_msg for var, var_info in config_variables().items()} + + for name, _ in elements: + if name.startswith(incomplete): + help_msg = config_items.get(name) + yield (name, help_msg) + + @app.command() -def unset(parameter: str): +def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autocompletion=get_config_variable)]): '''Unset a Hail configuration parameter (restore to default behavior).''' from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel @@ -96,10 +117,12 @@ def unset(parameter: str): del config[section][key] with open(config_file, 'w', encoding='utf-8') as f: config.write(f) + else: + print(f"WARNING: Unknown parameter {parameter!r}", file=sys.stderr) @app.command() -def get(parameter: str): +def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocompletion=get_config_variable)]): '''Get the value of a Hail configuration parameter.''' from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel diff --git a/hail/python/hailtop/hailctl/config/config_variables.py b/hail/python/hailtop/hailctl/config/config_variables.py new file mode 100644 index 000000000000..8273f8b126a5 --- /dev/null +++ b/hail/python/hailtop/hailctl/config/config_variables.py @@ -0,0 +1,94 @@ +from collections import namedtuple +import re + +from hailtop.config import ConfigVariable + + +_config_variables = None + +ConfigVariableInfo = namedtuple('ConfigVariableInfo', ['help_msg', 'validation']) + + +def config_variables(): + from hailtop.batch_client.parse import CPU_REGEXPAT, MEMORY_REGEXPAT # pylint: disable=import-outside-toplevel + from hailtop.fs.router_fs import RouterAsyncFS # pylint: disable=import-outside-toplevel + + global _config_variables + + if _config_variables is None: + _config_variables = { + ConfigVariable.DOMAIN: ConfigVariableInfo( + help_msg='Domain of the Batch service', + validation=(lambda x: re.fullmatch(r'.+\..+', x) is not None, 'should be valid domain'), + ), + ConfigVariable.GCS_REQUESTER_PAYS_PROJECT: ConfigVariableInfo( + help_msg='Project when using requester pays buckets in GCS', + validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, 'should be valid GCS project name'), + ), + ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS: ConfigVariableInfo( + help_msg='Allowed buckets when using requester pays in GCS', + validation=( + lambda x: re.fullmatch(r'[^:/\s]+(,[^:/\s]+)*', x) is not None, + 'should be comma separated list of bucket names'), + ), + ConfigVariable.BATCH_BUCKET: ConfigVariableInfo( + help_msg='Deprecated - Name of GCS bucket to use as a temporary scratch directory', + validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, + 'should be valid Google Bucket identifier, with no gs:// prefix'), + ), + ConfigVariable.BATCH_REMOTE_TMPDIR: ConfigVariableInfo( + help_msg='Cloud storage URI to use as a temporary scratch directory', + validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/'), + ), + ConfigVariable.BATCH_REGIONS: ConfigVariableInfo( + help_msg='Comma-separated list of regions to run jobs in', + validation=( + lambda x: re.fullmatch(r'[^\s]+(,[^\s]+)*', x) is not None, 'should be comma separated list of regions'), + ), + ConfigVariable.BATCH_BILLING_PROJECT: ConfigVariableInfo( + help_msg='Batch billing project', + validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, 'should be valid Batch billing project name'), + ), + ConfigVariable.BATCH_BACKEND: ConfigVariableInfo( + help_msg='Backend to use. One of local or service.', + validation=(lambda x: x in ('local', 'service'), 'should be one of "local" or "service"'), + ), + ConfigVariable.QUERY_BACKEND: ConfigVariableInfo( + help_msg='Backend to use for Hail Query. One of spark, local, batch.', + validation=(lambda x: x in ('local', 'spark', 'batch'), 'should be one of "local", "spark", or "batch"'), + ), + ConfigVariable.QUERY_JAR_URL: ConfigVariableInfo( + help_msg='Cloud storage URI to a Query JAR', + validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/jars/sha.jar') + ), + ConfigVariable.QUERY_BATCH_DRIVER_CORES: ConfigVariableInfo( + help_msg='Cores specification for the query driver', + validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, + 'should be an integer which is a power of two from 1 to 16 inclusive'), + ), + ConfigVariable.QUERY_BATCH_WORKER_CORES: ConfigVariableInfo( + help_msg='Cores specification for the query worker', + validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, + 'should be an integer which is a power of two from 1 to 16 inclusive'), + ), + ConfigVariable.QUERY_BATCH_DRIVER_MEMORY: ConfigVariableInfo( + help_msg='Memory specification for the query driver', + validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), + 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'), + ), + ConfigVariable.QUERY_BATCH_WORKER_MEMORY: ConfigVariableInfo( + help_msg='Memory specification for the query worker', + validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), + 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'), + ), + ConfigVariable.QUERY_NAME_PREFIX: ConfigVariableInfo( + help_msg='Name used when displaying query progress in a progress bar', + validation=(lambda x: re.fullmatch(r'[^\s]+', x) is not None, 'should be single word without spaces'), + ), + ConfigVariable.QUERY_DISABLE_PROGRESS_BAR: ConfigVariableInfo( + help_msg='Disable the progress bar with a value of 1. Enable the progress bar with a value of 0', + validation=(lambda x: x in ('0', '1'), 'should be a value of 0 or 1'), + ), + } + + return _config_variables diff --git a/hail/python/test/hailtop/hailctl/config/conftest.py b/hail/python/test/hailtop/hailctl/config/conftest.py index 95ca40e0f4dd..8818cfa3e0c8 100644 --- a/hail/python/test/hailtop/hailctl/config/conftest.py +++ b/hail/python/test/hailtop/hailctl/config/conftest.py @@ -1,3 +1,4 @@ +import os import pytest import tempfile @@ -13,3 +14,34 @@ def config_dir(): @pytest.fixture def runner(config_dir): yield CliRunner(mix_stderr=False, env={'XDG_CONFIG_HOME': config_dir}) + + +@pytest.fixture +def bc_runner(config_dir): + from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel + + # necessary for backwards compatibility test + os.environ['XDG_CONFIG_HOME'] = config_dir + + config = get_user_config() + config_file = get_user_config_path() + + items = [ + ('global', 'email', 'johndoe@gmail.com'), + ('batch', 'foo', '5') + ] + + for section, key, value in items: + if section not in config: + config[section] = {} + config[section][key] = value + + try: + f = open(config_file, 'w', encoding='utf-8') + except FileNotFoundError: + os.makedirs(config_file.parent, exist_ok=True) + f = open(config_file, 'w', encoding='utf-8') + with f: + config.write(f) + + yield CliRunner(mix_stderr=False, env={'XDG_CONFIG_HOME': config_dir}) diff --git a/hail/python/test/hailtop/hailctl/config/test_cli.py b/hail/python/test/hailtop/hailctl/config/test_cli.py index d5f07937d615..3a595aeb81cd 100644 --- a/hail/python/test/hailtop/hailctl/config/test_cli.py +++ b/hail/python/test/hailtop/hailctl/config/test_cli.py @@ -2,7 +2,8 @@ from typer.testing import CliRunner -from hailtop.hailctl.config import cli +from hailtop.config.variables import ConfigVariable +from hailtop.hailctl.config import cli, config_variables def test_config_location(runner: CliRunner, config_dir: str): @@ -20,42 +21,98 @@ def test_config_list_empty_config(runner: CliRunner): @pytest.mark.parametrize( 'name,value', [ - ('batch/backend', 'batch'), + ('domain', 'azure.hail.is'), + ('gcs_requester_pays/project', 'hail-vdc'), + ('gcs_requester_pays/buckets', 'hail,foo'), + ('batch/backend', 'service'), ('batch/billing_project', 'test'), + ('batch/regions', 'us-central1,us-east1'), ('batch/remote_tmpdir', 'gs://foo/bar'), ('query/backend', 'spark'), - - # hailctl currently accepts arbitrary settings - ('foo/bar', 'baz'), + ('query/jar_url', 'gs://foo/bar.jar'), + ('query/batch_driver_cores', '1'), + ('query/batch_worker_cores', '1'), + ('query/batch_driver_memory', '1Gi'), + ('query/batch_worker_memory', 'standard'), + ('query/name_prefix', 'foo'), + ('query/disable_progress_bar', '1'), ], ) -def test_config_set(name: str, value: str, runner: CliRunner): +def test_config_set_get_list_unset(name: str, value: str, runner: CliRunner): runner.invoke(cli.app, ['set', name, value], catch_exceptions=False) res = runner.invoke(cli.app, 'list', catch_exceptions=False) assert res.exit_code == 0 + if '/' not in name: + name = f'global/{name}' assert res.stdout.strip() == f'{name}={value}' res = runner.invoke(cli.app, ['get', name], catch_exceptions=False) assert res.exit_code == 0 assert res.stdout.strip() == value + res = runner.invoke(cli.app, ['unset', name], catch_exceptions=False) + assert res.exit_code == 0 -def test_config_get_bad_names(runner: CliRunner): - res = runner.invoke(cli.app, ['get', 'foo'], catch_exceptions=False) + res = runner.invoke(cli.app, 'list', catch_exceptions=False) assert res.exit_code == 0 assert res.stdout.strip() == '' - res = runner.invoke(cli.app, ['get', '/a/b/c'], catch_exceptions=False) - assert res.exit_code == 1 + +# backwards compatibility +def test_config_get_unknown_names(bc_runner: CliRunner): + res = bc_runner.invoke(cli.app, ['get', 'email'], catch_exceptions=False) + assert res.exit_code == 0 + assert res.stdout.strip() == 'johndoe@gmail.com' + + res = bc_runner.invoke(cli.app, ['get', 'batch/foo'], catch_exceptions=False) + assert res.exit_code == 0 + assert res.stdout.strip() == '5' @pytest.mark.parametrize( 'name,value', [ + ('foo/bar', 'baz'), + ], +) +def test_config_set_unknown_name(name: str, value: str, runner: CliRunner): + res = runner.invoke(cli.app, ['set', name, value], catch_exceptions=False) + assert res.exit_code == 2 + + +def test_config_unset_unknown_name(runner: CliRunner): + # backwards compatibility + res = runner.invoke(cli.app, ['unset', 'foo'], catch_exceptions=False) + assert res.exit_code == 0 + + res = runner.invoke(cli.app, ['unset', 'foo/bar'], catch_exceptions=False) + assert res.exit_code == 0 + + +@pytest.mark.parametrize( + 'name,value', + [ + ('domain', 'foo'), + ('gcs_requester_pays/project', 'gs://foo/bar'), + ('gcs_requester_pays/buckets', 'gs://foo/bar'), + ('batch/backend', 'foo'), + ('batch/billing_project', 'gs://foo/bar'), ('batch/remote_tmpdir', 'asdf://foo/bar'), + ('query/backend', 'random_backend'), + ('query/jar_url', 'bar://foo/bar.jar'), + ('query/batch_driver_cores', 'a'), + ('query/batch_worker_cores', 'b'), + ('query/batch_driver_memory', '1bar'), + ('query/batch_worker_memory', 'random'), + ('query/disable_progress_bar', '2'), ], ) def test_config_set_bad_value(name: str, value: str, runner: CliRunner): res = runner.invoke(cli.app, ['set', name, value], catch_exceptions=False) assert res.exit_code == 1 + + +def test_all_config_variables_in_map(): + for variable in ConfigVariable: + assert variable in config_variables.config_variables() diff --git a/hail/scripts/test_requester_pays_parsing.py b/hail/scripts/test_requester_pays_parsing.py index 4bdc9c2a6a25..d4d11e7699e7 100644 --- a/hail/scripts/test_requester_pays_parsing.py +++ b/hail/scripts/test_requester_pays_parsing.py @@ -4,8 +4,8 @@ from hailtop.aiocloud.aiogoogle import get_gcs_requester_pays_configuration from hailtop.aiocloud.aiogoogle.user_config import spark_conf_path, get_spark_conf_gcs_requester_pays_configuration -from hailtop.utils.process import check_exec_output -from hailtop.config.user_config import configuration_of +from hailtop.utils.process import CalledProcessError, check_exec_output +from hailtop.config.user_config import ConfigVariable, configuration_of if 'YOU_MAY_OVERWRITE_MY_SPARK_DEFAULTS_CONF_AND_HAILCTL_SETTINGS' not in os.environ: @@ -15,6 +15,7 @@ SPARK_CONF_PATH = spark_conf_path() + async def unset_hailctl(): await check_exec_output( 'hailctl', @@ -141,8 +142,8 @@ async def test_hailctl_takes_precedence_1(): actual = get_gcs_requester_pays_configuration() assert actual == 'hailctl_project', str(( - configuration_of('gcs_requester_pays', 'project', None, None), - configuration_of('gcs_requester_pays', 'buckets', None, None), + configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None), + configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None), get_spark_conf_gcs_requester_pays_configuration(), open('/Users/dking/.config/hail/config.ini', 'r').readlines() ))