Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/kepler_model/estimate/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from kepler_model.estimate.model.model import load_downloaded_model
from kepler_model.estimate.model_server_connector import is_model_server_enabled, make_request
from kepler_model.train.profiler.node_type_index import NodeTypeSpec, discover_spec_values, get_machine_spec
from kepler_model.util.config import SERVE_SOCKET, download_path, set_env_from_model_config
from kepler_model.util.config import CONFIG_PATH, SERVE_SOCKET, download_path, set_env_from_model_config, set_config_dir
from kepler_model.util.loader import get_download_output_path, load_metadata
from kepler_model.util.train_types import ModelOutputType, convert_enery_source, is_output_type_supported

Expand Down Expand Up @@ -185,7 +185,14 @@ def sig_handler(signum, frame) -> None:
type=click.Path(exists=True),
required=False,
)
def run(log_level: str, machine_spec: str):
@click.option(
"--config-dir",
"-c",
type=click.Path(exists=False, dir_okay=True, file_okay=False),
default=CONFIG_PATH,
required=False,
)
def run(log_level: str, machine_spec: str, config_dir: str) -> int:
level = getattr(logging, log_level.upper())
logging.basicConfig(
level=level,
Expand All @@ -194,6 +201,8 @@ def run(log_level: str, machine_spec: str):
)

logger.info("starting estimator")
set_config_dir(config_dir)

set_env_from_model_config()
clean_socket()
signal.signal(signal.SIGTERM, sig_handler)
Expand Down
20 changes: 18 additions & 2 deletions src/kepler_model/server/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from kepler_model.train import NodeTypeIndexCollection, NodeTypeSpec
from kepler_model.util.config import (
CONFIG_PATH,
ERROR_KEY,
MODEL_SERVER_MODEL_LIST_PATH,
MODEL_SERVER_MODEL_REQ_PATH,
download_path,
getConfig,
initial_pipeline_urls,
model_toppath,
set_config_dir,
)
from kepler_model.util.loader import (
CHECKPOINT_FOLDERNAME,
Expand Down Expand Up @@ -430,11 +432,25 @@ def fill_machine_spec():
default="info",
required=False,
)
def run(log_level: str):
@click.option(
"--config-dir",
"-c",
type=click.Path(exists=False, dir_okay=True, file_okay=False),
default=CONFIG_PATH,
required=False,
)
def run(log_level: str, config_dir: str) -> int:
level = getattr(logging, log_level.upper())
logging.basicConfig(level=level)
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

set_config_dir(config_dir)
load_init_pipeline()
app.run(host="0.0.0.0", port=MODEL_SERVER_PORT)
return 0


if __name__ == "__main__":
Expand Down
17 changes: 13 additions & 4 deletions src/kepler_model/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
SERVE_SOCKET = "/tmp/estimator.sock"


def set_config_dir(config_dir: str):
global CONFIG_PATH
CONFIG_PATH = config_dir


def getConfig(key: str, default):
# check configmap path
file = os.path.join(CONFIG_PATH, key)
Expand All @@ -74,8 +79,6 @@ def getPath(subpath):
# use local path if not exists or cannot write
MNT_PATH = os.path.join(os.path.dirname(__file__), "..")

CONFIG_PATH = getConfig("CONFIG_PATH", CONFIG_PATH)

model_topurl = getConfig("MODEL_TOPURL", base_model_url)
initial_pipeline_urls = getConfig("INITIAL_PIPELINE_URL", "")
if initial_pipeline_urls == "":
Expand Down Expand Up @@ -123,10 +126,16 @@ def set_env_from_model_config():
return

for line in model_config.splitlines():
splits = line.split("=")
line = line.strip()
# ignore comments and blanks
if not line or line.startswith("#"):
continue

# pick only the first part until # and ignore the rest
splits = line.split("#")[0].strip().split("=")
if len(splits) > 1:
os.environ[splits[0].strip()] = splits[1].strip()
logging.info(f"set {splits[0]} to {splits[1]}.")
logging.info(f"set env {splits[0]} to '{splits[1]}'.")


def is_estimator_enable(prefix):
Expand Down
25 changes: 0 additions & 25 deletions tests/query_test.py

This file was deleted.

Loading