Skip to content

Commit

Permalink
Update base for Update on "[inductor] fix cpp_wrapper inputs mismatch"
Browse files Browse the repository at this point in the history
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
  • Loading branch information
desertfire committed Dec 20, 2023
2 parents c215e59 + 16e539e commit 9006bff
Show file tree
Hide file tree
Showing 32 changed files with 334 additions and 157 deletions.
1 change: 1 addition & 0 deletions .github/merge_rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@

- name: x86 CPU quantization
patterns:
- aten/src/ATen/native/quantized/cpu/**
- torch/ao/quantization/quantizer/x86_inductor_quantizer.py
- torch/_inductor/fx_passes/quantization.py
- test/quantization/core/test_quantized_op.py
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,13 @@ static inline Tensor applySelect(
}

auto size = (*self_sizes)[dim];
// Note: `size >= -index` is not equivalent to `size > -1 - index` if index
// is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
// minus is undefined by the standard but in practice is equal to self. On
// the other hand, indexing wraping is valid for all negative int64_t
// values, as x[INT64_MIN] is the same as x[INT64_MAX]
TORCH_CHECK_INDEX(
size >= -index && size > index,
size > -1 - index && size > index,
"index ",
index,
" is out of bounds for dimension ",
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,7 +1828,11 @@ Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) {
}
dim = maybe_wrap_dim(dim, ndim);
auto size = self.sym_sizes()[dim];
if (size < -index || size <= index) {
// Note: `size < -index` is not equivalent to `size <= -1 - index` if index is INT64_MIN
// For std::numeric_limits<int64_t>::min() result of unary minus is undefined by the standard
// but in practice is equal to self. On the other hand, indexing wraping is valid for all
// negative int64_t values, as x[INT64_MIN] is the same as x[INT64_MAX]
if (size <= -1 - index || size <= index) {
if (self.has_names() && self.names()[dim] != Dimname::wildcard()) {
TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
self.sizes(), " at dimension ", self.names()[dim]);
Expand Down
4 changes: 2 additions & 2 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ lazy_tensor_core_python_sources = [
]

inductor_core_resources = [
"torch/csrc/inductor/aoti_model_container_runner.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
Expand Down Expand Up @@ -652,7 +652,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/CudaIPCTypes.cpp",
"torch/csrc/cuda/comm.cpp",
"torch/csrc/cuda/memory_snapshot.cpp",
"torch/csrc/inductor/aoti_model_container_runner_cuda.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
"torch/csrc/profiler/stubs/cuda.cpp",
Expand Down
6 changes: 3 additions & 3 deletions docs/source/torch.compiler_aot_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ previous step, enabling us to conduct model predictions directly within a C++ en
The following code snippet assumes your system has a CUDA-enabled device and your model was
compiled to run on CUDA as shown previously.
In the absence of a GPU, it's necessary to make these adjustments in order to run it on a CPU:
1. Modify ``aoti_model_container_runner_cuda.h`` to ``aoti_model_container_runner.h``
2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunner``
1. Change ``model_container_runner_cuda.h`` to ``model_container_runner_cpu.h``
2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunnerCpu``
3. Change ``at::kCUDA`` to ``at::kCPU``

.. code-block:: cpp
Expand All @@ -100,7 +100,7 @@ previous step, enabling us to conduct model predictions directly within a C++ en
#include <vector>
#include <torch/torch.h>
#include <torch/csrc/inductor/aoti_model_container_runner_cuda.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
int main() {
c10::InferenceMode mode;
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ def main():
"include/torch/csrc/distributed/autograd/rpc_messages/*.h",
"include/torch/csrc/dynamo/*.h",
"include/torch/csrc/inductor/*.h",
"include/torch/csrc/inductor/aoti_runner/*.h",
"include/torch/csrc/inductor/aoti_runtime/*.h",
"include/torch/csrc/inductor/aoti_torch/*.h",
"include/torch/csrc/inductor/aoti_torch/c/*.h",
Expand Down
33 changes: 33 additions & 0 deletions test/autograd/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Owner(s): ["module: autograd"]

import logging

import torch
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test


class TestAutogradLogging(LoggingTestCase):
@make_logging_test(autograd=logging.INFO)
def test_logging(self, records):
a = torch.rand(10, requires_grad=True)
b = a.mul(2).div(3).sum()
c = b.clone()
torch.autograd.backward((b, c))

self.assertEqual(len(records), 5)
expected = [
"CloneBackward0",
"SumBackward0",
"DivBackward0",
"MulBackward0",
"AccumulateGrad",
]

for i, record in enumerate(records):
self.assertIn(expected[i], record.getMessage())


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
4 changes: 2 additions & 2 deletions test/cpp/aot_inductor/aoti_custom_class.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <stdexcept>

#include <torch/csrc/inductor/aoti_model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_model_container_runner_cuda.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif

#include "aoti_custom_class.h"
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/aot_inductor/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include <string>
#include <vector>

#include <torch/csrc/inductor/aoti_model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_model_container_runner_cuda.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/script.h>
#include <torch/torch.h>
Expand Down
36 changes: 21 additions & 15 deletions test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))

def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand All @@ -197,7 +199,7 @@ def gn(x, y):

def fn(x, y):
# This goes through VariableBuilder
return checkpoint(gn, torch.sin(x), y)
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand Down Expand Up @@ -236,9 +238,9 @@ def gn(x, y):

def fn(x, y):
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(z)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z

x = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand All @@ -264,7 +266,9 @@ def forward(self, x):
mod = MockModule().cuda()

def fn(x):
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)

x = torch.randn(10, 10, device="cuda", requires_grad=True)

Expand All @@ -291,7 +295,9 @@ def forward(self, x):
mod = MockModule().cuda()

def fn(x):
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)

x = torch.randn(10, 10, device="cuda", requires_grad=True)

Expand All @@ -318,9 +324,9 @@ def gn(x, y):

def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z

x = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand All @@ -344,9 +350,9 @@ def gn(x, y):

def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
# x = torch.utils.checkpoint.checkpoint(gn, x, y)
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return x

x = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand Down Expand Up @@ -377,7 +383,7 @@ def forward(self, x):
mod = MockModule().cuda()

def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x)
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)

x = torch.randn(10, 10, device="cuda", requires_grad=True)
backend = "inductor"
Expand Down Expand Up @@ -452,7 +458,7 @@ def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))

def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y)
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)

backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
Expand Down Expand Up @@ -859,7 +865,7 @@ def fn(primals_1, primals_2, primals_3):
)[0]

def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args)
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)

with torch.cuda.amp.autocast():
x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
Expand Down Expand Up @@ -890,7 +896,7 @@ def forward(self, x):
mod = MockModule().cuda()

def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x)
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)

x = torch.randn(4, 4).cuda()
opt_fn = torch.compile(fn, fullgraph=True)
Expand All @@ -915,7 +921,7 @@ def forward(self, x, ys):
mod = MockModule().cuda()

def fn(x, ys):
return torch.utils.checkpoint.checkpoint(mod, x, ys)
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)

x = torch.randn(4, 4).cuda()
y = torch.randn(4, 4).cuda()
Expand Down
28 changes: 22 additions & 6 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,7 +3500,9 @@ def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))

def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)

x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
Expand All @@ -3520,7 +3522,11 @@ def gn(x, y):

def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
gn,
torch.sin(x),
y,
use_reentrant=True,
preserve_rng_state=False,
)

x = torch.randn(4, 4, requires_grad=True)
Expand All @@ -3540,7 +3546,9 @@ def gn(x, y):
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)

def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand All @@ -3563,7 +3571,9 @@ def gn(x, y):
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)

def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y)
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)

x = torch.randn(4, 4, device="cuda", requires_grad=True)
y = torch.randn(4, 4, device="cuda", requires_grad=True)
Expand All @@ -3581,7 +3591,11 @@ def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))

def fn(x, y):
return torch.cos(torch.utils.checkpoint.checkpoint(gn, torch.sin(x), y))
return torch.cos(
torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
),
)

x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
Expand Down Expand Up @@ -3614,7 +3628,9 @@ def forward(self, x):
mod = MockModule()

def fn(x):
return torch.utils.checkpoint.checkpoint(mod, torch.sin(x))
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)

x = torch.randn(10, 10, requires_grad=True)

Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7668,7 +7668,7 @@ def fn(xs):

# use checkpoint to trigger a "sourceless" tensor subclass
def checkpoint_fn(xs):
return checkpoint(fn, xs)
return checkpoint(fn, xs, use_reentrant=True)

xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))

Expand Down
Loading

0 comments on commit 9006bff

Please sign in to comment.