Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic shapes unable to reason: Ne(Mod(u0*u2 + u1*u2, u0 + u1), 0) #125307

Closed
ezyang opened this issue May 1, 2024 · 4 comments
Closed

Symbolic shapes unable to reason: Ne(Mod(u0*u2 + u1*u2, u0 + u1), 0) #125307

ezyang opened this issue May 1, 2024 · 4 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 1, 2024

馃悰 Describe the bug

Repro:

import torch
from torch import Tensor
from typing import Optional
import torch._prims_common as utils
import torch._dynamo
        
torch._dynamo.config.capture_scalar_outputs = True


@torch.compile(fullgraph=True)
def f(x):
    a, b, stride = x.tolist()
    torch._check_is_size(a)
    torch._check_is_size(b)
    torch._check_is_size(stride)
    ta = torch.randn(a * stride)
    tb = torch.randn(b * stride)
    r = torch.cat([ta, tb])
    return r.view(a + b, stride)


x = torch.tensor([30, 20, 10])
f(x)

Fails with

torch._dynamo.exc.UserError: Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at constrain_as_value OR constrain_as_size APIs.  Could not guard on data-dependent expression Ne(Mod(u0*u2 + u1*u2, u0 + u1), 0) (unhinted: Ne(Mod(u0*u2 + u1*u2, u0 + u1), 0)).  (Size-like symbols: u2, u1, u0)

Potential framework code culprit (scroll up for full backtrace):
  File "/data/users/ezyang/b/pytorch/torch/_refs/__init__.py", line 3691, in _reshape_view_helper
    while guard_size_oblivious(accum % length != 0):

For more information, run with TORCH_LOGS="dynamic"

This is obviously false though, just need to be a bit smarter in reasoning.

cc @msaroufim @bdhirsh @anijain2305 @chauhang @lezcano

Versions

main

@lezcano
Copy link
Collaborator

lezcano commented May 2, 2024

For this one, we need to implement a reduction rule similar to the one we have with gcd

@lezcano
Copy link
Collaborator

lezcano commented May 2, 2024

Actually, there's something funny going on here. Is that Mod class ours? Because SymPy can reason this properly:

>>> x = sympy.Symbol("x", integer=True, nonnegative=True)
>>> y= sympy.Symbol("y", integer=True, nonnegative=True)
>>> z = sympy.Symbol("z", integer=True, nonnegative=True)
>>> sympy.Mod(x*z + y*z, x + y)
0

@ezyang
Copy link
Contributor Author

ezyang commented May 3, 2024

yup, we have our own

class Mod(sympy.Function):
    """
    We maintain this so that we avoid SymPy correctness issues, such as:
    https://github.com/sympy/sympy/issues/25146
    """

    nargs = (2,)

@ezyang
Copy link
Contributor Author

ezyang commented May 3, 2024

@lezcano don't do this one, I'm going to give this as a bootcamp task

@yanboliang yanboliang added good first issue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed good first issue labels May 6, 2024
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this issue May 28, 2024
PyTorch overrides SymPy's Mod and does its own symbolic simplification. Inspired by issue pytorch#125307, this PR adds one more simplification tactic.

Fixes pytorch#125307

Pull Request resolved: pytorch#126351
Approved by: https://github.com/ezyang
Aidyn-A pushed a commit to tinglvv/pytorch that referenced this issue May 30, 2024
PyTorch overrides SymPy's Mod and does its own symbolic simplification. Inspired by issue pytorch#125307, this PR adds one more simplification tactic.

Fixes pytorch#125307

Pull Request resolved: pytorch#126351
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants