Skip to content

Commit

Permalink
feat: Switch to use pathlib.Path
Browse files Browse the repository at this point in the history
Switch to use pathlib.Path as the main way to handle path information.
  • Loading branch information
Ari Hartikainen authored and ahartikainen committed Feb 16, 2021
1 parent eea1033 commit c8734a8
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 55 deletions.
62 changes: 33 additions & 29 deletions httpstan/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
Functions in this module manage the Stan model cache and related caches.
"""
import logging
import os
import pathlib
import shutil
import typing
from importlib.machinery import EXTENSION_SUFFIXES
from pathlib import Path

import appdirs

Expand All @@ -16,11 +15,23 @@
logger = logging.getLogger("httpstan")


def model_directory(model_name: str) -> str:
def cache_directory() -> Path:
"""Get httpstan cache path."""
return Path(appdirs.user_cache_dir("httpstan", version=httpstan.__version__))


def model_directory(model_name: str) -> Path:
"""Get the path to a model's directory. Directory may not exist."""
cache_path = appdirs.user_cache_dir("httpstan", version=httpstan.__version__)
model_id = model_name.split("/")[1]
return os.path.join(cache_path, "models", model_id)
return cache_directory() / "models" / model_id


def fit_path(fit_name: str) -> Path:
"""Get the path to a fit file. File may not exist."""
# fit_name structure: cache / models / model_id / fit_id
fit_directory, fit_id = fit_name.rsplit("/", maxsplit=1)
fit_filename = fit_id + ".jsonlines.lz4"
return cache_directory() / fit_directory / fit_filename


def delete_model_directory(model_name: str) -> None:
Expand All @@ -30,30 +41,29 @@ def delete_model_directory(model_name: str) -> None:

def dump_services_extension_module_compiler_output(compiler_output: str, model_name: str) -> None:
"""Dump compiler output from building a model-specific stan::services extension module."""
model_directory_ = pathlib.Path(model_directory(model_name))
model_directory_ = model_directory(model_name)
model_directory_.mkdir(parents=True, exist_ok=True)
with open(model_directory_ / "stderr.log", "w") as fh:
with (model_directory_ / "stderr.log").open("w") as fh:
fh.write(compiler_output)


def load_services_extension_module_compiler_output(model_name: str) -> str:
"""Load compiler output from building a model-specific stan::services extension module."""
# may raise KeyError
model_directory_ = pathlib.Path(model_directory(model_name))
model_directory_ = model_directory(model_name)
if not model_directory_.exists():
raise KeyError(f"Directory for `{model_name}` at `{model_directory}` does not exist.")
with open(model_directory_ / "stderr.log") as fh:
with (model_directory_ / "stderr.log").open() as fh:
return fh.read()


def list_model_names() -> typing.List[str]:
"""Return model names (e.g., `models/dyeicfn2`) for models in cache."""
cache_path = appdirs.user_cache_dir("httpstan", version=httpstan.__version__)
models_directory = pathlib.Path(os.path.join(cache_path, "models"))
models_directory = cache_directory() / "models"
if not models_directory.exists():
return []

def has_extension_suffix(path: pathlib.Path) -> bool:
def has_extension_suffix(path: Path) -> bool:
return path.suffix in EXTENSION_SUFFIXES

model_names = []
Expand All @@ -68,19 +78,19 @@ def has_extension_suffix(path: pathlib.Path) -> bool:

def dump_stanc_warnings(stanc_warnings: str, model_name: str) -> None:
"""Dump stanc warnings associated with a model."""
model_directory_ = pathlib.Path(model_directory(model_name))
model_directory_ = model_directory(model_name)
model_directory_.mkdir(parents=True, exist_ok=True)
with open(model_directory_ / "stanc.log", "w") as fh:
with (model_directory_ / "stanc.log").open("w") as fh:
fh.write(stanc_warnings)


def load_stanc_warnings(model_name: str) -> str:
"""Load stanc output associated with a model."""
# may raise KeyError
model_directory_ = pathlib.Path(model_directory(model_name))
model_directory_ = model_directory(model_name)
if not model_directory_.exists():
raise KeyError(f"Directory for `{model_name}` at `{model_directory}` does not exist.")
with open(model_directory_ / "stanc.log") as fh:
with (model_directory_ / "stanc.log").open() as fh:
return fh.read()


Expand All @@ -94,12 +104,10 @@ def dump_fit(fit_bytes: bytes, name: str) -> None:
name: Stan fit name
fit_bytes: LZ4-compressed messages associated with Stan fit.
"""
cache_path = appdirs.user_cache_dir("httpstan", version=httpstan.__version__)
# fits are stored under their "parent" models
fits_path = os.path.join(*([cache_path] + name.split("/")[:-1]))
fit_filename = os.path.join(fits_path, f'{name.split("/")[-1]}.jsonlines.lz4')
os.makedirs(fits_path, exist_ok=True)
with open(fit_filename, mode="wb") as fh:
path = fit_path(name)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("wb") as fh:
fh.write(fit_bytes)


Expand All @@ -113,12 +121,10 @@ def load_fit(name: str) -> bytes:
Returns
LZ4-compressed messages associated with Stan fit.
"""
cache_path = appdirs.user_cache_dir("httpstan", version=httpstan.__version__)
# fits are stored under their "parent" models
fits_path = os.path.join(*([cache_path] + name.split("/")[:-1]))
fit_filename = os.path.join(fits_path, f'{name.split("/")[-1]}.jsonlines.lz4')
path = fit_path(name)
try:
with open(fit_filename, mode="rb") as fh:
with path.open("rb") as fh:
return fh.read()
except FileNotFoundError:
raise KeyError(f"Fit `{name}` not found.")
Expand All @@ -130,7 +136,5 @@ def delete_fit(name: str) -> None:
Arguments:
name: Stan fit name
"""
cache_path = appdirs.user_cache_dir("httpstan", version=httpstan.__version__)
fits_path = os.path.join(*([cache_path] + name.split("/")[:-1]))
fit_id = name.split("/")[-1]
pathlib.Path(os.path.join(fits_path, f"{fit_id}.jsonlines.lz4")).unlink()
path = fit_path(name)
path.unlink()
7 changes: 4 additions & 3 deletions httpstan/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import subprocess
import tempfile
from pathlib import Path
from typing import List, Tuple, Union


Expand All @@ -22,16 +23,16 @@ def compile(program_code: str, stan_model_name: str) -> Tuple[str, str]:
"""
with importlib.resources.path(__package__, "stanc") as stanc_binary:
with tempfile.TemporaryDirectory(prefix="httpstan_") as tmpdir:
filepath = os.path.join(tmpdir, f"{stan_model_name}.stan")
with open(filepath, "w") as fh:
filepath = Path(tmpdir) / f"{stan_model_name}.stan"
with filepath.open("w") as fh:
fh.write(program_code)
run_args: List[Union[os.PathLike, str]] = [
stanc_binary,
"--name",
stan_model_name,
"--warn-pedantic",
"--print-cpp",
filepath,
str(filepath),
]
completed_process = subprocess.run(run_args, capture_output=True, timeout=1)
stderr = completed_process.stderr.decode().strip()
Expand Down
36 changes: 17 additions & 19 deletions httpstan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import importlib
import importlib.resources
import logging
import os
import pathlib
import platform
import sys
from importlib.machinery import EXTENSION_SUFFIXES
from pathlib import Path
from types import ModuleType
from typing import List, Optional, Tuple

Expand All @@ -24,7 +23,7 @@
import httpstan.cache
import httpstan.compile

PACKAGE_DIR = pathlib.Path(__file__).resolve(strict=True).parents[0]
PACKAGE_DIR = Path(__file__).parent.resolve(strict=True)
logger = logging.getLogger("httpstan")


Expand Down Expand Up @@ -78,7 +77,7 @@ def import_services_extension_module(model_name: str) -> ModuleType:
KeyError: Model not found.
"""
model_directory = pathlib.Path(httpstan.cache.model_directory(model_name))
model_directory = httpstan.cache.model_directory(model_name)
try:
module_path = next(filter(lambda p: p.suffix in EXTENSION_SUFFIXES, model_directory.iterdir()))
except (FileNotFoundError, StopIteration):
Expand Down Expand Up @@ -114,25 +113,24 @@ async def build_services_extension_module(program_code: str, extra_compile_args:
"""
model_name = calculate_model_name(program_code)
model_directory_path = pathlib.Path(httpstan.cache.model_directory(model_name))
model_directory_path = httpstan.cache.model_directory(model_name)

os.makedirs(model_directory_path, exist_ok=True)
model_directory_path.mkdir(parents=True, exist_ok=True)

stan_model_name = f"model_{model_name.split('/')[1]}"
cpp_code, _ = httpstan.compile.compile(program_code, stan_model_name)
cpp_code_path = model_directory_path / f"{stan_model_name}.cpp"
with open(cpp_code_path, "w") as fh:
with cpp_code_path.open("w") as fh:
fh.write(cpp_code)

httpstan_dir = os.path.dirname(__file__)
include_dirs = [
httpstan_dir, # for socket_writer.hpp and socket_logger.hpp
model_directory_path.as_posix(),
os.path.join(httpstan_dir, "include"),
os.path.join(httpstan_dir, "include", "lib", "eigen_3.3.9"),
os.path.join(httpstan_dir, "include", "lib", "boost_1.72.0"),
os.path.join(httpstan_dir, "include", "lib", "sundials_5.6.1", "include"),
os.path.join(httpstan_dir, "include", "lib", "tbb_2019_U8", "include"),
str(PACKAGE_DIR), # for socket_writer.hpp and socket_logger.hpp
str(model_directory_path),
str(PACKAGE_DIR / "include"),
str(PACKAGE_DIR / "include" / "lib" / "eigen_3.3.9"),
str(PACKAGE_DIR / "include" / "lib" / "boost_1.72.0"),
str(PACKAGE_DIR / "include" / "lib" / "sundials_5.6.1" / "include"),
str(PACKAGE_DIR / "include" / "lib" / "tbb_2019_U8" / "include"),
]

stan_macros: List[Tuple[str, Optional[str]]] = [
Expand All @@ -156,20 +154,20 @@ async def build_services_extension_module(program_code: str, extra_compile_args:
extension = setuptools.Extension(
f"stan_services_{stan_model_name}", # filename only. Module name is "stan_services"
language="c++",
sources=[cpp_code_path.as_posix()],
sources=[str(cpp_code_path)],
define_macros=stan_macros,
include_dirs=include_dirs,
library_dirs=[f"{PACKAGE_DIR / 'lib'}"],
library_dirs=[str(PACKAGE_DIR / "lib")],
libraries=libraries,
extra_compile_args=extra_compile_args,
extra_link_args=[f"-Wl,-rpath,{PACKAGE_DIR / 'lib'}"],
extra_objects=[
(PACKAGE_DIR / "stan_services.cpp").with_suffix(".o").as_posix(),
str((PACKAGE_DIR / "stan_services.cpp").with_suffix(".o")),
],
)

extensions = [extension]
build_lib = model_directory_path.as_posix()
build_lib = str(model_directory_path)

# Building the model takes a long time. Run in a different thread.
compiler_output = await asyncio.get_event_loop().run_in_executor(
Expand Down
11 changes: 7 additions & 4 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Test services function argument lookups."""
import pathlib

import pytest

import httpstan.app
Expand All @@ -13,8 +11,13 @@
def test_model_directory() -> None:
model_name = "models/abcdef"
model_directory = httpstan.cache.model_directory(model_name)
model_dirpath = pathlib.Path(model_directory)
assert model_dirpath.name == "abcdef"
assert model_directory.name == "abcdef"


def test_fit_path() -> None:
fit_name = "models/abcdef/ghijklmn"
path = httpstan.cache.fit_path(fit_name)
assert path.name == "ghijklmn.jsonlines.lz4"


@pytest.mark.asyncio
Expand Down

0 comments on commit c8734a8

Please sign in to comment.