Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Minifier doesn't work with torch.compile #1964

@williamwen42

Description

@williamwen42

🐛 Describe the bug

torch.compile with the inductor backend creates a "custom backend" that preserves inductor settings. The minifier currently only works with pre-registered backends, so it would not support the backend given by torch.compile.

Error logs

No response

Minified repro

import torch
import torch._dynamo

def triton_runtime_error(x):
    return f"{x}; assert?"

import torch._inductor.codegen.triton as codegen
overrides = codegen.TritonOverrides
overrides.relu = staticmethod(triton_runtime_error)

def foo():
    y = torch.ones(200, 200).cuda()
    x = torch.ones(200, 200).cuda()
    z = x + y
    a = torch.relu(z)
    return a

torch._dynamo.config.repro_after = "dynamo"
# torch._dynamo.config.repro_level = 2

if __name__ == '__main__':
    foo_opt = torch.compile(foo)
    # foo_opt = torch._dynamo.optimize()(foo)
    foo_opt()

produces a minifier launcher script that raises an error since the backend given by torch.compile is seen as a custom backend.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions