Skip to content

Commit

Permalink
[inductor] Don't import torchvision
Browse files Browse the repository at this point in the history
Fixes #93019

Since PyTorch regularly breaks binary compatibility, `torchvision` must be
compiled with the exact same version of PyTorch. If not, then importing it may
cause mysterious failures at runtime due to binary incompatibility.

This fixes the issue by delaying the `make_fallback` call for
`torchvision.roi_align` until the operator appears in a graph being lowered, by
which point the user must have imported torchvision themself.

ghstack-source-id: 052c5d9f87803b4cf31b8fd88f32bdd4f1886122
Pull Request resolved: #93027
  • Loading branch information
peterbell10 committed Jan 25, 2023
1 parent 18d5288 commit 1d7f306
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from torch._inductor.codegen.cpp import cexpr, CppOverrides, CppVecOverrides
from torch._inductor.codegen.triton import texpr
from torch._inductor.compile_fx import compile_fx, complex_memory_overlap
from torch._inductor.ir import IndexingDiv, ModularIndexing
from torch._inductor.ir import ModularIndexing
from torch._inductor.overrides import (
linear_permute_fusion,
linear_transpose,
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
from .lowering import (
FALLBACK_ALLOW_LIST,
layout_constraints,
lowerings,
make_fallback,
Expand Down Expand Up @@ -291,7 +292,10 @@ def call_function(self, target, args, kwargs):
return super().call_function(target, args, kwargs)

if target not in lowerings:
if config.implicit_fallbacks:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
Expand Down
9 changes: 4 additions & 5 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TensorBox,
View,
)
from .utils import ceildiv, has_torchvision_roi_align, sympy_product
from .utils import ceildiv, sympy_product
from .virtualized import ops, V

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1165,10 +1165,6 @@ def require_contiguous(_, *args, **kwargs):
return args, kwargs


if has_torchvision_roi_align():
make_fallback(torch.ops.torchvision.roi_align)


def constrain_to_fx_strides(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, ir.IRNode):
Expand All @@ -1183,6 +1179,9 @@ def apply_constraint(arg, fx_arg):

# TODO(jansel): we should implement decomps or lowerings for these
# https://github.com/pytorch/torchdynamo/issues/327
FALLBACK_ALLOW_LIST = {
"torchvision::roi_align",
}
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
make_fallback(aten.convolution_backward, constrain_to_fx_strides)
make_fallback(aten._cudnn_rnn, require_dense)
Expand Down

0 comments on commit 1d7f306

Please sign in to comment.