Skip to content

Commit

Permalink
[hailctl] Autocomplete for hailctl config {get,set,unset} (hail-is#13224
Browse files Browse the repository at this point in the history
)

To get this to work, I had to run the following command:
```
hailctl --install-completion zsh
```

And then add this line to the end of the ~/.zshrc file: `autoload -Uz
compinit && compinit`
  • Loading branch information
jigold authored and Sophie Parsa committed Aug 15, 2023
1 parent 0622d35 commit dafc8c8
Show file tree
Hide file tree
Showing 15 changed files with 324 additions and 72 deletions.
4 changes: 2 additions & 2 deletions hail/python/hail/backend/backend.py
Expand Up @@ -4,7 +4,7 @@
import pkg_resources
import zipfile

from hailtop.config.user_config import configuration_of
from hailtop.config.user_config import unchecked_configuration_of
from hailtop.fs.fs import FS

from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS
Expand Down Expand Up @@ -197,7 +197,7 @@ def persist_expression(self, expr: Expression) -> Expression:

def _initialize_flags(self, initial_flags: Dict[str, str]) -> None:
self.set_flags(**{
k: configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar)
k: unchecked_configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar)
for k, (deprecated_envvar, default) in Backend._flags_env_vars_and_defaults.items()
if k not in initial_flags
}, **initial_flags)
Expand Down
20 changes: 10 additions & 10 deletions hail/python/hail/backend/service_backend.py
Expand Up @@ -19,7 +19,7 @@
from hail.ir.renderer import CSERenderer

from hailtop import yamlx
from hailtop.config import (configuration_of, get_remote_tmpdir)
from hailtop.config import (ConfigVariable, configuration_of, get_remote_tmpdir)
from hailtop.utils import async_to_blocking, secret_alnum_string, TransientError, Timings, am_i_interactive, retry_transient_errors
from hailtop.utils.rich_progress_bar import BatchProgressBar
from hailtop.batch_client import client as hb
Expand Down Expand Up @@ -206,7 +206,7 @@ async def create(*,
token: Optional[str] = None,
regions: Optional[List[str]] = None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None):
billing_project = configuration_of('batch', 'billing_project', billing_project, None)
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
if billing_project is None:
raise ValueError(
"No billing project. Call 'init_batch' with the billing "
Expand All @@ -224,17 +224,17 @@ async def create(*,
batch_attributes: Dict[str, str] = dict()
remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir)

jar_url = configuration_of('query', 'jar_url', jar_url, None)
jar_url = configuration_of(ConfigVariable.QUERY_JAR_URL, jar_url, None)
jar_spec = GitRevision(revision()) if jar_url is None else JarUrl(jar_url)

driver_cores = configuration_of('query', 'batch_driver_cores', driver_cores, None)
driver_memory = configuration_of('query', 'batch_driver_memory', driver_memory, None)
worker_cores = configuration_of('query', 'batch_worker_cores', worker_cores, None)
worker_memory = configuration_of('query', 'batch_worker_memory', worker_memory, None)
name_prefix = configuration_of('query', 'name_prefix', name_prefix, '')
driver_cores = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_CORES, driver_cores, None)
driver_memory = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_MEMORY, driver_memory, None)
worker_cores = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_CORES, worker_cores, None)
worker_memory = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_MEMORY, worker_memory, None)
name_prefix = configuration_of(ConfigVariable.QUERY_NAME_PREFIX, name_prefix, '')

if regions is None:
regions_from_conf = configuration_of('batch', 'regions', regions, None)
regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, regions, None)
if regions_from_conf is not None:
assert isinstance(regions_from_conf, str)
regions = regions_from_conf.split(',')
Expand All @@ -245,7 +245,7 @@ async def create(*,
assert len(regions) > 0, regions

if disable_progress_bar is None:
disable_progress_bar_str = configuration_of('query', 'disable_progress_bar', None, None)
disable_progress_bar_str = configuration_of(ConfigVariable.QUERY_DISABLE_PROGRESS_BAR, None, None)
if disable_progress_bar_str is None:
disable_progress_bar = not am_i_interactive()
else:
Expand Down
11 changes: 11 additions & 0 deletions hail/python/hail/docs/install/macosx.rst
Expand Up @@ -14,3 +14,14 @@ Install Hail on Mac OS X
- Install Python 3.9 or later. We recommend `Miniconda <https://docs.conda.io/en/latest/miniconda.html#macosx-installers>`__.
- Open Terminal.app and execute ``pip install hail``. If this command fails with a message about "Rust", please try this instead: ``pip install hail --only-binary=:all:``.
- `Run your first Hail query! <try.rst>`__

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
hailctl Autocompletion (Optional)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Install autocompletion with ``hailctl --install-completion zsh``
- Ensure this line is in your zsh config file (~/.zshrc) and then reload your terminal.

.. code-block::
autoload -Uz compinit && compinit
4 changes: 2 additions & 2 deletions hail/python/hail/utils/java.py
Expand Up @@ -2,11 +2,11 @@
import sys
import re

from hailtop.config import configuration_of
from hailtop.config import ConfigVariable, configuration_of


def choose_backend(backend: Optional[str] = None) -> str:
return configuration_of('query', 'backend', backend, 'spark')
return configuration_of(ConfigVariable.QUERY_BACKEND, backend, 'spark')


class HailUserError(Exception):
Expand Down
6 changes: 3 additions & 3 deletions hail/python/hailtop/aiocloud/aiogoogle/user_config.py
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass


from hailtop.config.user_config import configuration_of
from hailtop.config.user_config import ConfigVariable, configuration_of


GCSRequesterPaysConfiguration = Union[str, Tuple[str, List[str]]]
Expand All @@ -19,8 +19,8 @@ def get_gcs_requester_pays_configuration(
if gcs_requester_pays_configuration:
return gcs_requester_pays_configuration

project = configuration_of('gcs_requester_pays', 'project', None, None)
buckets = configuration_of('gcs_requester_pays', 'buckets', None, None)
project = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None)
buckets = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None)

spark_conf = get_spark_conf_gcs_requester_pays_configuration()

Expand Down
6 changes: 3 additions & 3 deletions hail/python/hailtop/batch/backend.py
Expand Up @@ -15,7 +15,7 @@
from rich.progress import track

from hailtop import pip_version
from hailtop.config import configuration_of, get_deploy_config, get_remote_tmpdir
from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir
from hailtop.utils.rich_progress_bar import SimpleRichProgressBar
from hailtop.utils import parse_docker_image_reference, async_to_blocking, bounded_gather, url_scheme
from hailtop.batch.hail_genetics_images import HAIL_GENETICS_IMAGES, hailgenetics_python_dill_image_for_current_python_version
Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(self,
warnings.warn('Use of deprecated positional argument \'bucket\' in ServiceBackend(). Specify \'bucket\' as a keyword argument instead.')
bucket = args[1]

billing_project = configuration_of('batch', 'billing_project', billing_project, None)
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
if billing_project is None:
raise ValueError(
'the billing_project parameter of ServiceBackend must be set '
Expand All @@ -501,7 +501,7 @@ def __init__(self,
self.__fs: RouterAsyncFS = RouterAsyncFS(gcs_kwargs=gcs_kwargs)

if regions is None:
regions_from_conf = configuration_of('batch', 'regions', None, None)
regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, None, None)
if regions_from_conf is not None:
assert isinstance(regions_from_conf, str)
regions = regions_from_conf.split(',')
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hailtop/batch/batch.py
Expand Up @@ -10,7 +10,7 @@
from hailtop.aiocloud.aioazure.fs import AzureAsyncFS
from hailtop.aiotools.router_fs import RouterAsyncFS
import hailtop.batch_client.client as _bc
from hailtop.config import configuration_of
from hailtop.config import ConfigVariable, configuration_of

from . import backend as _backend, job, resource as _resource # pylint: disable=cyclic-import
from .exceptions import BatchException
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self,
if backend:
self._backend = backend
else:
backend_config = configuration_of('batch', 'backend', None, 'local')
backend_config = configuration_of(ConfigVariable.BATCH_BACKEND, None, 'local')
if backend_config == 'service':
self._backend = _backend.ServiceBackend()
else:
Expand Down
4 changes: 3 additions & 1 deletion hail/python/hailtop/config/__init__.py
@@ -1,12 +1,14 @@
from .user_config import (get_user_config, get_user_config_path,
get_remote_tmpdir, configuration_of)
from .deploy_config import get_deploy_config, DeployConfig
from .variables import ConfigVariable

__all__ = [
'get_deploy_config',
'get_user_config',
'get_user_config_path',
'get_remote_tmpdir',
'DeployConfig',
'configuration_of'
'ConfigVariable',
'configuration_of',
]
32 changes: 22 additions & 10 deletions hail/python/hailtop/config/user_config.py
Expand Up @@ -5,6 +5,8 @@
import warnings
from pathlib import Path

from .variables import ConfigVariable

user_config = None


Expand Down Expand Up @@ -36,15 +38,12 @@ def get_user_config() -> configparser.ConfigParser:
T = TypeVar('T')


def configuration_of(section: str,
option: str,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
assert VALID_SECTION_AND_OPTION_RE.fullmatch(section), (section, option)
assert VALID_SECTION_AND_OPTION_RE.fullmatch(option), (section, option)

def unchecked_configuration_of(section: str,
option: str,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
if explicit_argument is not None:
return explicit_argument

Expand All @@ -69,6 +68,19 @@ def configuration_of(section: str,
return fallback


def configuration_of(config_variable: ConfigVariable,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
if '/' in config_variable.value:
section, option = config_variable.value.split('/')
else:
section = 'global'
option = config_variable.value
return unchecked_configuration_of(section, option, explicit_argument, fallback, deprecated_envvar=deprecated_envvar)


def get_remote_tmpdir(caller_name: str,
*,
bucket: Optional[str] = None,
Expand All @@ -87,7 +99,7 @@ def get_remote_tmpdir(caller_name: str,
raise ValueError(f'Cannot specify both \'remote_tmpdir\' and \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.')

if bucket is None and remote_tmpdir is None:
remote_tmpdir = configuration_of('batch', 'remote_tmpdir', None, None)
remote_tmpdir = configuration_of(ConfigVariable.BATCH_REMOTE_TMPDIR, None, None)

if remote_tmpdir is None:
if bucket is None:
Expand Down
20 changes: 20 additions & 0 deletions hail/python/hailtop/config/variables.py
@@ -0,0 +1,20 @@
from enum import Enum


class ConfigVariable(str, Enum):
DOMAIN = 'domain'
GCS_REQUESTER_PAYS_PROJECT = 'gcs_requester_pays/project'
GCS_REQUESTER_PAYS_BUCKETS = 'gcs_requester_pays/buckets'
BATCH_BUCKET = 'batch/bucket'
BATCH_REMOTE_TMPDIR = 'batch/remote_tmpdir'
BATCH_REGIONS = 'batch/regions'
BATCH_BILLING_PROJECT = 'batch/billing_project'
BATCH_BACKEND = 'batch/backend'
QUERY_BACKEND = 'query/backend'
QUERY_JAR_URL = 'query/jar_url'
QUERY_BATCH_DRIVER_CORES = 'query/batch_driver_cores'
QUERY_BATCH_WORKER_CORES = 'query/batch_worker_cores'
QUERY_BATCH_DRIVER_MEMORY = 'query/batch_driver_memory'
QUERY_BATCH_WORKER_MEMORY = 'query/batch_worker_memory'
QUERY_NAME_PREFIX = 'query/name_prefix'
QUERY_DISABLE_PROGRESS_BAR = 'query/disable_progress_bar'
73 changes: 48 additions & 25 deletions hail/python/hailtop/hailctl/config/cli.py
@@ -1,13 +1,15 @@
import os
import sys
import re
import warnings

from typing import Optional, Tuple, Annotated as Ann
from rich import print

import typer
from typer import Argument as Arg

from hailtop.config.variables import ConfigVariable
from .config_variables import config_variables


app = typer.Typer(
name='config',
Expand All @@ -29,7 +31,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]:
from the configuration parameter, for example: "batch/billing_project".
Parameters may also have no slashes, indicating the parameter is a global
parameter, for example: "email".
parameter, for example: "domain".
A parameter with more than one slash is invalid, for example:
"batch/billing/project".
Expand All @@ -41,35 +43,32 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]:
sys.exit(1)


def complete_config_variable(incomplete: str):
for var, var_info in config_variables().items():
if var.value.startswith(incomplete):
yield (var.value, var_info.help_msg)


@app.command()
def set(parameter: str, value: str):
def set(parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", autocompletion=complete_config_variable)], value: str):
'''Set a Hail configuration parameter.'''
from hailtop.aiotools.router_fs import RouterAsyncFS # pylint: disable=import-outside-toplevel
from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel

config = get_user_config()
config_file = get_user_config_path()
section, key, path = get_section_key_path(parameter)
if parameter not in config_variables():
print(f"Error: unknown parameter {parameter!r}", file=sys.stderr)
sys.exit(1)

validations = {
('batch', 'bucket'): (
lambda x: re.fullmatch(r'[^:/\s]+', x) is not None,
'should be valid Google Bucket identifier, with no gs:// prefix',
),
('batch', 'remote_tmpdir'): (
RouterAsyncFS.valid_url,
'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/',
),
('email',): (lambda x: re.fullmatch(r'.+@.+', x) is not None, 'should be valid email address'),
}
section, key, _ = get_section_key_path(parameter.value)

config_variable_info = config_variables()[parameter]
validation_func, error_msg = config_variable_info.validation

validation_func, msg = validations.get(path, (lambda _: True, '')) # type: ignore
if not validation_func(value):
print(f"Error: bad value {value!r} for parameter {parameter!r} {msg}", file=sys.stderr)
print(f"Error: bad value {value!r} for parameter {parameter!r} {error_msg}", file=sys.stderr)
sys.exit(1)

if path == ('batch', 'bucket'):
warnings.warn("'batch/bucket' has been deprecated. Use 'batch/remote_tmpdir' instead.")
config = get_user_config()
config_file = get_user_config_path()

if section not in config:
config[section] = {}
Expand All @@ -84,8 +83,30 @@ def set(parameter: str, value: str):
config.write(f)


def get_config_variable(incomplete: str):
from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel

config = get_user_config()

elements = []
for section_name, section in config.items():
for item_name, value in section.items():
if section_name == 'global':
path = item_name
else:
path = f'{section_name}/{item_name}'
elements.append((path, value))

config_items = {var.name: var_info.help_msg for var, var_info in config_variables().items()}

for name, _ in elements:
if name.startswith(incomplete):
help_msg = config_items.get(name)
yield (name, help_msg)


@app.command()
def unset(parameter: str):
def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autocompletion=get_config_variable)]):
'''Unset a Hail configuration parameter (restore to default behavior).'''
from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel

Expand All @@ -96,10 +117,12 @@ def unset(parameter: str):
del config[section][key]
with open(config_file, 'w', encoding='utf-8') as f:
config.write(f)
else:
print(f"WARNING: Unknown parameter {parameter!r}", file=sys.stderr)


@app.command()
def get(parameter: str):
def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocompletion=get_config_variable)]):
'''Get the value of a Hail configuration parameter.'''
from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel

Expand Down

0 comments on commit dafc8c8

Please sign in to comment.