Skip to content

Conversation

voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Aug 1, 2023

import torch
import numpy as np


def _inf_nan_preprocess(t, t_np):
    t_np = np.nan_to_num(t_np)
    return t, t_np


@torch.compile()
def fn():
    # shape, dims format
    test_cases = (
        (3, 3),
        (4, 4),
        (5, 5),
    )

    for shape in test_cases:
        t = torch.randn(shape, dtype=torch.complex64)
        t_np = np.random.randn(*shape).astype(np.complex64)

        _, t_np = _inf_nan_preprocess(t, t_np)
        print(t, t_np)  # Just a side effect so that compilation kicks in


fn()

cc @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @anijain2305

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106431

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 2 Pending, 3 Unrelated Failures

As of commit f38030f:

BROKEN TRUNK - The following jobs failed but were present on the merge base 3db2550:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
suppress_errors = os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", "1") == "1"
suppress_errors = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano this is intentional, this flag makes it hard to debug.

self._produce_guard_code(guard, [shape_guard], shape_env=True)

def TENSOR_MATCH(self, guard: Guard):
def TENSOR_MATCH(self, guard: Guard, value=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if .get() invokes an fn, like in the source added in this PR (NumpyTensorSource) - we do not have id stability, which is required between example values from get and traced values at fake-ification time, because of how we build the sizes/strides for tensor guards.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I really need a comment here explaining what the value arg does. Still not sure from the comment either, need to read more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will add a comment.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, with this PR in the current state, we still have that a function that just has NumPy inputs and NumPy outputs is not being compiled. A small modification of the example in the OP shows this.

OTOH, they seem to have correct guards, and their shapes are correctly traced with symints when dynamic shapes kick in.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 2, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@voznesenskym voznesenskym force-pushed the voz/torch_np branch 2 times, most recently from 1a8a61f to 49c975a Compare August 2, 2023 20:18
@voznesenskym
Copy link
Collaborator Author

@lezcano I undid your mega merge, this PR is back to being based on head of your branch. If you need work there, push it to that branch and rebase this branch, don't push to this branch please :P


def add_graph_output(self, value):
graph_outputs_key = id(value.proxy)
graph_outputs_key = id(value.as_proxy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this... do anything lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the better convention, but no, should be same.

def TENSOR_MATCH(self, guard: Guard):
def TENSOR_MATCH(self, guard: Guard, value=None):
if guard.is_nn_module():
self.ID_MATCH(guard)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy ndarray on nn module 🤔 hilarious failure mode

Copy link
Collaborator Author

@voznesenskym voznesenskym Aug 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not fail, but will overspecialize. HOWEVER - we throw out nn module guards, so this will just be okay for now.

**kwargs,
):
super().__init__(proxy, **kwargs)
self.proxy = proxy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Are there two proxies floating around now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this line, super already does that

)
options = {"source": source}
options = {"source": source, "guards": tensor_vt.guards}
numpy_ndarray_variable = wrap_fx_proxy_cls(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears these are the substantive changes

example_value=value,
guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
guards=self.make_guards(
functools.partial(GuardBuilder.TENSOR_MATCH, value=value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you passing value for EVERYBODY, not just ndarray?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this actually saves us an eval, and should always be sound.

@voznesenskym
Copy link
Collaborator Author

The main PR has all the tests passing minus the one this one is fixing. Once this one is ready to be merged I think the CI in the main PR should be fully green.

Sounds good will get this over.

@voznesenskym voznesenskym requested review from ezyang and lezcano August 6, 2023 05:04
@voznesenskym voznesenskym changed the title [WIP] Fix guarding issues w/ numpy Fix guarding issues w/ numpy Aug 6, 2023
Comment on lines +1648 to +1649
if isinstance(value, np.ndarray):
return torch.as_tensor(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to remove this one, as, if our logic is correct, an np.ndarray should never get to this point. We should probably even assert not isinstance(value, np.ndarray)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check, I am pretty certain we do hit this, and that is why I added it. Why would you expect us to not get here?

@lezcano
Copy link
Collaborator

lezcano commented Aug 6, 2023

The issue in #106431 (review) is still present (i.e., if you remove the tensor arg to the _inf_nan_preprocess function we don't compile anything) but this shouldn't be blocking to land the main PR, I don't think.

@voznesenskym
Copy link
Collaborator Author

The issue in #106431 (review) is still present (i.e., if you remove the tensor arg to the _inf_nan_preprocess function we don't compile anything) but this shouldn't be blocking to land the main PR, I don't think.

Let me take a look :)

@voznesenskym
Copy link
Collaborator Author

voznesenskym commented Aug 7, 2023

The issue in #106431 (review) is still present (i.e., if you remove the tensor arg to the _inf_nan_preprocess function we don't compile anything) but this shouldn't be blocking to land the main PR, I don't think.

Let me take a look :)

BTW this would be in another PR - its not blocking this one. I have not looked yet, but my money is on something that looks for tensor compute in frames. We do this in a few specific places. Will report back.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, another PR SGTM.

On the interaction with torch_np, this PR LGTM. Let's wait for Ed's review about the general strategy tho.

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 7, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

numpy_to_tensor_wrapper(func),
*proxy_args_kwargs(args, kwargs),
)
return NumpyNdarrayVariable.create(tx, proxy, **options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference before and after? Or is this just refactoring.

@lezcano lezcano merged this pull request into torch_np Aug 7, 2023
@lezcano
Copy link
Collaborator

lezcano commented Aug 7, 2023

I merged it because @voznesenskym was going to.
This has been merged into #106211. If any of the reviewers @yanboliang @anijain2305 @ezyang have any comments, you can put them there and I'll address them.

lezcano pushed a commit that referenced this pull request Aug 7, 2023
lezcano pushed a commit that referenced this pull request Aug 8, 2023
pytorchmergebot pushed a commit that referenced this pull request Aug 11, 2023
RFC: pytorch/rfcs#54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/

We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.

In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.

Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- #106431 (comment). @voznesenskym to submit a fix after we merge.

All the tests in `tests/torch_np` take about 75s to run.

This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.

Pull Request resolved: #106211
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the voz/torch_np branch February 4, 2025 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants