Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passed meta-information for MeasurableOps #6685

Closed

Conversation

Dhruvanshu-Joshi
Copy link
Member

@Dhruvanshu-Joshi Dhruvanshu-Joshi commented Apr 22, 2023

What is this PR about?
This PR aims to solve issue 6360 by incorporating RV meta information in intermediate MeasurableVariables. The Measurable ops covered are MeasurableComparison, MeasurableClip, MeasurableRound, MeasurableSpecifyShape, MeasurableCheckAndRaise, MeasurableIfElse, MeasurableScan, MeasurableMakeVector, MeasurableJoin, MeasurableDimShuffle, MeasurableTransforms and DiracDelta.

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

📚 Documentation preview 📚: https://pymc--6685.org.readthedocs.build/en/6685/

@codecov
Copy link

codecov bot commented Apr 22, 2023

Codecov Report

Merging #6685 (f2f6a55) into main (55d915c) will decrease coverage by 0.08%.
The diff coverage is 38.09%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6685      +/-   ##
==========================================
- Coverage   91.99%   91.92%   -0.08%     
==========================================
  Files          95       95              
  Lines       16016    16036      +20     
==========================================
+ Hits        14734    14741       +7     
- Misses       1282     1295      +13     
Impacted Files Coverage Δ
pymc/logprob/abstract.py 86.53% <38.09%> (-12.28%) ⬇️

Comment on lines 41 to 45
def __init__(self, scalar, ndim_supp=0, support_axis=None, d_type="mixed"):
super().__init__(scalar)
self.ndim_supp = ndim_supp
self.support_axis = support_axis
self.d_type = d_type
Copy link
Member

@ricardoV94 ricardoV94 Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding these attributes to each subclass, we can add it to the base MeasurableVariable class. Also we shouldn't give defaults, unless we are sure they always apply. Since most rewrites work with both discrete and continuous RVs, that's not the case.

The rewrites that introduce the MeasurableOps should always define them. We can have a helper that works for the default cases where ndim_supp, support_axes and measure_type are the same as the base RV class.

Also, for measure_dtype let's use Enum to avoid typing bugs.

It might look something like:

from enum import Enum, auto

class MeasureType(Enum):
  Discrete = auto()
  Continuous = auto()
  Mixed = auto()


def get_default_measurable_metainfo(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]:
  if not isinstance(base_op, MeasurableVariable):
      raise TypeError("base_op must be a RandomVariable or MeasurableVariable")
  
  ndim_supp = base_op.ndim_supp

  supp_axes = getattr(base_op, "supp_axes", None)
  if supp_axes is None:
    supp_axes = tuple(range(-base_op.ndim_supp, 0))

  measure_type = getattr(base_op, "measure_type", None):
  if measure_type is None:
    measure_type = MeasureType.Discrete if base_op.dtype.startswith("int") else MeasureType.Continuous
  
  return ndim_supp, supp_axes, measure_type

@@ -131,6 +131,11 @@ def _icdf_helper(rv, value, **kwargs):
class MeasurableVariable(abc.ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh we are not subclassing from it directly so its not being used (almost all of our logprob rewrites should have failed).

We usually subclass from the Ops... so maybe we should add a class decorator to add these properties to init and pass the remaining ones to the original Op class when it's initialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So from what I understand, we should add meta-information as parameters to the init using decorators for the pytensor ops which we will be using namely: Elemwise, CumOp, CheckAndRaise, SpecifyShape, IfElse, Scan , Join, MakeVector and Dimshuffle. We can then use the get_default_measurable_metainfo on the node.op which is an object of the pytensor op in the find_measurable_operation function to get the default values wherever required. Am I correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decorator would be applied on the Measurable subclasses of those Ops, e.g.

@measurable_op
class MeasurableElemwise(Elemwise):
  pass

So that we can kind of subclass without multiple inheritance. The decorator would override __init__ to capture the new parameters we need, and call the original __init__ with the rest.

Copy link
Member

@ricardoV94 ricardoV94 Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe Multiple Inheritance is the best here. We might need to think/try both to figure it out.

EDIT: I am more convinced of this one now :/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the decorator method and it works fine. There is one problem that whenever the MeasurableOp class is called in the find_measurable_op method, in many cases(almost all), it requires some default values for all the parameters including the metainfo parameters and the original init parameters. The get_default_measurable_metainfo does not work on node.op as node.op is an object of the Op and not MeasurableOp where the parameters are added.
For eg:

def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableComparison]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
if isinstance(node.op, MeasurableComparison):
return None # pragma: no cover
measurable_inputs = [
(inp, idx)
for idx, inp in enumerate(node.inputs)
if inp.owner
and isinstance(inp.owner.op, MeasurableVariable)
and inp not in rv_map_feature.rv_values
]
if len(measurable_inputs) != 1:
return None
# Make the measurable base_var always be the first input to the MeasurableComparison node
base_var: TensorVariable = measurable_inputs[0][0]
# Check that the other input is not potentially measurable, in which case this rewrite
# would be invalid
const = tuple(inp for inp in node.inputs if inp is not base_var)
# check for potential measurability of const
if not check_potential_measurability(const, rv_map_feature):
return None
const = const[0]
# Make base_var unmeasurable
unmeasurable_base_var = ignore_logprob(base_var)
node_scalar_op = node.op.scalar_op
# Change the Op if the base_var is the second input in node.inputs. e.g. pt.lt(const, dist) -> pt.gt(dist, const)
if measurable_inputs[0][1] == 1:
if isinstance(node_scalar_op, LT):
node_scalar_op = GT()
elif isinstance(node_scalar_op, GT):
node_scalar_op = LT()
elif isinstance(node_scalar_op, GE):
node_scalar_op = LE()
elif isinstance(node_scalar_op, LE):
node_scalar_op = GE()
compared_op = MeasurableComparison(node_scalar_op)
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
compared_rv.name = node.outputs[0].name
return [compared_rv]

Replacing line 93 and 94 as

ndim_supp, supp_axes, measure_type = get_default_measurable_metainfo(node.op)
compared_op = MeasurableComparison(ndim_supp, supp_axes, measure_type, node_scalar_op)

produces an attribute error

AttributeError: 'Elemwise' object has no attribute 'ndim_supp'

I guess multiple Inheritance will also lead to this error as the metainfo are passed to our MeasurableOps and not the original pytensor Ops.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's exactly why multiple inheritance would work better, as those attributes would now be part of the Op through the MeasurableClass path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made attributes part of the Op through the MeasurableClass path by including them in the init a done according to the first commit which , I believe is the approach that multiple inheritance follows. The attributes are accessible through an object of the MeasurableClass and not the Op class. However, I am still confused on how to proceed with default values. As mentioned earlier, node.op is derived from the MeasurableClass itself and not the Op. Hence in case of MeasurableElemwise and sub-class which derive from it, scalar_op being an attribute of the op itself gives no error for node.op.scalar_op but node.op.ndim_supp does.

@Dhruvanshu-Joshi
Copy link
Member Author

continued in #6754

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants