Skip to content

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Sep 27, 2025

Stack from ghstack (oldest at bottom):

For this program

import torch

torch._dynamo.config.enable_cpp_symbolic_shape_guards = False

def fn(x):
    for _ in range(20000):
        x = x + 1
    return x

mod = torch.fx.symbolic_trace(fn)

print(mod)


opt_mod = torch.compile(mod, backend="eager", dynamic=True)

x = torch.randn(2, 2)
opt_mod(x)

It reduces the guard overhead from 1.8 ms to 72 us. The dict pop is not required as explained in the comment.

It can reduce to 46 us if we just remove the _dict.attr pop branch.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

Copy link

pytorch-bot bot commented Sep 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164038

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 044eaff with merge base 55840fb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@mlazos mlazos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really cool speedup!

…ion"


For this program

```
import torch

torch._dynamo.config.enable_cpp_symbolic_shape_guards = False

def fn(x):
    for _ in range(20000):
        x = x + 1
    return x

mod = torch.fx.symbolic_trace(fn)

print(mod)


opt_mod = torch.compile(mod, backend="eager", dynamic=True)

x = torch.randn(2, 2)
opt_mod(x)

```

It reduces the guard overhead from 1.8 ms to 172 us. The dict pop is not required as explained in the comment.

It can reduce to 46 us if we just remove the _dict.attr pop branch.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Sep 30, 2025
…ion"


For this program

```
import torch

torch._dynamo.config.enable_cpp_symbolic_shape_guards = False

def fn(x):
    for _ in range(20000):
        x = x + 1
    return x

mod = torch.fx.symbolic_trace(fn)

print(mod)


opt_mod = torch.compile(mod, backend="eager", dynamic=True)

x = torch.randn(2, 2)
opt_mod(x)

```

It reduces the guard overhead from 1.8 ms to 72 us. The dict pop is not required as explained in the comment.

It can reduce to 46 us if we just remove the _dict.attr pop branch.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Sep 30, 2025
}
} else {
_dict[framelocals_names[i]] = value;
seen.insert(name_ptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is there a guarantee that entries in framelocals_names with the same string value are the string object? If not, then the unordered_set should use Python string compare for its hash/equality? (We should also try to craft a test case that fails if this update is done incorrectly - I recall encountering failing tests when I made mistakes in the initial implementation, so I'm confused as to why the tests today don't seem to catch incorrect updates.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was my concern as well. I don't know how to make that test case though.

@anijain2305
Copy link
Contributor Author

Not merging yet. @williamwen42 has concerns and he will take over this PR.

@anijain2305 anijain2305 changed the title [dynamo][guards] Do not dict pop in framelocals dict creation [DONT MERGE YET][dynamo][guards] Do not dict pop in framelocals dict creation Sep 30, 2025
williamwen42 added a commit that referenced this pull request Sep 30, 2025
Followup to #164038

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
@anijain2305 anijain2305 closed this Oct 1, 2025
pytorchmergebot pushed a commit that referenced this pull request Oct 1, 2025
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants