Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 71 additions & 203 deletions src/hpc_connect/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,139 +5,44 @@
import logging
import math
import os
import shlex
import shutil
import sys
from collections.abc import ValuesView
from functools import cached_property
from typing import IO
from typing import Any

import pluggy
import psutil
import schema
import yaml
from schema import Optional
from schema import Or
from schema import Schema
from schema import Use

from .discover import default_resource_set
from .pluginmanager import HPCConnectPluginManager
from .schemas import config_schema
from .schemas import environment_variable_schema
from .schemas import launch_schema
from .schemas import machine_schema
from .schemas import submit_schema
from .util import collections
from .util import safe_loads
from .util.string import strip_quotes

logger = logging.getLogger("hpc_connect")


def flag_splitter(arg: list[str] | str) -> list[str]:
if isinstance(arg, str):
return shlex.split(arg)
elif not isinstance(arg, list) and not all(isinstance(str, _) for _ in arg):
raise ValueError("expected list[str]")
return arg


def dict_str_str(arg: Any) -> bool:
f = isinstance
return f(arg, dict) and all([f(_, str) for k, v in arg.items() for _ in (k, v)])


class choose_from:
def __init__(self, *choices: str | None):
self.choices = set(choices)

def __call__(self, arg: str | None) -> str | None:
if arg not in self.choices:
raise ValueError(f"Invalid choice {arg!r}, choose from {self.choices!r}")
return arg


def which(arg: str) -> str:
if path := shutil.which(arg):
return path
logger.debug(f"{arg} not found on PATH")
return arg


# Resource spec have the following form:
# machine:
# resources:
# - type: node
# count: node_count
# resources:
# - type: socket
# count: sockets_per_node
# resources:
# - type: resource_name (like cpus)
# count: type_per_socket
# additional_properties: (optional)
# - type: slots
# count: 1

resource_spec = {
"type": "node",
"count": int,
Optional("additional_properties"): Or(dict, None),
"resources": [
{
"type": str,
"count": int,
Optional("additional_properties"): Or(dict, None),
Optional("resources"): [
{
"type": str,
"count": int,
Optional("additional_properties"): Or(dict, None),
},
],
},
],
section_schemas: dict[str, schema.Schema] = {
"config": config_schema,
"machine": machine_schema,
"submit": submit_schema,
"launch": launch_schema,
}


launch_spec = {
Optional("numproc_flag"): str,
Optional("default_options"): Use(flag_splitter),
Optional("local_options"): Use(flag_splitter),
Optional("pre_options"): Use(flag_splitter),
Optional("mappings"): dict_str_str,
}

schema = Schema(
{
"hpc_connect": {
Optional("config"): {
Optional("debug"): bool,
},
Optional("submit"): {
Optional("backend"): Use(
choose_from(None, "shell", "slurm", "sbatch", "pbs", "qsub", "flux")
),
Optional("default_options"): Use(flag_splitter),
Optional(str): {
Optional("default_options"): Use(flag_splitter),
},
},
Optional("machine"): {
Optional("resources"): Or([resource_spec], None),
},
Optional("launch"): {
Optional("exec"): Use(which),
**launch_spec,
Optional(str): launch_spec,
},
}
},
ignore_extra_keys=True,
description="HPC connect configuration schema",
)


class ConfigScope:
def __init__(self, name: str, file: str | None, data: dict[str, Any]) -> None:
self.name = name
self.file = file
self.data = schema.validate({"hpc_connect": data})["hpc_connect"]
self.data: dict[str, Any] = {}
for section, data in data.items():
schema = section_schemas[section]
self.data[section] = schema.validate(data)

def __repr__(self):
file = self.file or "<none>"
Expand All @@ -151,44 +56,50 @@ def __eq__(self, other):
def __iter__(self):
return iter(self.data)

def __contains__(self, section: str) -> bool:
return section in self.data

def get_section(self, section: str) -> Any:
return self.data.get(section)

def pop_section(self, section: str) -> Any:
return self.data.pop(section, None)

def dump(self) -> None:
if self.file is None:
return
with open(self.file, "w") as fh:
yaml.dump({"hpc_connect": self.data}, fh, default_flow_style=False)


config_defaults = {
"config": {
"debug": False,
},
"machine": {
"resources": None,
},
"submit": {
"backend": None,
"default_options": [],
},
"launch": {
"exec": "mpiexec",
"numproc_flag": "-n",
"default_options": [],
"local_options": [],
"pre_options": [],
"mappings": {},
},
}


class Config:
def __init__(self) -> None:
self.pluginmanager: pluggy.PluginManager = HPCConnectPluginManager()
self.scopes: dict[str, ConfigScope] = {
"defaults": ConfigScope("defaults", None, config_defaults)
self.pluginmanager: HPCConnectPluginManager = HPCConnectPluginManager()
rspec = self.pluginmanager.hook.hpc_connect_discover_resources()
defaults = {
"config": {
"debug": False,
"plugins": [],
},
"machine": {
"resources": rspec,
},
"submit": {
"backend": None,
"default_options": [],
},
"launch": {
"exec": "mpiexec",
"numproc_flag": "-n",
"default_options": [],
"local_options": [],
"pre_options": [],
"mappings": {},
},
}
self.scopes: dict[str, ConfigScope] = {}
default_scope = ConfigScope("defaults", None, defaults)
self.push_scope(default_scope)
for scope in ("site", "global", "local"):
config_scope = read_config_scope(scope)
self.push_scope(config_scope)
Expand All @@ -202,6 +113,10 @@ def read_only_scope(self, scope: str) -> bool:

def push_scope(self, scope: ConfigScope) -> None:
self.scopes[scope.name] = scope
if cfg := scope.get_section("config"):
if plugins := cfg.get("plugins"):
for f in plugins:
self.pluginmanager.consider_plugin(f)

def pop_scope(self, scope: ConfigScope) -> ConfigScope | None:
return self.scopes.pop(scope.name, None)
Expand Down Expand Up @@ -235,6 +150,14 @@ def get(self, path: str, default: Any = None, scope: str | None = None) -> Any:
value = value[key]
return value

def get_highest_priority(self, path: str, default: Any = None) -> tuple[Any, str]:
sentinel = object()
for scope in reversed(self.scopes.keys()):
value = self.get(path, default=sentinel, scope=scope)
if value is not sentinel:
return value, scope
return default, "none"

def set(self, path: str, value: Any, scope: str | None = None) -> None:
parts = process_config_path(path)
section = parts.pop(0)
Expand Down Expand Up @@ -336,18 +259,12 @@ def set_main_options(self, args: argparse.Namespace) -> None:

@property
def resource_specs(self) -> list[dict]:
from .submit import factory

if resource_specs := self.get("machine:resources"):
return resource_specs
if self.get("submit:backend"):
# backend may set resources
factory(config=self)
if resource_specs := self.get("machine:resources"):
return resource_specs
resource_specs = default_resource_spec()
self.set("machine:resources", resource_specs, scope="defaults")
return resource_specs
specs, _ = self.get_highest_priority("machine:resources")
if specs is not None:
return specs
resources = default_resource_set()
self.set("machine:resources", specs, scope="defaults")
return resources

def resource_types(self) -> list[str]:
"""Return the types of resources available"""
Expand Down Expand Up @@ -486,20 +403,15 @@ def compute_required_resources(
return reqd_resources

def dump(self, stream: IO[Any], scope: str | None = None, **kwargs: Any) -> None:
from .submit import factory

# initialize the resource spec
if self.get("machine:resources") is None:
if self.get("submit:backend"):
factory(self)
if not self.get("machine:resources"):
self.set("machine:resources", default_resource_spec(), scope="defaults")
data: dict[str, Any] = {}
for section in self.scopes["defaults"]:
if section == "machine":
continue
section_data = self.get_config(section, scope=scope)
if not section_data and scope is not None:
continue
data[section] = section_data
data.setdefault("machine", {})["resources"] = self.resource_specs
yaml.dump({"hpc_connect": data}, stream, **kwargs)


Expand Down Expand Up @@ -532,32 +444,10 @@ def get_scope_filename(scope: str) -> str | None:


def read_env_config() -> ConfigScope | None:
def load_mappings(arg: str) -> dict[str, str]:
mappings: dict[str, str] = {}
for kv in arg.split(","):
k, v = [_.strip() for _ in kv.split(":") if _.split()]
mappings[k] = v
return mappings

data: dict[str, Any] = {}
for var in os.environ:
if not var.startswith("HPCC_"):
continue
try:
section, *parts = var[5:].lower().split("_")
key = "_".join(parts)
except ValueError:
continue
if section not in config_defaults:
continue
value: Any
if key == "mappings":
value = load_mappings(os.environ[var])
else:
value = safe_loads(os.environ[var])
data.setdefault(section, {}).update({key: value})
if not data:
variables = {key: var for key, var in os.environ.items() if key.startswith("HPC_CONNECT_")}
if not variables:
return None
data = environment_variable_schema.validate(variables)
return ConfigScope("environment", None, data)


Expand Down Expand Up @@ -609,25 +499,3 @@ def set_logging_level(levelname: str) -> None:
for h in logger.handlers:
h.setLevel(level)
logger.setLevel(level)


def default_resource_spec() -> list[dict]:
resource_spec: list[dict] = [
{
"type": "node",
"count": 1,
"resources": [
{
"type": "socket",
"count": 1,
"resources": [
{
"type": "cpu",
"count": psutil.cpu_count(),
},
],
},
],
}
]
return resource_spec
30 changes: 30 additions & 0 deletions src/hpc_connect/discover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright NTESS. See COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

import fnmatch
import json
import os
from typing import Any

import psutil

from .hookspec import hookimpl


def default_resource_set() -> list[dict[str, Any]]:
local_resource = {"type": "cpu", "count": psutil.cpu_count()}
socket_resource = {"type": "socket", "count": 1, "resources": [local_resource]}
return [{"type": "node", "count": 1, "resources": [socket_resource]}]


@hookimpl(tryfirst=True, specname="hpc_connect_discover_resources")
def read_resources_from_hostfile() -> dict[str, list] | None:
if file := os.getenv("HPC_CONNECT_HOSTFILE"):
with open(file) as fh:
data = json.load(fh)
host: str = os.getenv("HPC_CONNECT_HOSTNAME") or os.uname().nodename
for pattern, rspec in data.items():
if fnmatch.fnmatch(host, pattern):
return rspec
return None
Loading