From 1f99f750d716ed3f5e1f7af0185977006f7a75a3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 12 Sep 2023 14:37:36 +0000 Subject: [PATCH] Revert "Force synced KJT to trace unbacked SymInt (#108960)" This reverts commit f9a250c35bd061e2e6f4c2d92e2b1b16390e8636. Reverted https://github.com/pytorch/pytorch/pull/108960 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/108960#issuecomment-1715850779)) --- test/dynamo/test_torchrec.py | 204 ------------------------ torch/_dynamo/variables/builder.py | 17 +- torch/_dynamo/variables/user_defined.py | 7 +- torch/_guards.py | 24 --- 4 files changed, 5 insertions(+), 247 deletions(-) delete mode 100644 test/dynamo/test_torchrec.py diff --git a/test/dynamo/test_torchrec.py b/test/dynamo/test_torchrec.py deleted file mode 100644 index 817d81b9a23da..0000000000000 --- a/test/dynamo/test_torchrec.py +++ /dev/null @@ -1,204 +0,0 @@ -# Owner(s): ["module: dynamo"] -import sys -import unittest -from typing import Dict, List - -import torch -import torch._dynamo.test_case -from torch import nn - -from torch._dynamo.test_case import TestCase -from torch._dynamo.testing import CompileCounter -from torch.testing._internal.common_utils import NoTest - -try: - from torchrec.datasets.random import RandomRecDataset - from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor - - HAS_TORCHREC = True -except ImportError: - HAS_TORCHREC = False - - -class BucketizeMod(torch.nn.Module): - def __init__(self, feature_boundaries: Dict[str, List[float]]): - super().__init__() - self.bucket_w = torch.nn.ParameterDict() - self.boundaries_dict = {} - for key, boundaries in feature_boundaries.items(): - self.bucket_w[key] = torch.nn.Parameter( - torch.empty([len(boundaries) + 1]).fill_(1.0), - requires_grad=True, - ) - buf = torch.tensor(boundaries, requires_grad=False) - self.register_buffer( - f"{key}_boundaries", - buf, - persistent=False, - ) - self.boundaries_dict[key] = buf - - def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": - weights_list = [] - for key, boundaries in self.boundaries_dict.items(): - jt = features[key] - bucketized = torch.bucketize(jt.weights(), boundaries) - # doesn't super matter I guess - # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries)) - hashed = bucketized - weights = torch.gather(self.bucket_w[key], dim=0, index=hashed) - weights_list.append(weights) - return KeyedJaggedTensor( - keys=features.keys(), - values=features.values(), - weights=torch.cat(weights_list), - lengths=features.lengths(), - offsets=features.offsets(), - stride=features.stride(), - length_per_key=features.length_per_key(), - ) - - -if not HAS_TORCHREC: - print("torchrec not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 - - -@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") -class TorchRecTests(TestCase): - def test_pooled(self): - tables = [ - (nn.EmbeddingBag(2000, 8), ["a0", "b0"]), - (nn.EmbeddingBag(2000, 8), ["a1", "b1"]), - (nn.EmbeddingBag(2000, 8), ["b2"]), - ] - - embedding_groups = { - "a": ["a0", "a1"], - "b": ["b0", "b1", "b2"], - } - - counter = CompileCounter() - - @torch.compile(backend=counter, fullgraph=True, dynamic=True) - def f(id_list_features: KeyedJaggedTensor): - id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() - pooled_embeddings = {} - # TODO: run feature processor - for emb_module, feature_names in tables: - features_dict = id_list_jt_dict - for feature_name in feature_names: - f = features_dict[feature_name] - pooled_embeddings[feature_name] = emb_module( - f.values(), f.offsets() - ) - - pooled_embeddings_by_group = {} - for group_name, group_embedding_names in embedding_groups.items(): - group_embeddings = [ - pooled_embeddings[name] for name in group_embedding_names - ] - pooled_embeddings_by_group[group_name] = torch.cat( - group_embeddings, dim=1 - ) - - return pooled_embeddings_by_group - - dataset = RandomRecDataset( - keys=["a0", "a1", "b0", "b1", "b2"], - batch_size=4, - hash_size=2000, - ids_per_feature=3, - num_dense=0, - ) - di = iter(dataset) - - # unsync should work - - d1 = next(di).sparse_features.unsync() - d2 = next(di).sparse_features.unsync() - d3 = next(di).sparse_features.unsync() - - r1 = f(d1) - r2 = f(d2) - r3 = f(d3) - - self.assertEqual(counter.frame_count, 1) - counter.frame_count = 0 - - # sync should work too - - d1 = next(di).sparse_features.sync() - d2 = next(di).sparse_features.sync() - d3 = next(di).sparse_features.sync() - - r1 = f(d1) - r2 = f(d2) - r3 = f(d3) - - self.assertEqual(counter.frame_count, 1) - - # export only works with unsync - - gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module - gm.print_readable() - - self.assertEqual(gm(d1), r1) - self.assertEqual(gm(d2), r2) - self.assertEqual(gm(d3), r3) - - def test_bucketize(self): - mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]}) - features = KeyedJaggedTensor.from_lengths_sync( - keys=["f1"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - lengths=torch.tensor([2, 0, 1, 1, 1, 3]), - weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), - ).unsync() - - def f(x): - # This is a trick to populate the computed cache and instruct - # ShapeEnv that they're all sizey - x.to_dict() - return mod(x) - - torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable() - - @unittest.expectedFailure - def test_simple(self): - jag_tensor1 = KeyedJaggedTensor( - values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - keys=["index_0", "index_1"], - lengths=torch.tensor([0, 0, 1, 1, 1, 3]), - ).sync() - - # ordinarily, this would trigger one specialization - self.assertEqual(jag_tensor1.length_per_key(), [1, 5]) - - counter = CompileCounter() - - @torch._dynamo.optimize(counter, nopython=True) - def f(jag_tensor): - # The indexing here requires more symbolic reasoning - # and doesn't work right now - return jag_tensor["index_0"].values().sum() - - f(jag_tensor1) - - self.assertEqual(counter.frame_count, 1) - - jag_tensor2 = KeyedJaggedTensor( - values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - keys=["index_0", "index_1"], - lengths=torch.tensor([2, 0, 1, 1, 1, 3]), - ).sync() - - f(jag_tensor2) - - self.assertEqual(counter.frame_count, 1) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 79a9691d18dbd..572d94d280f92 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -23,7 +23,6 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch.fx.experimental.symbolic_shapes import ( - _constrain_range_for_size, DimConstraint, DimDynamic, RelaxedUnspecConstraint, @@ -886,7 +885,7 @@ def wrap_literal(self, value): elif unspec and type(value) is int: # unspecializing int by default, but still # specialize for the following conditions - if not TracingContext.get().force_unspec_int_unbacked_size_like and ( + if ( value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() @@ -1084,21 +1083,11 @@ def wrap_unspecialized_primitive(self, value): if self.name in self.tx.output.unspec_variable_map: return self.tx.output.unspec_variable_map[self.name] else: - shape_env = self.tx.output.shape_env - if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance( - value, int - ): - wrapped_value = shape_env.create_unbacked_symint() - _constrain_range_for_size(wrapped_value) - self.tx.output.tracked_fakes.append( - TrackedFake(wrapped_value, self.source, None) - ) - # NB: We do not do float. For motivation, see # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit # but the general idea is that we generate kernels that can # take unspecialized floats and use them in sizevar computation - elif ( + if ( isinstance(value, int) and not is_constant_source(self.get_source()) and not isinstance(self.get_source(), RandomValueSource) @@ -1112,6 +1101,8 @@ def wrap_unspecialized_primitive(self, value): guards=self.make_guards(GuardBuilder.CONSTANT_MATCH), ) + shape_env = self.tx.output.shape_env + name = self.source.name() if name not in self.tx.output.frame_state: # Note - this esentially means that if this name gets reused as a tensor, diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index e99987ecaeec7..57b3eb30a8b45 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -10,7 +10,6 @@ from typing import Dict, List import torch.nn -from torch._guards import TracingContext from .. import variables from ..allowed_functions import is_allowed @@ -590,8 +589,4 @@ def __init__(self, value, **kwargs): assert type(value) is KeyedJaggedTensor super().__init__(value, **kwargs) - def var_getattr(self, tx, name): - if self.source is not None and name in ("_length_per_key", "_offset_per_key"): - with TracingContext.patch(force_unspec_int_unbacked_size_like=True): - return super().var_getattr(tx, name) - return super().var_getattr(tx, name) + # TODO Handle getattr for _length_per_key and _offset_per_key properly. diff --git a/torch/_guards.py b/torch/_guards.py index da986e125ff3f..4bb70fd5e32a7 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -605,30 +605,6 @@ def __init__(self, fake_mode): # you ever do change this in aot_autograd.py; you should check # on permutations preferentially.) self.output_strides: Optional[List[Optional[List[int]]]] = None - # When this is True, whenever we encounter an int in Dynamo tracing, - # we will (1) force unspec it and (2) force it as a size-like unbacked - # integer. This is currently used when processing certain lists of - # ints that are known to be size-like and may have 0/1 entries that we - # must not specialize on. - self.force_unspec_int_unbacked_size_like = False - - @staticmethod - @contextmanager - def patch(**kwargs): - prior = {} - ctx = TracingContext.get() - assert ctx is not None - - for key in kwargs.keys(): - # KeyError on invalid entry - prior[key] = getattr(ctx, key) - for key, val in kwargs.items(): - setattr(ctx, key, val) - try: - yield - finally: - for key, val in prior.items(): - setattr(ctx, key, val) @staticmethod def extract_stack():