diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index ef4aa2a157..9330eeddb6 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -40,6 +40,7 @@ ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.printing import Print from pytensor.scalar.basic import Cast from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable @@ -275,6 +276,15 @@ def floatX(X): return np.asarray(X, dtype=pytensor.config.floatX) +def print_value(var, name=None): + """Print value of variable when it is computed during sampling. + This is likely to affect sampling performance. + """ + if name is None: + name = var.name + return Print(name)(var) + + _conversion_map = {"float64": "int32", "float32": "int16", "float16": "int8", "float8": "int8"} diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 29f53808e8..0869ab04a8 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import warnings +from io import StringIO + import numpy as np import numpy.ma as ma import numpy.testing as npt @@ -22,10 +25,11 @@ import pytest import scipy.sparse as sps -from pytensor import scan, shared +from pytensor import function, scan, shared from pytensor.compile import UnusedInputError from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable +from pytensor.tensor import dvector from pytensor.tensor.random.basic import normal, uniform from pytensor.tensor.random.var import RandomStateSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -43,6 +47,7 @@ constant_fold, convert_observed_data, extract_obs_data, + print_value, replace_rng_nodes, replace_vars_in_graphs, reseed_rngs, @@ -732,3 +737,35 @@ def test_replace_vars_in_graphs_nested_reference(): assert np.abs(x.eval()) < 1 # Confirm the original `y` variable is not changed in place assert np.abs(y.eval()) < 1 + + +def test_print_value(capsys): + """Ensure that print_value correctly prints the variable name and value.""" + # Define input vectors + x_values = np.array([1, 2, 3]) + y_values = np.array([1, 0, -1]) + + # Define tensor variables + x = dvector("x") + y = dvector("y") + + # Redirect sys.stdout to a StringIO object + original_stdout = sys.stdout + sys.stdout = StringIO() + + # Evaluate expression with print statement + z_with_print = print_value(x - y, "x - y") + func_with_print = function([x, y], 1 / z_with_print) + func_with_print(x_values, y_values) + + # Get the printed output + printed_output = sys.stdout.getvalue().strip() + + # Restore sys.stdout + sys.stdout = original_stdout + + # Expected output (adjust according to the actual output) + expected_output = "x - y __str__ = [0. 2. 4.]" + + # Check if the expected output was printed + assert expected_output in printed_output