Skip to content

Commit

Permalink
Update on "aot autograd refactor: make all synthetic base logic layer…
Browse files Browse the repository at this point in the history
…ed in a single location"

This  refactor doesn't significantly change LoC in aot autograd, but I think this nets out to making it clearer (interested in peoples' thoughts).

The idea is that I tried to re-write the part of aot autograd that deals with synthetic bases in a layered way, similar to how Ed wrote the logic for dedup'ing inputs: it happens in one place, and all of the downstream transformation in aot autograd don't have to worry about it.

Specifically, I added a new function `aot_wrapper_synthetic_base`, similar to the existing `aot_wrapper_dedupe`.

The benefit: none of the other code in aot autograd needs to think about synthetic bases (previously, synthetic base code was intertwined in several places).

The downsides: there are two.

(1) `aot_wrapper_synthetic_base()` needs to have its own epilogue. There is one particularly hairy case, where factoring the synthetic base logic to a single location was painful: If you have two inputs that alias each other, where one gets a data mutation, and the other gets a metadata mutation.

Ordinarily, metadata mutations are handled by the runtime epilogue, in `create_runtime_wrapper`. However, now that things are factored this way, the runtime wrapper operates only on synthetic bases instead of operating on the original inputs. For data mutations, it is fine to apply the data mutation to the synthetic base instead of the original input alias. But for metadata mutations, we **need** to apply the metadata mutation directly to the original inputs.

The way that I handled this was by tracking which inputs slot into this specific case (part of a synthetic base, and get metadata mutations), and updateing the flat_fn() that we pass downstream to return these updated inputs as extra outputs. From the perspective of downstream logic, these are real user outputs, that it can treat like any other user outputs. `aot_wrapper_synthetic_base` will know to grab these extra outputs and use them to apply the metadata mutations.

This was pretty annoying, but has the benefit that all of that logic is encapsulated entirely in `aot_wrapper_synthetic_base()`.

(2) input mutations are now performed on the synthetic base instead of the individual aliases.

You can see the original code comment [here](https://github.com/pytorch/pytorch/blob/b0b5f3c6c681896febbd9ff7ad7649b13def345d/torch/_functorch/aot_autograd.py#L1131) for details. We used to do the optimized thing in this case, and now we do the less optimized thing (copying the entire synthetic base, instead of the potentially smaller alias).

To be fair, we had no data showing that this optimization was showing improvements on any models in practice. I also think that the main reason anyone would ever run across this problem is because of a graph break - so if you care about perf, you probably want to avoid the extra graph breaks to begin with. I haven't added any warnings for this, but we probably could depending on what people think.




[ghstack-poisoned]
  • Loading branch information
bdhirsh committed Mar 15, 2023
2 parents d3d8c54 + c5a66e2 commit 6f19e89
Show file tree
Hide file tree
Showing 35 changed files with 9,550 additions and 9,708 deletions.
5 changes: 3 additions & 2 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ test_single_dynamo_benchmark() {
--output "$TEST_REPORTS_DIR/${name}_${suite}.csv"
python benchmarks/dynamo/check_csv.py \
-f "$TEST_REPORTS_DIR/${name}_${suite}.csv"
if [[ "${TEST_CONFIG}" != *cpu_accuracy* ]] && [[ "${TEST_CONFIG}" != *dynamic* ]]; then
# because I haven't tracked the cpu-side or dynamic expected artifacts yet, and need to differentiate filenames
if [[ "${TEST_CONFIG}" == *inductor* ]] && [[ "${TEST_CONFIG}" != *dynamic* ]]; then
# because I haven't dealt with dynamic expected artifacts yet,
# and non-inductor jobs (e.g. periodic) may have different set of expected models.
python benchmarks/dynamo/check_graph_breaks.py \
--actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/${name}_${suite}${shard_id}.csv"
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
98c58158d1bc09e6fab31d3bf1af36e8d1752a89
6c4ff94b56a8a26c150af0cd95f37bf30e1b8eb4
19 changes: 0 additions & 19 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ void set_hpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
}

bool is_privateuseone_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
}

void set_privateuseone_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1, !new_enabled);
}

namespace {
// Imitate Apex and cache some of the casts to streamline parameter reuse.
// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
Expand Down Expand Up @@ -96,9 +88,6 @@ thread_local bool cache_enabled = true;

// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;

// autocast_privateuseone_dtype is the lower_precision_fp used by AutocastPrivateUse1.
thread_local at::ScalarType autocast_privateuseone_dtype = at::kHalf;
}

void clear_cache() {
Expand Down Expand Up @@ -130,10 +119,6 @@ at::ScalarType get_autocast_hpu_dtype() {
return autocast_hpu_dtype;
}

at::ScalarType get_autocast_privateuseone_dtype() {
return autocast_privateuseone_dtype;
}

void set_autocast_cpu_dtype(at::ScalarType dtype) {
TORCH_CHECK(
dtype == at::kBFloat16,
Expand All @@ -153,10 +138,6 @@ void set_autocast_hpu_dtype(at::ScalarType dtype) {
autocast_hpu_dtype = dtype;
}

void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
autocast_privateuseone_dtype = dtype;
}

bool is_autocast_cache_enabled() {
return cache_enabled;
}
Expand Down
11 changes: 0 additions & 11 deletions aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ TORCH_API bool is_hpu_enabled();
TORCH_API void set_hpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_hpu_dtype();
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
TORCH_API bool is_privateuseone_enabled();
TORCH_API void set_privateuseone_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
TORCH_API bool is_autocast_cache_enabled();
TORCH_API void set_autocast_cache_enabled(bool enabled);

Expand All @@ -44,9 +40,6 @@ bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
return tensor.is_xpu() && tensor.is_floating_point();
case DeviceType::HPU:
return tensor.is_hpu() && tensor.is_floating_point();
case DeviceType::PrivateUse1:
return tensor.device().type() == DeviceType::PrivateUse1 &&
tensor.is_floating_point();
default:
return false;
}
Expand All @@ -64,8 +57,6 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastXPU;
case DeviceType::HPU:
return DispatchKey::AutocastHPU;
case DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
Expand All @@ -83,8 +74,6 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
return get_autocast_xpu_dtype();
case DeviceType::HPU:
return get_autocast_hpu_dtype();
case DeviceType::PrivateUse1:
return get_autocast_privateuseone_dtype();
default:
throw std::runtime_error(
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,9 @@ class CI(NamedTuple):
*CI_SKIP[CI("inductor", training=True)],
# torchbench
"pytorch_unet", # TypeError: unhashable type: 'SymInt'
"yolov3", # 'float' object has no attribute '_has_symbolic_sizes_strides'
# timm_models
"eca_botnext26ts_256", # 'float' object has no attribute '_has_symbolic_sizes_strides'
"mixnet_l", # 'float' object has no attribute '_has_symbolic_sizes_strides'
"tf_efficientnet_b0", # 'float' object has no attribute '_has_symbolic_sizes_strides'
"tf_mixnet_l", # 'float' object has no attribute '_has_symbolic_sizes_strides'
"visformer_small", # 'float' object has no attribute '_has_symbolic_sizes_strides'
"rexnet_100", # Accuracy failed for key name stem.bn.weight.grad
"tf_efficientnet_b0", # NameError: name 's1' is not defined
]


Expand Down
3 changes: 0 additions & 3 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ const char* toString(DispatchKey t) {
return "AutocastHPU";
case DispatchKey::AutocastCUDA:
return "AutocastCUDA";
case DispatchKey::AutocastPrivateUse1:
return "AutocastPrivateUse1";

case DispatchKey::FuncTorchBatched:
return "FuncTorchBatched";
Expand Down Expand Up @@ -287,7 +285,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
{"Batched", c10::DispatchKey::Batched},
Expand Down
1 change: 0 additions & 1 deletion c10/core/DispatchKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ enum class DispatchKey : uint16_t {
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
// it probably should get its own Autocast key
AutocastCUDA,
AutocastPrivateUse1,

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// There are a number of alternative modes which may want to handle before
Expand Down
6 changes: 0 additions & 6 deletions c10/core/DispatchKeySet.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,6 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastPrivateUse1,
});

// See Note [TLS Initialization]
Expand All @@ -657,7 +656,6 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastPrivateUse1,
});

constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
Expand Down Expand Up @@ -841,8 +839,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
constexpr auto autocast_privateuse1_ks =
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
Expand All @@ -853,8 +849,6 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
case BackendComponent::CUDABit:
case BackendComponent::XLABit:
return autocast_cuda_ks;
case BackendComponent::PrivateUse1Bit:
return autocast_privateuse1_ks;
default:
return DispatchKeySet();
}
Expand Down
3 changes: 1 addition & 2 deletions caffe2/python/memonger.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def is_grad_blob(b):
name = str(b)
# Note: need to look at _{namescope} pattern as it matches
# to handle the auto-split gradients
return name.endswith("_grad") and (name.startswith(namescope) or
name.startswith("_" + namescope)) and name not in param_grads
return name.endswith("_grad") and (name.startswith((namescope, "_" + namescope))) and name not in param_grads

def is_grad_op(op):
# TODO: something smarter
Expand Down
2 changes: 1 addition & 1 deletion scripts/release_notes/commitlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def categorize(features):
else:
# Below are some extra quick checks that aren't necessarily file-path related,
# but I found that to catch a decent number of extra commits.
if len(files_changed) > 0 and all([f_name.endswith('.cu') or f_name.endswith('.cuh') for f_name in files_changed]):
if len(files_changed) > 0 and all([f_name.endswith(('.cu', '.cuh')) for f_name in files_changed]):
category = 'cuda'
elif '[PyTorch Edge]' in title:
category = 'mobile'
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_api_parity/module_impl_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def run_cpp_test_fn_and_check_output():
param_name = key[:-len(suffix)]
break
assert param_name is not None
sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense'
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'

unit_test_class.assertTrue(
key in cpp_grad_dict,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def shard_fn(name, module, device_mesh):
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
for name, param in dist_module.named_parameters():
self.assertIsInstance(param, DTensor)
if name.startswith("seq.0") or name.startswith("seq.8"):
if name.startswith(("seq.0", "seq.8")):
self.assertEqual(param.placements, shard_spec)
else:
self.assertEqual(param.placements, replica_spec)
Expand Down
40 changes: 0 additions & 40 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,26 +784,6 @@ class DummyXPUModule:
def is_available():
return True

@staticmethod
def is_autocast_foo_enabled():
return True

@staticmethod
def get_autocast_foo_dtype():
return torch.float16

@staticmethod
def set_autocast_foo_enabled(enable):
pass

@staticmethod
def set_autocast_foo_dtype(dtype):
pass

@staticmethod
def get_amp_supported_dtype():
return [torch.float16]


class TestExtensionUtils(TestCase):
def test_external_module_register(self):
Expand All @@ -826,26 +806,6 @@ def test_external_module_register(self):
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
torch._register_device_module('xpu', DummyXPUModule)

def test_external_module_and_backend_register(self):
torch.utils.rename_privateuse1_backend('foo')
with self.assertRaisesRegex(RuntimeError, "has already been set"):
torch.utils.rename_privateuse1_backend('dummmy')

custom_backend_name = torch._C._get_privateuse1_backend_name()
self.assertEqual(custom_backend_name, 'foo')

with self.assertRaises(AttributeError):
torch.foo.is_available()

with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
with torch.autocast(device_type=custom_backend_name):
pass
torch._register_device_module('foo', DummyXPUModule)

torch.foo.is_available()
with torch.autocast(device_type=custom_backend_name):
pass


class TestDeviceUtils(TestCase):
def test_basic(self):
Expand Down
2 changes: 1 addition & 1 deletion third_party/XNNPACK
Submodule XNNPACK updated 1959 files

0 comments on commit 6f19e89

Please sign in to comment.