Skip to content

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Jun 12, 2024

Stack from ghstack (oldest at bottom):

Fixes #125720

I was earlier worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then @zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated (even when the user calls del x and x is the last reference).

So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects.

The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare.

Cpython notes

  1. https://docs.python.org/3/library/weakref.html
  2. https://docs.python.org/3/reference/datamodel.html#index-2

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

Copy link

pytorch-bot bot commented Jun 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128533

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 63529e8 with merge base be0eec9 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

void install_no_tensor_aliasing_guard(
const py::list& guard_managers,
py::list tensor_names,
const py::list& tensor_names,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter error - unrelated to the PR

Copy link
Member

@williamwen42 williamwen42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test where the weakref'd object is deleted? e.g.

x = Obj()
ref = weakref.ref(x)
opt_fn(ref, inp)
del x
opt_fn(ref, inp)

Fixes #125720

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

@williamwen42 added.

@anijain2305 anijain2305 added the keep-going Don't stop on first failure, keep running tests until the end label Jun 12, 2024
@anijain2305 anijain2305 requested a review from williamwen42 June 12, 2024 22:07
@williamwen42
Copy link
Member

Is there an assumption that a weakref's call (i.e. ref()) won't change during tracing or in the middle of running a dynamo-optimized function?

@anijain2305
Copy link
Contributor Author

@williamwen42 Nice catch! Yeah, this test fails

Let me think how to handle it

import torch
import weakref

def fn(y):
    x = torch.randn(4)
    x_weak = weakref.ref(x)
    if x_weak is not None and x_weak() is not None:
        z =  torch.sin(y)
    else:
        z = torch.cos(y)

    del x
    if x_weak is not None and x_weak() is not None:
        z =  torch.sin(y)
    else:
        z = torch.cos(y)
    return z




y = torch.ones(4)
ref = fn(y)

opt_fn = torch.compile(fn, backend="eager")
res = opt_fn(y)
print(ref, res)

@williamwen42
Copy link
Member

williamwen42 commented Jun 12, 2024

I tried thinking about how to support weakrefs a few months ago and ran into the issue of determining when a weakref gets invalidated - it may be the case that weakrefs can be invalidated at anytime - perhaps by a different thread, or by the garbage collector? If we do manage to support detecting when weakrefs are invalidated, do we also need to support tracing the finalizer/callback?

@anijain2305
Copy link
Contributor Author

Lol ... this is tricky .. what was I thinking .. converting to draft

@mlazos
Copy link
Contributor

mlazos commented Jun 13, 2024

From chatting offline I think it should be possible to support just getting the value out - that's most of the use cases I've seen.

Tracking the lifetime of the underlying object is much harder and I don't think it's needed.

@anijain2305
Copy link
Contributor Author

anijain2305 commented Jun 14, 2024

@jansel @williamwen42 @mlazos

I was worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then @zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated right away after del referent (even when the user calls del x and x is the last reference).

So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects.

The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare.

Fixes #125720

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@anijain2305 anijain2305 requested a review from jansel June 14, 2024 18:47
Fixes #125720

I was earlier worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated (even when the user calls del x and x is the last reference).

So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects.

The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare.


Cpython notes
1) https://docs.python.org/3/library/weakref.html
2) https://docs.python.org/3/reference/datamodel.html#index-2

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
Fixes #125720

I was earlier worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated (even when the user calls del x and x is the last reference).

So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects.

The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare.


Cpython notes
1) https://docs.python.org/3/library/weakref.html
2) https://docs.python.org/3/reference/datamodel.html#index-2

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jun 14, 2024
ghstack-source-id: 0caa63e
Pull Request resolved: #128533
@anijain2305 anijain2305 added the topic: not user facing topic category label Jun 14, 2024
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/anijain2305/377/head branch July 16, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants