Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt singleton semantics for globally defined constraint instances #1507

Merged
merged 12 commits into from
Jan 22, 2023

Conversation

pierreglaser
Copy link
Contributor

Closes #1378.

@fritzo
Copy link
Member

fritzo commented Dec 5, 2022

Doh, thanks for fixing this issue! In retrospect we probably should have avoided the singletons and required the slightly less clever notation constraints.Real() 😅

LGTM, but I defer to @fehiepsi

@fritzo
Copy link
Member

fritzo commented Dec 5, 2022

On second thought, I think it would be cleaner to handle this as in general (rather than merely for pickling) by following the singleton pattern:

class _Real(Constraint):
    def __new__(cls):
        if not hasattr(cls, "_instance"):
            cls._instance = super().__new__(cls)
        return cls._instance
    ...

That way _Real() is _Real() even outside of pickling. WDYT?

@pierreglaser
Copy link
Contributor Author

I don't have a strong opinion. _Real is private anyways, so it should not be accessed nor instantiated outside of the constraints module; only its singletonreal should be. But as I said, no strong opinion.

@fritzo
Copy link
Member

fritzo commented Dec 6, 2022

I guess another advantage of .__new__() over .__reduce__() is that it works with copy.copy:

constraints.real is copy.copy(constraints.real)

And thanks for bringing up this issue! It's also a problem with torch.distributions.constraints and we've never thought it through fully. In PyTorch we tried to make the _Real class private but devs ended up using it in type annotations and unit tests anyway, ignoring the intentionally public type(real) 🤷

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Dec 6, 2022

In PyTorch we tried to make the _Real class private but devs ended up using it in type annotations and unit tests anyway

If by experience private contracts are not respected in practice, then let's go with __new__ instead of __reduce__.

I guess another advantage of .__new__() over .__reduce__() is that it works with copy.copy:

For the record: copy.copy actually __reduces__ (and then reconstructs using the returned __reduce__ tuple) an object in order to copy it.

Overriding __new__ as suggested would preserve identity semantics under copy operations precisely because the default object.__reduce__ generates a NEWOBJ pickle OPCODE, which triggers a call to _Real.__new__ during reconstruction :-)

It's also a problem with torch.distributions.constraints and we've never thought it through fully

Interesting. I mostly work with jax now, hence my PR here.

@fritzo
Copy link
Member

fritzo commented Dec 6, 2022

copy.copy actually reduces

Oh thanks for explaining, I didn't realize that!

@fehiepsi
Copy link
Member

fehiepsi commented Dec 6, 2022

Thanks @pierreglaser (and @fritzo for chiming in)! Could you also add this logic or the __new__ one for other constraints?

@pierreglaser
Copy link
Contributor Author

Done.

fehiepsi
fehiepsi previously approved these changes Dec 19, 2022
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @pierreglaser!

numpyro/distributions/constraints.py Outdated Show resolved Hide resolved
@pierreglaser
Copy link
Contributor Author

Actually, the current logic provides no way to treat as singleton instances the constraints coming from parametrized classes and exposed as top-level variables of the constraints module, such as:

positive = _GreaterThan(0.0)

Treating them as complete singleton instances would require introspecting the arguments of the constructor and using a lookup-table mechanism to return instances already created for a given set of arguments.

Preserving identity semantics during pickling (which is ultimately what this PR was for) can be done by adding a lookup table of all global variables of the constraints module, and pickling the global variables by reference during __reduce__. Note that this logic does not require any singleton logic, such as the __new__ method introduced in previous commits.

Ultimately, I think addressing 1 (treating constraints a singleton when applicable) is more complex, and to be fair, out of scope for this PR, which originally addressed 2 (preserving identity semantics during pickling). So I'd rather rollback any changes made to treat Constraint objects as singleton, and simply address 2 in this PR.

WDYT @fehiepsi @fritzo ?

@pierreglaser
Copy link
Contributor Author

(By the way, I can't parse the lint failure from the CI, any help would be appreciated)

@fritzo
Copy link
Member

fritzo commented Dec 21, 2022

no way to treat as singleton instances the constraints coming from parametrized classes

I believe that's fine. Here's my reasoning:

  1. The biggest pain point I've seen with pickling constraints is when unpickling non-parametrized constraints like constraints.real causing errors in code that dispatches on if constraint is constraints.real. I've only ever seen constraint is used with non-parametrized classes; all parametrized classes dispatch via isinstance or something like singledispatch which examine the type rather than the identity of a constraint object.
  2. At least in Pyro (unsure about NumPyro), we try to avoid data-dependent control flow. In this situation of pickling constraints, 'data' means the parameters of a constraint, and 'control flow' is something that depends/dispatches on the identity of a constraint object rather than it's type. Thus it seems fine to deduplicate/singletonize only the non-parametrized classes.
  3. More generally, we've tried to avoid data-dependent memoization, which I believe is what would be required to deduplicate parametrized types. This is more important in PyTorch where data is mutable, but it may also be a good habit in JAX even if data is immutable.

Thanks for thinking hard about this issue!

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Dec 22, 2022

I agree that, say, for any instance of the parametrized _GreaterThan other than positive, it makes sense to use isinstance and introspection checks. But for the special case of numpyro.distributions.constraints.positive (also an instance of _GreaterThan), I would have imagined that the semantics is exactly the same as for other top-level constraints like numpyro.distributions.constraints.real, e.g. if constraint is constraints.positive. Can you confirm your opinion on that point?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though it's convenient to have constraint is constraints.positive, I feel that it is a better practice to use the check isinstance(constraint, constraints.greater_than) for a couple of reasons:

I'm not sure how to tell users about the differences of real and greater_than. Probably we can add an attribute named is_singleton which is False by default.

I did several greps and found that we didn't follow the expected behavior at those places:

All of these were for efficiency. I'll remove them in a separate PR.

numpyro/distributions/constraints.py Outdated Show resolved Hide resolved
@fritzo
Copy link
Member

fritzo commented Dec 27, 2022

Can you confirm your opinion on that point [whether constraints.positive should behave like constraints.real]?

I do see the parsimony of constraint is constraints.positive but I agree with @fehiepsi that it's probably safer to use isinstance here in practice. Practically, I think we could get your idea working for python-only data (no arrays) using a dependent-singleton pattern

something like this:

class _DependentSingleton:
    def __new__(cls, *args):  # TODO convert **kwargs to *args
        # only cache simple python things
        if not all(isinstance(arg, (int, float)) for arg in args):
            return super().__new__(*args)

        # weakly memoize
        cache = cls.__dict__.setdefault("_instance_cache", WeakValueDictionary())
        instance = cache.get(args)
        if instance is None:
            instance = super().__new__(*args)
            cache[args] = instance
        return instance

but that's a bit complex and seems out of scope for your nice clean first PR.

@fehiepsi
Copy link
Member

Alternatively, I guess we can also subclass those constraints, like

class _Positive(_GreaterThan):
    def __init__(self):
        super().__init__(0.)

positive = _Positive()

@pierreglaser
Copy link
Contributor Author

Though it's convenient to have constraint is constraints.positive, I feel that it is a better practice to use the check isinstance(constraint, constraints.greater_than) for a couple of reasons

As such, this isinstance check will capture any greater_than instance, and not only positive-like instances. A full-blown positive check would require parameter introspection, and becomes much more verbose than an identity check.

_TOPLEVEL_CONSTRAINTS is not exposed to users, so the semantic for user-defined constraints will be different from the "toplevel" ones

IMO, defining equality/identity semantics for user-defined constraints should be left to the user defining them. To me, this sounds more understandable and uncontroversial than explaining that the two built-inconstraints.real and constraints.positive, objects of seemingly similar nature, should be treated differently for implementation-details reasons.

Alternatively, I guess we can also subclass those constraints, like

class _Positive(_GreaterThan):
    def __init__(self):
        super().__init__(0.)

positive = _Positive()

I like this solution a lot @fehiepsi :) it is more elegant than a global lookup table, and preserves identity semantics for top-level singleton objects, which I find more natural and concise than instance+attribute checking. If that sounds good to you, I'd happily implement this + singleton-style __new__/__reduce__ for all parametrized and non-parametrized singletons, after which the PR should be good to merge, cc @fritzo.

Long term, my suggestion would maybe be to expose a check_constraint_equals utility that unifies equality/identity checks for both singleton-style constraints and non-singleton-style constraints, which would work for all builtin constraints, and default to an __eq__ check for user-defined ones. This change could be accompanied a paragraph in the docs about it, indicating that user should define __eq__ for user-defined constraints subclasses in order for their instances to work with check_constraint_equals.

@pierreglaser
Copy link
Contributor Author

Regarding the (still open) pytorch/pytorch#50616, I don't believe that numpyro (and pyro, provided that its constraint module is structured as in numpyro) should expose singleton classes. constraint.real is in essence a singleton object and different from parametrized interval instances. It is logical than they should be treated differently in equality check - the (limited) increase in code compacity resulting from exposing _Real instead of _Real() cannot outweight that point.

@fehiepsi
Copy link
Member

A full-blown positive check would require parameter introspection, and becomes much more verbose than an identity check.

Agreed! I tried to add some logic to check for constraints.positive/unit_interval in #1519 but eventually gave up due to the complexity (e.g. we need to check if lower_bound is int/float, then compare the value to 0).

I'd happily implement this + singleton-style new/reduce for all parametrized and non-parametrized singletons, after which the PR should be good to merge

LGTM. It's great to have this in 0.11.0 release. (not important: if possible, could you revert the checks for positive, unit_interval in #1519 and add a comment that the check is possible after this PR? - I don't have a strong opinion regarding this - given that there are many positive-support distributions, we can gain some speed).

@pierreglaser pierreglaser changed the title pickle constraints.real by reference Adopt singleton semantics for globally defined constraint instances Jan 3, 2023
@pierreglaser
Copy link
Contributor Author

@fehiepsi @fritzo I implemented the solution agreed upon above. I think the PR is ready for another round of review.

fritzo
fritzo previously approved these changes Jan 3, 2023
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for making this module much more consistent!

numpyro/distributions/constraints.py Show resolved Hide resolved
numpyro/infer/util.py Outdated Show resolved Hide resolved
numpyro/distributions/transforms.py Show resolved Hide resolved
numpyro/distributions/transforms.py Show resolved Hide resolved
@pierreglaser
Copy link
Contributor Author

@fritzo @fehiepsi thanks for the review. all comments should have been addressed now.

test/test_pickle.py Outdated Show resolved Hide resolved
fehiepsi
fehiepsi previously approved these changes Jan 4, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending lint issue.

@pierreglaser
Copy link
Contributor Author

Addressed the linting issue. Some test related to funsor is failing though. Can't look at it now, but will investigate later.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 9, 2023

Hi @pierreglaser, sorry, I'll take a closer look. It seems that in funsor, we use support's name for some logics. Maybe it's easiest to let me to add a patch and create an issue there.

@fehiepsi fehiepsi merged commit c46b0db into pyro-ppl:master Jan 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Constraints object are not robust to pickling
3 participants