Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inductor Freezing #100652

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
65d0a76
Inductor Optimize For Inference/Freezing
eellison May 4, 2023
85539ee
Update on "Inductor Optimize For Inference/Freezing"
eellison May 5, 2023
b26af0b
Update on "Inductor Optimize For Inference/Freezing"
eellison May 9, 2023
7e1e2a9
Update on "Inductor Optimize For Inference/Freezing"
eellison May 9, 2023
03bd79e
Update on "Inductor Optimize For Inference/Freezing"
eellison May 9, 2023
edadb81
Update on "Inductor Optimize For Inference/Freezing"
eellison May 9, 2023
20100ca
Update on "Inductor Optimize For Inference/Freezing"
eellison May 9, 2023
955ba1d
Update on "Inductor Freezing"
eellison May 16, 2023
6fdc419
Update on "Inductor Freezing"
eellison May 16, 2023
da65575
Update on "Inductor Freezing"
eellison May 16, 2023
6333b39
Update on "Inductor Freezing"
eellison May 16, 2023
887649b
Update on "Inductor Freezing"
eellison May 16, 2023
44dd75a
Update on "Inductor Freezing"
eellison May 16, 2023
e43ccdd
Update on "Inductor Freezing"
eellison May 16, 2023
ca0263f
Update on "Inductor Freezing"
eellison May 16, 2023
3e5eea2
Update on "Inductor Freezing"
eellison May 26, 2023
25e3f65
Update on "Inductor Freezing"
eellison May 26, 2023
fb2d815
Update on "Inductor Freezing"
eellison May 26, 2023
bd2f558
Update on "Inductor Freezing"
eellison May 26, 2023
fc87b32
Update on "Inductor Freezing"
eellison May 27, 2023
c6f0b94
Update on "Inductor Freezing"
eellison Jun 5, 2023
0449e31
Update on "Inductor Freezing"
eellison Jun 6, 2023
0e7b949
Update on "Inductor Freezing"
eellison Jun 8, 2023
cbd30a6
Update on "Inductor Freezing"
eellison Jun 8, 2023
e4ab199
Update on "Inductor Freezing"
eellison Jun 8, 2023
eb915e3
Update on "Inductor Freezing"
eellison Jun 9, 2023
bf875a0
Update on "Inductor Freezing"
eellison Jun 9, 2023
e6199da
Update on "Inductor Freezing"
eellison Jun 9, 2023
91fd27c
Update on "Inductor Freezing"
eellison Jun 9, 2023
8372107
Update on "Inductor Freezing"
eellison Jun 12, 2023
17a09e6
Update on "Inductor Freezing"
eellison Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
258 changes: 258 additions & 0 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Owner(s): ["module: inductor"]
import contextlib
import functools
import importlib
import os
import sys
import unittest
import weakref

import torch

import torch._dynamo
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)

from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests

if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")

importlib.import_module("functorch")
importlib.import_module("filelock")

from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA

HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
"freezing": True,
"freezing_discard_parameters": True,
}
)
)

@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()

def setUp(self):
torch._dynamo.reset()
super().setUp()

def tearDown(self):
super().tearDown()
torch._dynamo.reset()


class ConvBN(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float)

def forward(self, x):
return self.bn(self.conv(x))


class OptimizeForInferenceTemplate(TestCase):
def test_mutation(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.mutated_param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self):
self.mutated_param.add_(10)
return self.mutated_param

with torch.no_grad():
mod = Mod().to(self.device)
out_eager = mod()
out_eager2 = mod()

mod = Mod().to(self.device)

@torch.compile
def foo(mod):
return mod()

out_comp = foo(mod)
out_comp2 = foo(mod)

self.assertEqual(out_eager, out_comp)
self.assertEqual(out_eager2, out_comp2)

def test_aliased_param_return(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.aliased_param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self):
return self.aliased_param[1:], self.aliased_param

mod = Mod().to(self.device).eval()

@torch.compile()
def foo(mod):
return mod()

with torch.no_grad():
mod_eager = mod()
self.assertEqual(foo(mod), mod_eager)

def test_autocast(self):
if self.device == "cpu":
raise unittest.SkipTest("MLKDNN Bug")

mod = torch.nn.Linear(10, 10).to(self.device).eval()
inp = torch.rand([10, 10]).to(self.device).to(torch.half)

@torch.compile()
def foo(mod, inp):
return mod(inp)

with torch.no_grad():
with self.autocast():
out_eager = mod(inp)
out_compiled, code = run_and_get_code(foo, mod, inp)

FileCheck().check_not("@triton.jit").run(code[0])
self.assertEqual(out_eager, out_compiled)

def test_error_on_eager(self):
mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device)

x = torch.rand(3, 3, 32, 32).to(self.device)

@torch.compile()
def foo(mod, x):
return mod(x)

with torch.no_grad():
foo(mod, x)

with self.assertRaisesRegex(
RuntimeError, "Trying to Run Pytorch Eager Module After Dynamo Freezing"
):
mod(x)

def test_rng_op(self):
@torch.compile()
def foo():
return torch.rand([4, 4], device=self.device) + 1

with torch.no_grad():
o1 = foo()
o2 = foo()
self.assertNotEqual(o1, o2)

def test_symint_not_folded(self):
def fn(a):
return a.cos(), torch.zeros(a.shape[0], a.shape[1])

fn_opt = torch._dynamo.optimize("inductor", dynamic=True)(fn)
inp = torch.randn(2, 4, 6).to(self.device)
torch._dynamo.mark_dynamic(inp, 0)
torch._dynamo.mark_dynamic(inp, 1)

with torch.no_grad():
self.assertEqual(fn(inp), fn_opt(inp))
inp2 = torch.randn(3, 5, 6).to(self.device)
torch._dynamo.mark_dynamic(inp2, 0)
torch._dynamo.mark_dynamic(inp2, 1)
self.assertEqual(fn(inp2), fn_opt(inp2))

def test_param_deallocated(self):
# TODO: cpu path keeps an extra copy of graph around somewhere,
# memory not as important for cpu
if self.device == "cpu":
raise unittest.SkipTest("NYI CPU")

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([10, 10]))

def forward(self, x):
return (self.param + 10) + x

mod = Mod().eval().to(self.device)
inp = torch.rand([10], device=self.device)

with torch.no_grad():
eager = mod(inp)

weight_ref = weakref.ref(mod.param)

@torch.compile()
def foo(mod, inp):
return mod(inp)

with torch.no_grad():
compiled = foo(mod, inp)

self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)


if HAS_CPU and not torch.backends.mps.is_available():

class CpuTests(TestCase):
common = check_model
device = "cpu"
autocast = torch.cpu.amp.autocast

copy_tests(OptimizeForInferenceTemplate, CpuTests, "cpu")

if HAS_CUDA and not TEST_WITH_ASAN:

class CudaTests(TestCase):
common = check_model_cuda
device = "cuda"
autocast = torch.cuda.amp.autocast

copy_tests(OptimizeForInferenceTemplate, CudaTests, "cuda")


del OptimizeForInferenceTemplate

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
run_tests(needs="filelock")
18 changes: 10 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,9 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
fake_mode = detect_fake_mode()
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
flat_args.extend([seed, offset])

if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = fw_metadata
eellison marked this conversation as resolved.
Show resolved Hide resolved
compiled_fw = compiler(fw_module, flat_args)

# This boxed_call handling happens inside create_runtime_wrapper as well.
Expand Down Expand Up @@ -2766,13 +2769,9 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
# 1) There is a check in the the debug compiler at the end
# 2) It does not matter as these are fake tensors

# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs
if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = fw_metadata

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down Expand Up @@ -3634,7 +3633,7 @@ def aot_module_simplified(
**dict(mod.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = tuple(params_flat)
params_flat = list(params_flat)
params_len = len(params_flat)

functional_call = create_functional_call(mod, params_spec, params_len)
Expand All @@ -3650,6 +3649,9 @@ def aot_module_simplified(
# First, the params
full_args.extend(params_flat)

if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().params_flat = params_flat

aot_autograd_arg_pos_to_source = None
# Then, the params 1:1 mapped sources, if relevant.
if hasattr(mod, "_param_name_to_source"):
Expand Down
3 changes: 3 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def __init__(self, fake_mode):
self.fake_mode = fake_mode
self.frame_summary_stack = []
self.loc_in_frame = None
# this is only set after aot_autograd
self.fw_metadata = None
self.params_flat = None

@staticmethod
def extract_stack():
Expand Down