Skip to content

Commit 2e98e29

Browse files
committed
Update on "Grab Current Tracing Fake Mode in a couple spots"
Fix for #99286. There were a couple locations we were instantiating new fake modes instead of grabbing the correct one from the current tracing context/inputs. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
2 parents 5c084dd + 777800f commit 2e98e29

File tree

20 files changed

+333
-97
lines changed

20 files changed

+333
-97
lines changed

.github/workflows/_mac-build.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,6 @@ jobs:
174174
export PATH="$CONDA_ENV/bin":$PATH
175175
fi
176176
177-
# NB: Same trick as Linux, there is no need to initialize sccache with the risk of getting
178-
# it hangs or timeout at initialization. The cache will be started automatically
179-
export SKIP_SCCACHE_INITIALIZATION=1
180177
${CONDA_RUN} .ci/pytorch/macos-build.sh
181178
182179
- name: Archive artifacts into zip

aten/src/ATen/native/TensorShape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2807,7 +2807,7 @@ static std::vector<Tensor> reshape_input_for_column_stack(TensorList tensors) {
28072807
auto transform_lambda = [](const Tensor& input) -> Tensor {
28082808
// reshape 0D or 1D tensor t into (t.numel(), 1)
28092809
if (input.dim() <= 1) {
2810-
return input.reshape_symint({input.sym_numel(), 1});
2810+
return input.reshape({input.numel(), 1});
28112811
}
28122812
return input;
28132813
};

c10/core/Scalar.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class C10_API Scalar {
171171
}
172172

173173
bool isSymbolic() const {
174-
return Tag::HAS_si == tag || Tag::HAS_sd == tag;
174+
return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag;
175175
}
176176

177177
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept {

test/cpp_extensions/open_registration_extension.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,27 @@
22
#include <c10/core/Allocator.h>
33

44
#include <torch/csrc/Device.h>
5+
#include <c10/core/impl/DeviceGuardImplInterface.h>
6+
#include <c10/macros/Macros.h>
57
#include <torch/extension.h>
68

79
#include <ATen/native/cpu/Loops.h>
810
#include <ATen/native/DispatchStub.h>
11+
#include <ATen/native/Resize.h>
912
#include <ATen/EmptyTensor.h>
1013
#include <ATen/core/GeneratorForPrivateuseone.h>
1114

1215
static uint64_t add_counter = 0;
1316
static uint64_t last_saved_value = 0;
1417

18+
// register guard
19+
namespace at {
20+
namespace detail {
21+
22+
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
23+
24+
}} // namespace at::detail
25+
1526
// basic dummy add function
1627
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
1728
add_counter += 1;
@@ -79,6 +90,16 @@ at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride,
7990
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
8091
}
8192

93+
// Some set operations for the basic use case
94+
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
95+
int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
96+
c10::IntArrayRef stride = {};
97+
result.unsafeGetTensorImpl()->set_storage_offset(0);
98+
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : c10::nullopt;
99+
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), new_size, stride_opt, /*resize_storage=*/!result.is_meta());
100+
return result;
101+
}
102+
82103
// This macro does the heavy lifting.
83104
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
84105
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
@@ -94,6 +115,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
94115
m.impl("fill_.Scalar", &custom_fill__scalar);
95116
m.impl("_copy_from", &custom__copy_from);
96117
m.impl("empty_strided", &custom_empty_strided);
118+
m.impl("set_.source_Storage", &custom_set_source_Storage);
97119
}
98120

99121
// This basic implementation doesn't bother dealing with different device indices

test/dynamo/test_export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,25 @@ def forward(self, input):
24962496
tracing_mode="real",
24972497
)
24982498

2499+
@config.patch(
2500+
dynamic_shapes=True,
2501+
capture_dynamic_output_shape_ops=True,
2502+
capture_scalar_outputs=True,
2503+
assume_static_by_default=False,
2504+
)
2505+
def test_sym_contains(self):
2506+
def f(x, y):
2507+
return x.size(0) in y
2508+
2509+
gm, _ = torch._dynamo.export(
2510+
f, torch.ones(2), torch.ones(3), aten_graph=True, tracing_mode="symbolic"
2511+
)
2512+
2513+
true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5))
2514+
false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2))
2515+
self.assertEqual(gm(*true_inp), f(*true_inp))
2516+
self.assertEqual(gm(*false_inp), f(*false_inp))
2517+
24992518

25002519
common_utils.instantiate_parametrized_tests(ExportTests)
25012520

test/functorch/test_aotdispatch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,7 +2509,9 @@ def forward(self, x):
25092509
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25102510
xfail('cholesky_inverse', ''), # could not find kernel
25112511
xfail('cholesky_solve', ''), # could not find kernel
2512+
xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25122513
xfail('combinations', ''), # aten.masked_select.default
2514+
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
25132515
xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition
25142516
xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25152517
xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
@@ -2525,6 +2527,7 @@ def forward(self, x):
25252527
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25262528
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
25272529
xfail('linalg.cond', ''), # Cannot call numel() on tensor with symbolic sizes/strides
2530+
xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
25282531
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
25292532
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/deco...
25302533
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition

test/test_cpp_extensions_open_device_registration.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,28 @@ def test_open_device_registration(self):
162162
self.assertTrue(x.is_foo)
163163
self.assertTrue(hasattr(torch.nn.Module, 'foo'))
164164

165+
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
166+
torch.utils.generate_methods_for_privateuse1_backend(for_tensor=False, for_module=False, for_storage=True)
167+
168+
x = torch.empty(4, 4)
169+
# check TypedStorage
170+
z1 = x.storage()
171+
self.assertFalse(z1.is_foo)
172+
z1 = z1.foo()
173+
self.assertFalse(self.module.custom_add_called())
174+
self.assertTrue(z1.is_foo)
175+
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
176+
z1.foo(torch.device("cpu"))
177+
z1 = z1.cpu()
178+
179+
y = torch.empty(4, 4)
180+
# check UntypedStorage
181+
z2 = y.untyped_storage()
182+
self.assertFalse(z2.is_foo)
183+
z2 = z2.foo()
184+
self.assertFalse(self.module.custom_add_called())
185+
self.assertTrue(z2.is_foo)
186+
165187
def test_open_device_random(self):
166188
torch.utils.rename_privateuse1_backend('foo')
167189
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):

test/test_proxy_tensor.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,17 +1236,6 @@ def f(x, y):
12361236
inp = (torch.randn(8)[3:], torch.randn(5))
12371237
self.assertEqual(fx_g(*inp), f(*inp))
12381238

1239-
def test_sym_contains(self):
1240-
def f(x, y):
1241-
return x.size(0) in y
1242-
1243-
# This shouldn't raise but we need SymBool from
1244-
# https://github.com/pytorch/pytorch/pull/98453
1245-
# then modify this. It should NOT raise a RuntimeError
1246-
# though!
1247-
with self.assertRaisesRegex(NotImplementedError, "item NYI for torch.bool"):
1248-
make_fx(f, tracing_mode="symbolic")(torch.randn(2), torch.randn(3))
1249-
12501239
def _assert_no_guards(self, fx_g, free_symbols):
12511240
assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
12521241
assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
@@ -1369,7 +1358,9 @@ def f(a, b, c, d, e):
13691358
xfail('linalg.eig'),
13701359
xfail('linalg.eigvals'),
13711360
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
1361+
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
13721362
xfail('combinations', ''),
1363+
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
13731364
xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
13741365
xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
13751366
xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
@@ -1386,6 +1377,7 @@ def f(a, b, c, d, e):
13861377
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13871378
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
13881379
xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel
1380+
xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
13891381
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
13901382
xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
13911383
xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct...

third_party/fbgemm

Submodule fbgemm updated 118 files

torch/_dynamo/skipfiles.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,7 @@ def add(import_name: str):
145145
if isinstance(import_name, types.ModuleType):
146146
return add(import_name.__name__)
147147
assert isinstance(import_name, str)
148-
try:
149-
module_spec = importlib.util.find_spec(import_name)
150-
except Exception:
151-
return
148+
module_spec = importlib.util.find_spec(import_name)
152149
if not module_spec:
153150
return
154151
origin = module_spec.origin
@@ -191,7 +188,6 @@ def check(filename, allow_torch=False):
191188
"tensorflow",
192189
"tensorrt",
193190
"torch2trt",
194-
"torchrec.distributed",
195191
"tqdm",
196192
"tree",
197193
"tvm",

0 commit comments

Comments
 (0)