Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

saving render_model() output to the desired file path #1831

Closed
znwang25 opened this issue Jul 8, 2024 · 5 comments · Fixed by #1857
Closed

saving render_model() output to the desired file path #1831

znwang25 opened this issue Jul 8, 2024 · 5 comments · Fixed by #1857
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@znwang25
Copy link

znwang25 commented Jul 8, 2024

When calling numpyro.render_model(model, filename=my_path) I was expecting the generated file to be saved in my_path. But it does not.

Reading the source code, it seems that it deliberately use filename.stem instead of user provided file path. Is this the intended behavior? Unable to save the file to the path user requested make the filename option useless.

def render_model(
    model,
    model_args=None,
    model_kwargs=None,
    filename=None,
    render_distributions=False,
    render_params=False,
):
    """
    Wrap all functions needed to automatically render a model.

    .. warning:: This utility does not support the
        :func:`~numpyro.contrib.control_flow.scan` primitive.
        If you want to render a time-series model, you can try
        to rewrite the code using Python for loop.

    :param model: Model to render.
    :param model_args: Positional arguments to pass to the model.
    :param model_kwargs: Keyword arguments to pass to the model.
    :param str filename: File to save rendered model in.
    :param bool render_distributions: Whether to include RV distribution annotations in the plot.
    :param bool render_params: Whether to show params in the plot.
    """
    relations = get_model_relations(
        model,
        model_args=model_args,
        model_kwargs=model_kwargs,
    )
    graph_spec = generate_graph_specification(relations, render_params=render_params)
    graph = render_graph(graph_spec, render_distributions=render_distributions)

    if filename is not None:
        filename = Path(filename)
        graph.render(
            filename.stem, view=False, cleanup=True, format=filename.suffix[1:]
        )  # remove leading period from suffix

    return graph
@fehiepsi
Copy link
Member

fehiepsi commented Jul 8, 2024

Nice catch. I think we can use the original filename via the argument filename and get suffix with Path(filename).suffix. Do you want to submit the fix?

@fehiepsi fehiepsi added the bug Something isn't working label Jul 8, 2024
@znwang25
Copy link
Author

znwang25 commented Jul 8, 2024

I think this will get it fixed.

    if filename is not None:
        filename = Path(filename)
        graph.render(
            filename.with_suffix(''), view=False, cleanup=True, format=filename.suffix[1:]
        )  # remove leading period from suffix

@fehiepsi
Copy link
Member

fehiepsi commented Jul 8, 2024

Is it necessary to remove suffix in the filename?

@znwang25
Copy link
Author

znwang25 commented Jul 8, 2024 via email

@fehiepsi
Copy link
Member

Thanks! Sorry for the slow response. Do you want to submit the fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants