Skip to content

Commit

Permalink
Back out "QEBC permute order caching" (#2156)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2156

Original commit changeset: 9eb716e27f18

Original Phabricator Diff: D57339912

This diff breaks a silvertorch model publish
https://fb.workplace.com/groups/silvertorch/posts/892246682627925/?comment_id=892279845957942

Differential Revision: D58887074

fbshipit-source-id: 9679deb7b7b370603fabe1f4bbb01814520e18ed
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 22, 2024
1 parent 827fad3 commit 82c6977
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 52 deletions.
46 changes: 7 additions & 39 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,9 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:


@torch.fx.wrap
def _feature_permute(
features: KeyedJaggedTensor,
features_order: List[int],
features_order_tensor: torch.Tensor,
) -> KeyedJaggedTensor:
return features.permute(
features_order,
features_order_tensor,
)
def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return feature.keys()


def for_each_module_of_type_do(
Expand Down Expand Up @@ -469,35 +463,6 @@ def __init__(
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)

self._has_uninitialized_kjt_permute_order: bool = True
self._has_features_permute: bool = True
self._features_order: List[int] = []

def _permute_kjt_order(
self, features: KeyedJaggedTensor
) -> List[KeyedJaggedTensor]:
if self._has_uninitialized_kjt_permute_order:
kjt_keys = features.keys()
for f in self._feature_names:
self._features_order.append(kjt_keys.index(f))

self.register_buffer(
"_features_order_tensor",
torch.tensor(
self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)

self._has_uninitialized_kjt_permute_order = False

features_permute = _feature_permute(
features, self._features_order, self._features_order_tensor
)
return features_permute.split(self._feature_splits)

def forward(
self,
features: KeyedJaggedTensor,
Expand All @@ -511,7 +476,10 @@ def forward(
"""

embeddings = []
kjts_per_key = self._permute_kjt_order(features)
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

for i, (emb_op, _) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
Expand Down
25 changes: 12 additions & 13 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,6 @@ def test_trace_and_script(self) -> None:

qebc = QuantEmbeddingBagCollection.from_float(ebc)

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
)

# need to first run the model once to initialize lazily modules
original_out = qebc(features)

from torchrec.fx import symbolic_trace

gm = symbolic_trace(qebc, leaf_modules=[ComputeKJTToJTDict.__name__])
Expand All @@ -468,14 +459,22 @@ def test_trace_and_script(self) -> None:
)
self.assertEqual(
non_placeholder_nodes[0].op,
"get_attr",
f"First non-placeholder node must be get_attr, got {non_placeholder_nodes[0].op} instead",
"call_function",
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
"_features_order_tensor",
f"First non-placeholder node must be '_features_order_tensor', got {non_placeholder_nodes[0].name} instead",
"_get_kjt_keys",
f"First non-placeholder node must be '_get_kjt_keys', got {non_placeholder_nodes[0].name} instead",
)

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
)

original_out = qebc(features)
traced_out = gm(features)

scripted_module = torch.jit.script(gm)
Expand Down

0 comments on commit 82c6977

Please sign in to comment.