Skip to content

Commit

Permalink
pull from master, merge activation
Browse files Browse the repository at this point in the history
  • Loading branch information
ynonaolga committed Nov 14, 2022
2 parents 1a89581 + 7aa144a commit 957c79c
Show file tree
Hide file tree
Showing 58 changed files with 1,640 additions and 901 deletions.
11 changes: 5 additions & 6 deletions aten/src/ATen/PythonTorchFunctionTLS.cpp
Expand Up @@ -26,12 +26,12 @@ int64_t PythonTorchFunctionTLS::stack_len() {
return pythonTorchFunctionState.stack_.size();
}

void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) {
pythonTorchFunctionState.disabled_state_ = disabled_state;
void PythonTorchFunctionTLS::set_disabled(bool disabled) {
pythonTorchFunctionState.disabled_ = disabled;
}

TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() {
return pythonTorchFunctionState.disabled_state_;
bool PythonTorchFunctionTLS::is_disabled() {
return pythonTorchFunctionState.disabled_;
}

void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
Expand All @@ -43,8 +43,7 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
}

bool torch_function_mode_enabled() {
return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
PythonTorchFunctionTLS::stack_len() > 0;
return PythonTorchFunctionTLS::stack_len() > 0;
}

} // namespace impl
Expand Down
12 changes: 5 additions & 7 deletions aten/src/ATen/PythonTorchFunctionTLS.h
Expand Up @@ -6,11 +6,9 @@
namespace at {
namespace impl {

enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };

struct TORCH_API PythonTorchFunctionTLS {
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
static TorchFunctionDisabledState get_disabled_state();
static void set_disabled(bool);
static bool is_disabled();

static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
Expand All @@ -22,11 +20,11 @@ struct TORCH_API PythonTorchFunctionTLS {

private:
// The mode TLS is split into
// - disabled_state, which says which part of torch function are disabled
// - disabled_, which says whether or not to disable all torch function
// modes
// - stack_, which is a vector of modes representing the stack of user
// defined modes
TorchFunctionDisabledState disabled_state_ =
TorchFunctionDisabledState::ENABLED;
bool disabled_;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Expand Up @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
OP_DECOMPOSE2(bitwise_or, Scalar);
OP_DECOMPOSE2(bitwise_xor, Scalar);
OP_DECOMPOSE(broadcast_tensors);
OP_DECOMPOSE(broadcast_to);
m.impl("broadcast_to", native::broadcast_to_symint);
OP_DECOMPOSE(cartesian_prod);
OP_DECOMPOSE(cdist);
OP_DECOMPOSE(clip);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced);
}

Tensor broadcast_to(const Tensor& self, IntArrayRef size) {
return self.expand(size);
Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) {
return self.expand_symint(size);
}

std::vector<Tensor> broadcast_tensors(TensorList tensors) {
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1195,8 +1195,10 @@
device_check: NoCheck
device_guard: False

- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
variants: function, method
dispatch:
CompositeImplicitAutograd: broadcast_to_symint

- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
variants: function
Expand Down
1 change: 0 additions & 1 deletion test/allowlist_for_publicAPI.json
Expand Up @@ -1129,7 +1129,6 @@
"ComplexDoubleStorage",
"ComplexFloatStorage",
"DisableTorchFunction",
"DisableTorchFunctionSubclass",
"Generator",
"HalfStorage",
"HalfTensor",
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_state_dict.py
Expand Up @@ -25,7 +25,7 @@
StateDictType,
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.fully_sharded_data_parallel import FLAT_PARAM
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel import DistributedDataParallel
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_summon_full_params.py
Expand Up @@ -212,7 +212,7 @@ def forward(self, fsdp_module):

model = FSDP(MyModule()).cuda(self.rank)
with self.assertRaisesRegex(
ValueError, "current state is TrainingState.FORWARD"
ValueError, "Current handle state is HandleTrainingState.FORWARD"
):
model(model)

Expand All @@ -231,7 +231,7 @@ def bad_backwards_hook(tensor):
output.register_hook(bad_backwards_hook)

with self.assertRaisesRegex(
ValueError, "current state is TrainingState.FORWARD_BACKWARD"
ValueError, "Current handle state is HandleTrainingState.BACKWARD_PRE"
):
output.backward()

Expand Down
12 changes: 2 additions & 10 deletions test/functorch/test_aotdispatch.py
Expand Up @@ -1093,20 +1093,13 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition
xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition
xfail('masked_fill', ''), # could not find kernel
xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ...
xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi...
xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t...
xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
# Seems flaky: https://github.com/pytorch/pytorch/issues/88883
skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc...
xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc...
xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to...
xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to...
xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo...
xfail('median', ''), # could not find kernel
Expand Down Expand Up @@ -1214,7 +1207,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('topk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
26 changes: 26 additions & 0 deletions test/inductor/test_torchinductor.py
Expand Up @@ -4601,6 +4601,8 @@ def fn(a):
CommonTemplate.install(CudaTests, "cuda")

class CudaReproTests(TestCase):
common = check_model_cuda

def test_index_put_issue(self):
def forward(
self,
Expand Down Expand Up @@ -4637,6 +4639,30 @@ def forward(
compiled = compile_fx_inner(mod, inps)
compiled(inps)

@requires_cuda()
def test_input_channels_last(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 1, 1),
ToTuple(),
).cuda()
inp = (
torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
)

self.common(
m,
(inp,),
check_lowp=False,
)

@torch._dynamo.optimize()
def foo(m, inp):
return m(inp)

self.assertTrue(
foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last)
)

# https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
@requires_cuda()
def test_unspec_inputs_interop(self):
Expand Down
77 changes: 61 additions & 16 deletions test/onnx/internal/test_diagnostics.py
Expand Up @@ -3,6 +3,7 @@
import contextlib
import dataclasses
import io
import typing
import unittest
from typing import AbstractSet, Tuple

Expand Down Expand Up @@ -110,23 +111,15 @@ class TestOnnxDiagnostics(common_utils.TestCase):
def setUp(self):
engine = diagnostics.engine
engine.clear()
self._sample_rule = diagnostics.rules.missing_custom_symbolic_function
super().setUp()

def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
with self.assertRaises(AssertionError):
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.node_missing_onnx_shape_inference,
diagnostics.levels.WARNING,
):
pass

def test_cpp_diagnose_emits_warning(self):
def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
self,
) -> diagnostics.ExportDiagnostic:
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x + y

@staticmethod
Expand All @@ -137,14 +130,38 @@ class M(torch.nn.Module):
def forward(self, x):
return CustomAdd.apply(x, x)

# trigger warning for missing shape inference.
rule = diagnostics.rules.node_missing_onnx_shape_inference
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())

context = diagnostics.engine.contexts[-1]
for diagnostic in context.diagnostics:
if (
diagnostic.rule == rule
and diagnostic.level == diagnostics.levels.WARNING
):
return typing.cast(diagnostics.ExportDiagnostic, diagnostic)
raise AssertionError("No diagnostic found.")

def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
with self.assertRaises(AssertionError):
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.node_missing_onnx_shape_inference,
diagnostics.levels.WARNING,
):
pass

def test_cpp_diagnose_emits_warning(self):
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.node_missing_onnx_shape_inference,
diagnostics.levels.WARNING,
):
# trigger warning for missing shape inference.
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()

def test_py_diagnose_emits_error(self):
class M(torch.nn.Module):
Expand All @@ -168,15 +185,43 @@ def forward(self, x):
def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
self,
):
sample_rule = diagnostics.rules.missing_custom_symbolic_function
sample_level = diagnostics.levels.ERROR
with assert_diagnostic(
self,
diagnostics.engine,
sample_rule,
self._sample_rule,
sample_level,
):
diagnostics.context.diagnose(sample_rule, sample_level)
diagnostics.context.diagnose(self._sample_rule, sample_level)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(
self._sample_rule, diagnostics.levels.NOTE
)
stack = diagnostic.python_call_stack
assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0)
frame = stack.frames[0]
assert frame.location.snippet is not None # for mypy
self.assertIn("self._sample_rule", frame.location.snippet)
assert frame.location.uri is not None # for mypy
self.assertIn("test_diagnostics.py", frame.location.uri)

def test_diagnostics_records_cpp_call_stack(self):
diagnostic = (
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
)
stack = diagnostic.cpp_call_stack
assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0)
frame_messages = [frame.location.message for frame in stack.frames]
self.assertTrue(
any(
isinstance(message, str)
and "torch::jit::ONNXShapeTypeInference" in message
for message in frame_messages
)
)


@dataclasses.dataclass
Expand Down

0 comments on commit 957c79c

Please sign in to comment.