Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,18 @@ def sync_shared(self):
# NOTE: sync was needed on old gpu backend
pass

def dprint(self, **kwargs):
"""Debug print itself

Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general do you know of an elegant way to copy over docstrings? It'd be nice if we could see the full dprint docstring, but only if we don't have to re-copy it in every helper method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know a way but I suspect there might be one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internet shows stuff with assigning .__doc__ variables. Feels a bit hacky to me but maybe it's conventional. Checking if @OriolAbril has any input here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen and used copying things over and setting the .__doc__ attribute manually, I don't think there is anything wrong with it

from pytensor.printing import debugprint

return debugprint(self, **kwargs)

Comment on lines +1109 to +1120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed feelings about this solution since this paradigm is quite rare, but perhaps it shouldn't be.

Suggested change
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
from pytensor.printing import debugprint

I think this does everything cleanly, but I haven't tested it.

Copy link
Member Author

@ricardoV94 ricardoV94 Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this call the function with self as the first argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On class instances, yes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats 🤯 haha


# pickling/deepcopy support for Function
def _pickle_Function(f):
Expand Down
12 changes: 12 additions & 0 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,15 @@ def __contains__(self, item: Variable | Apply) -> bool:
return item in self.apply_nodes
else:
raise TypeError()

def dprint(self, **kwargs):
"""Debug print itself

Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint

return debugprint(self, **kwargs)
7 changes: 7 additions & 0 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import (
Expand Down Expand Up @@ -862,6 +863,12 @@ def test_key_string_requirement(self):
with pytest.raises(AssertionError):
function([x], outputs={(1, "b"): x, 1.0: x**2})

def test_dprint(self):
x = pt.scalar("x")
out = x + 1
f = function([x], out)
assert f.dprint(file="str") == debugprint(f, file="str")


class TestPicklefunction:
def test_deepcopy(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/graph/test_fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from tests.graph.utils import (
MyConstant,
MyOp,
Expand Down Expand Up @@ -706,3 +707,9 @@ def test_nominals(self):
assert nm2 not in fg.inputs
assert nm in fg.variables
assert nm2 in fg.variables

def test_dprint(self):
r1, r2 = MyVariable("x"), MyVariable("y")
o1 = op1(r1, r2)
fg = FunctionGraph([r1, r2], [o1], clone=False)
assert fg.dprint(file="str") == debugprint(fg, file="str")