Skip to content

Commit

Permalink
feat: Improve UI messages during build and sample
Browse files Browse the repository at this point in the history
Improve user interface messages during build and sample.  Report build
progress. Sampling progress reports now resemble those you would see
when doing `git clone` (for a large repository).

These changes make use of poorly-documented features in `clikit`.
Since `clikit` is used by `poetry`, relying on these features seems
safe.

Closes #190
  • Loading branch information
riddell-stan committed Apr 26, 2021
1 parent a138499 commit 163af4f
Showing 1 changed file with 41 additions and 23 deletions.
64 changes: 41 additions & 23 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import json
import re
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union

import httpstan.models
Expand All @@ -11,7 +12,6 @@
import numpy as np
import simdjson
from clikit.io import ConsoleIO
from clikit.ui.components import ProgressBar

import stan.common
import stan.fit
Expand Down Expand Up @@ -158,7 +158,7 @@ def _create_fit(self, *, function, num_chains, **kwargs) -> stan.fit.Fit:
payload["random_seed"] = self.random_seed # type: ignore

# fit needs to know num_samples, num_warmup, num_thin, save_warmup
# progress bar needs to know some of these
# progress reporting needs to know some of these
num_warmup = payload.get("num_warmup", arguments.lookup_default(arguments.Method["SAMPLE"], "num_warmup"))
num_samples = payload.get(
"num_samples",
Expand All @@ -173,9 +173,9 @@ def _create_fit(self, *, function, num_chains, **kwargs) -> stan.fit.Fit:

async def go():
io = ConsoleIO()
io.error_line("<info>Sampling...</info>")
progress_bar = ProgressBar(io)
progress_bar.set_format("very_verbose")
sampling_output = io.section().error_output
percent_complete = 0
sampling_output.write_line(f"<comment>Sampling:</comment> {percent_complete:3.0f}%")

current_and_max_iterations_re = re.compile(r"Iteration:\s+(\d+)\s+/\s+(\d+)")
async with stan.common.HttpstanClient() as client:
Expand Down Expand Up @@ -204,14 +204,23 @@ async def go():
iteration, iteration_max = map(
int, current_and_max_iterations_re.findall(progress_message).pop(0)
)
if not progress_bar.get_max_steps(): # i.e., has not started
progress_bar.start(max=iteration_max * num_chains)
current_iterations[operation["name"]] = iteration
progress_bar.set_progress(sum(current_iterations.values()))
iterations_count = sum(current_iterations.values())
total_iterations = iteration_max * num_chains
percent_complete = 100 * iterations_count / total_iterations
sampling_output.clear()
sampling_output.write_line(
f"<comment>Sampling:</comment> {round(percent_complete):3.0f}% ({iterations_count}/{total_iterations})"
)
await asyncio.sleep(0.01)
# Sampling has finished. But we do not call `progress_bar.finish()` right
# now. First we write informational messages to the screen, then we
# redraw the (complete) progress bar. Only after that do we call `finish`.

sampling_output.clear()
fit_in_cache = len(current_iterations) < num_chains
sampling_output.write_line(
"<info>Sampling:</info> 100%, done."
if fit_in_cache
else f"<info>Sampling:</info> {percent_complete:3.0f}% ({iterations_count}/{total_iterations}), done."
)

stan_outputs = []
for operation in operations:
Expand Down Expand Up @@ -258,17 +267,12 @@ def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
nonstandard_logger_messages.append(msg.as_dict())
del parser # simdjson.Parser is no longer used at this point.

progress_bar.clear()
io.error("\x08" * progress_bar._last_messages_length) # move left to start of line
if nonstandard_logger_messages:
io.error_line("<comment>Messages received during sampling:</comment>")
for msg in nonstandard_logger_messages:
text = msg["values"][0].replace("info:", " ").replace("error:", " ")
if text.strip():
io.error_line(f"{text}")
progress_bar.display() # re-draw the (complete) progress bar
progress_bar.finish()
io.error_line("\n<info>Done.</info>")

fit = stan.fit.Fit(
stan_outputs,
Expand Down Expand Up @@ -437,18 +441,35 @@ def build(program_code: str, data: Data = frozendict(), random_seed: Optional[in

async def go():
io = ConsoleIO()
io.error("<info>Building...</info>")
# hack: use stdout instead of stderr because httpstan silences stderr during compilation
building_output = io.section().output
building_output.write("<comment>Building:</comment>")
async with stan.common.HttpstanClient() as client:
# Check to see if model is in cache.
model_name = httpstan.models.calculate_model_name(program_code)
resp = await client.post(f"/{model_name}/params", json={"data": data})
model_in_cache = resp.status != 404
io.error("\n" if model_in_cache else " This may take some time.\n")

# Note: during compilation `httpstan` redirects stderr to /dev/null, making `print` impossible.
resp = await client.post("/models", json={"program_code": program_code})
task = asyncio.create_task(client.post("/models", json={"program_code": program_code}))
start = time.time()
while True:
done, pending = await asyncio.wait({task}, timeout=0.1)
if done:
break
building_output.clear()
building_output.write(f"<comment>Building:</comment> {time.time() - start:0.1f}s")
building_output.clear()
# now that httpstan has released stderr, we can use error_output again
building_output = io.section().error_output
resp = task.result()

if resp.status != 201:
raise RuntimeError(resp.json()["message"])
building_output.clear()
if model_in_cache:
building_output.write("<info>Building:</info> found in cache, done.")
else:
building_output.write(f"<info>Building:</info> {time.time() - start:0.1f}s, done.")
assert model_name == resp.json()["name"]
if resp.json().get("stanc_warnings"):
io.error_line("<comment>Messages from <fg=cyan;options=bold>stanc</>:</comment>")
Expand All @@ -461,9 +482,6 @@ async def go():
assert len({param["name"] for param in params_list}) == len(params_list)
param_names, dims = zip(*((param["name"], param["dims"]) for param in params_list))
constrained_param_names = sum((tuple(param["constrained_names"]) for param in params_list), ())
if model_in_cache:
io.error("<comment>Found model in cache.</comment> ")
io.error_line("<info>Done.</info>")
return Model(model_name, program_code, data, param_names, constrained_param_names, dims, random_seed)

try:
Expand Down

0 comments on commit 163af4f

Please sign in to comment.