-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Passed meta-information for MeasurableOps #6685
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6685 +/- ##
==========================================
- Coverage 91.99% 91.92% -0.08%
==========================================
Files 95 95
Lines 16016 16036 +20
==========================================
+ Hits 14734 14741 +7
- Misses 1282 1295 +13
|
pymc/logprob/binary.py
Outdated
def __init__(self, scalar, ndim_supp=0, support_axis=None, d_type="mixed"): | ||
super().__init__(scalar) | ||
self.ndim_supp = ndim_supp | ||
self.support_axis = support_axis | ||
self.d_type = d_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding these attributes to each subclass, we can add it to the base MeasurableVariable class. Also we shouldn't give defaults, unless we are sure they always apply. Since most rewrites work with both discrete and continuous RVs, that's not the case.
The rewrites that introduce the MeasurableOps should always define them. We can have a helper that works for the default cases where ndim_supp
, support_axes
and measure_type
are the same as the base RV class.
Also, for measure_dtype
let's use Enum to avoid typing bugs.
It might look something like:
from enum import Enum, auto
class MeasureType(Enum):
Discrete = auto()
Continuous = auto()
Mixed = auto()
def get_default_measurable_metainfo(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]:
if not isinstance(base_op, MeasurableVariable):
raise TypeError("base_op must be a RandomVariable or MeasurableVariable")
ndim_supp = base_op.ndim_supp
supp_axes = getattr(base_op, "supp_axes", None)
if supp_axes is None:
supp_axes = tuple(range(-base_op.ndim_supp, 0))
measure_type = getattr(base_op, "measure_type", None):
if measure_type is None:
measure_type = MeasureType.Discrete if base_op.dtype.startswith("int") else MeasureType.Continuous
return ndim_supp, supp_axes, measure_type
5947550
to
f2f6a55
Compare
@@ -131,6 +131,11 @@ def _icdf_helper(rv, value, **kwargs): | |||
class MeasurableVariable(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh we are not subclassing from it directly so its not being used (almost all of our logprob rewrites should have failed).
We usually subclass from the Ops... so maybe we should add a class decorator to add these properties to init and pass the remaining ones to the original Op class when it's initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So from what I understand, we should add meta-information as parameters to the init using decorators for the pytensor ops which we will be using namely: Elemwise, CumOp, CheckAndRaise, SpecifyShape, IfElse, Scan , Join, MakeVector and Dimshuffle. We can then use the get_default_measurable_metainfo
on the node.op which is an object of the pytensor op in the find_measurable_operation
function to get the default values wherever required. Am I correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decorator would be applied on the Measurable subclasses of those Ops, e.g.
@measurable_op
class MeasurableElemwise(Elemwise):
pass
So that we can kind of subclass without multiple inheritance. The decorator would override __init__
to capture the new parameters we need, and call the original __init__
with the rest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But maybe Multiple Inheritance is the best here. We might need to think/try both to figure it out.
EDIT: I am more convinced of this one now :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the decorator method and it works fine. There is one problem that whenever the MeasurableOp
class is called in the find_measurable_op
method, in many cases(almost all), it requires some default values for all the parameters including the metainfo parameters and the original init parameters. The get_default_measurable_metainfo
does not work on node.op
as node.op
is an object of the Op
and not MeasurableOp
where the parameters are added.
For eg:
Lines 44 to 97 in a617bf2
def find_measurable_comparisons( | |
fgraph: FunctionGraph, node: Node | |
) -> Optional[List[MeasurableComparison]]: | |
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) | |
if rv_map_feature is None: | |
return None # pragma: no cover | |
if isinstance(node.op, MeasurableComparison): | |
return None # pragma: no cover | |
measurable_inputs = [ | |
(inp, idx) | |
for idx, inp in enumerate(node.inputs) | |
if inp.owner | |
and isinstance(inp.owner.op, MeasurableVariable) | |
and inp not in rv_map_feature.rv_values | |
] | |
if len(measurable_inputs) != 1: | |
return None | |
# Make the measurable base_var always be the first input to the MeasurableComparison node | |
base_var: TensorVariable = measurable_inputs[0][0] | |
# Check that the other input is not potentially measurable, in which case this rewrite | |
# would be invalid | |
const = tuple(inp for inp in node.inputs if inp is not base_var) | |
# check for potential measurability of const | |
if not check_potential_measurability(const, rv_map_feature): | |
return None | |
const = const[0] | |
# Make base_var unmeasurable | |
unmeasurable_base_var = ignore_logprob(base_var) | |
node_scalar_op = node.op.scalar_op | |
# Change the Op if the base_var is the second input in node.inputs. e.g. pt.lt(const, dist) -> pt.gt(dist, const) | |
if measurable_inputs[0][1] == 1: | |
if isinstance(node_scalar_op, LT): | |
node_scalar_op = GT() | |
elif isinstance(node_scalar_op, GT): | |
node_scalar_op = LT() | |
elif isinstance(node_scalar_op, GE): | |
node_scalar_op = LE() | |
elif isinstance(node_scalar_op, LE): | |
node_scalar_op = GE() | |
compared_op = MeasurableComparison(node_scalar_op) | |
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output() | |
compared_rv.name = node.outputs[0].name | |
return [compared_rv] |
Replacing line 93 and 94 as
ndim_supp, supp_axes, measure_type = get_default_measurable_metainfo(node.op)
compared_op = MeasurableComparison(ndim_supp, supp_axes, measure_type, node_scalar_op)
produces an attribute error
AttributeError: 'Elemwise' object has no attribute 'ndim_supp'
I guess multiple Inheritance will also lead to this error as the metainfo are passed to our MeasurableOps
and not the original pytensor Ops
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's exactly why multiple inheritance would work better, as those attributes would now be part of the Op through the MeasurableClass path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made attributes part of the Op through the MeasurableClass path by including them in the init
a done according to the first commit which , I believe is the approach that multiple inheritance follows. The attributes are accessible through an object of the MeasurableClass and not the Op class. However, I am still confused on how to proceed with default values. As mentioned earlier, node.op
is derived from the MeasurableClass itself and not the Op. Hence in case of MeasurableElemwise and sub-class which derive from it, scalar_op
being an attribute of the op itself gives no error for node.op.scalar_op
but node.op.ndim_supp
does.
continued in #6754 |
What is this PR about?
This PR aims to solve issue 6360 by incorporating RV meta information in intermediate MeasurableVariables. The Measurable ops covered are MeasurableComparison, MeasurableClip, MeasurableRound, MeasurableSpecifyShape, MeasurableCheckAndRaise, MeasurableIfElse, MeasurableScan, MeasurableMakeVector, MeasurableJoin, MeasurableDimShuffle, MeasurableTransforms and DiracDelta.
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance
📚 Documentation preview 📚: https://pymc--6685.org.readthedocs.build/en/6685/