Skip to content

Commit

Permalink
feat: Adapt to new protocol buffer message schema
Browse files Browse the repository at this point in the history
Adapt to httpstan's new protobuf message schema.
See commit 91964cb073635cd07d38617fdeb6d3f3815fe3eb in
httpstan.

The protocol buffer message now uses protocol buffers `bytes` types
where previously protocol buffers `string` types were used.
  • Loading branch information
riddell-stan committed Nov 7, 2020
1 parent 0adea1e commit c4a6f68
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.7"
aiohttp = "^3.6"
httpstan = "^2.3.0"
httpstan = "^3.0.0"
numpy = "^1.7"
clikit = "^0.6.2"

Expand Down
4 changes: 2 additions & 2 deletions stan/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def __init__(
for msg in stan_output:
if msg.topic == callbacks_writer_pb2.WriterMessage.Topic.Value("SAMPLE"):
# Ignore sample message which is mixed together with proper draws.
if msg.feature and msg.feature[0].name == "":
if msg.feature and msg.feature[0].name == b"":
continue

draw_row = [] # a "row" of values from a single draw from Stan C++

# for the first draw: collect sample and sampler parameter names.
if not hasattr(self, "_draws"):
feature_names = tuple(fea.name for fea in msg.feature)
feature_names = tuple(fea.name.decode() for fea in msg.feature)
self.sample_and_sampler_param_names = tuple(
name for name in feature_names if name.endswith("__")
)
Expand Down
14 changes: 7 additions & 7 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,17 @@ async def go():
def is_nonempty_logger_message(msg):
return (
msg.topic == callbacks_writer_pb2.WriterMessage.Topic.LOGGER
and msg.feature[0].string_list.value[0].strip() != "info:"
and msg.feature[0].bytes_list.value[0].strip() != b"info:"
)

def is_iteration_or_elapsed_time_logger_message(msg):
# Assumes `msg` is a message with topic `LOGGER`.
text = msg.feature[0].string_list.value[0]
text = msg.feature[0].bytes_list.value[0]
return (
text.startswith("info:Iteration:")
or text.startswith("info: Elapsed Time:")
text.startswith(b"info:Iteration:")
or text.startswith(b"info: Elapsed Time:")
# this detects lines following "Elapsed Time:", part of a multi-line Stan message
or text.startswith("info:" + " " * 15)
or text.startswith(b"info:" + b" " * 15)
)

logger_messages = []
Expand All @@ -216,8 +216,8 @@ def is_iteration_or_elapsed_time_logger_message(msg):
if non_standard_logger_messages:
io.error("\n<info>Messages received during sampling:</info>\n")
for msg in non_standard_logger_messages:
text = msg.feature[0].string_list.value[0].replace("info:", " ")
io.error(f"<info>{text}</info>\n")
text_bytes = msg.feature[0].bytes_list.value[0].replace(b"info:", b" ")
io.error(f"<info>{text_bytes.decode()}</info>\n")

# clean up after ourselves when fit is uncacheable (no random seed)
if self.random_seed is None:
Expand Down

0 comments on commit c4a6f68

Please sign in to comment.