Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a helper function to print intermediate values #7146

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.printing import Print
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
Expand Down Expand Up @@ -275,6 +276,15 @@ def floatX(X):
return np.asarray(X, dtype=pytensor.config.floatX)


def print_value(var, name=None):
"""Print value of variable when it is computed during sampling.
This is likely to affect sampling performance.
"""
if name is None:
name = var.name
return Print(name)(var)


_conversion_map = {"float64": "int32", "float32": "int16", "float16": "int8", "float8": "int8"}


Expand Down
39 changes: 38 additions & 1 deletion tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import warnings

from io import StringIO

import numpy as np
import numpy.ma as ma
import numpy.testing as npt
Expand All @@ -22,10 +25,11 @@
import pytest
import scipy.sparse as sps

from pytensor import scan, shared
from pytensor import function, scan, shared
from pytensor.compile import UnusedInputError
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable
from pytensor.tensor import dvector
from pytensor.tensor.random.basic import normal, uniform
from pytensor.tensor.random.var import RandomStateSharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand All @@ -43,6 +47,7 @@
constant_fold,
convert_observed_data,
extract_obs_data,
print_value,
replace_rng_nodes,
replace_vars_in_graphs,
reseed_rngs,
Expand Down Expand Up @@ -732,3 +737,35 @@ def test_replace_vars_in_graphs_nested_reference():
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is not changed in place
assert np.abs(y.eval()) < 1


def test_print_value(capsys):
"""Ensure that print_value correctly prints the variable name and value."""
# Define input vectors
x_values = np.array([1, 2, 3])
y_values = np.array([1, 0, -1])

# Define tensor variables
x = dvector("x")
y = dvector("y")

# Redirect sys.stdout to a StringIO object
original_stdout = sys.stdout
sys.stdout = StringIO()

# Evaluate expression with print statement
z_with_print = print_value(x - y, "x - y")
func_with_print = function([x, y], 1 / z_with_print)
func_with_print(x_values, y_values)

# Get the printed output
printed_output = sys.stdout.getvalue().strip()

# Restore sys.stdout
sys.stdout = original_stdout

# Expected output (adjust according to the actual output)
expected_output = "x - y __str__ = [0. 2. 4.]"

# Check if the expected output was printed
assert expected_output in printed_output