Skip to content

Commit

Permalink
feat: Provide user feedback during compilation
Browse files Browse the repository at this point in the history
Compiling the Stan program takes time. Give the user some indication of
what is happening.

Note that we cannot use a progress indicator here because httpstan
redirects everything sent to stderr to /dev/null to avoid overwhelming
the user with a flood of compiler info and warning messages. Progress
indicators can be used elsewhere.

Closes #111
  • Loading branch information
riddell-stan committed Jul 22, 2020
1 parent bc77d72 commit cdf278e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python = "^3.7"
aiohttp = "^3.6"
httpstan = "^2.0.2"
numpy = "^1.7"
tqdm = "^4.14"
clikit = "^0.6.2"

# docs
sphinx = { version = "^3.1.1", optional = true }
Expand Down
23 changes: 16 additions & 7 deletions stan/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import aiohttp
import collections.abc
import json
import typing

import aiohttp
import google.protobuf.internal.decoder
import httpstan.callbacks_writer_pb2 as callbacks_writer_pb2
import httpstan.models
import httpstan.schemas
import httpstan.services.arguments as arguments
import httpstan.utils
import numpy as np
from clikit.io import ConsoleIO

import stan.common
import stan.fit

import google.protobuf.internal.decoder
import httpstan.callbacks_writer_pb2 as callbacks_writer_pb2
import numpy as np


def _make_json_serializable(data: dict) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON-serializable form.
Expand Down Expand Up @@ -207,12 +208,21 @@ def build(program_code, data=None, random_seed=None):
assert all(not isinstance(value, np.ndarray) for value in data.values())

async def go():
io = ConsoleIO()
io.error("<info>Building...</info>")
async with stan.common.httpstan_server() as (host, port):
# Check to see if model is in cache.
model_name = httpstan.models.calculate_model_name(program_code)
path, payload = f"/v1/{model_name}/params", {"data": data}
async with aiohttp.request("POST", f"http://{host}:{port}{path}", json=payload) as resp:
model_in_cache = resp.status != 404
io.error_line(" Found model in cache." if model_in_cache else " This may take some time.")
# Note: during compilation `httpstan` redirects stderr to /dev/null, making `print` impossible.
path, payload = "/v1/models", {"program_code": program_code}
async with aiohttp.request("POST", f"http://{host}:{port}{path}", json=payload) as resp:
if resp.status != 201:
raise RuntimeError((await resp.json())["message"])
model_name = (await resp.json())["name"]
assert model_name == (await resp.json())["name"]

path, payload = f"/v1/{model_name}/params", {"data": data}
async with aiohttp.request("POST", f"http://{host}:{port}{path}", json=payload) as resp:
Expand All @@ -222,7 +232,6 @@ async def go():
assert len({param["name"] for param in params_list}) == len(params_list)
param_names, dims = zip(*((param["name"], param["dims"]) for param in params_list))
constrained_param_names = sum((tuple(param["constrained_names"]) for param in params_list), ())

return Model(model_name, program_code, data, param_names, constrained_param_names, dims, random_seed)

return asyncio.run(go())

0 comments on commit cdf278e

Please sign in to comment.