Skip to content

Commit

Permalink
Preserve path to the file when rendering model (#3186)
Browse files Browse the repository at this point in the history
* Issue #3184

* reverted signature

* os import

* black formatting

* split line
  • Loading branch information
LysSanzMoreta committed Mar 3, 2023
1 parent 685c7ad commit 04fc486
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
import os
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
Expand Down Expand Up @@ -572,7 +573,7 @@ def render_model(
list of tuples for semisupervised models.
:param model_kwargs: Dict of keyword arguments to pass to the model, or
list of dicts for semisupervised models.
:param str filename: File to save rendered model in.
:param str filename: Name of file or path to file to save rendered model in.
:param bool render_distributions: Whether to include RV distribution
annotations (and param constraints) in the plot.
:param bool render_params: Whether to show params inthe plot.
Expand Down Expand Up @@ -604,9 +605,9 @@ def render_model(
graph = render_graph(graph_spec, render_distributions=render_distributions)

if filename is not None:
filename = Path(filename)
suffix = filename.suffix[1:] # remove leading period from suffix
graph.render(filename.stem, view=False, cleanup=True, format=suffix)
suffix = Path(filename).suffix[1:] # remove leading period from suffix
filepath = os.path.splitext(filename)[0]
graph.render(filepath, view=False, cleanup=True, format=suffix)

return graph

Expand Down

0 comments on commit 04fc486

Please sign in to comment.