Skip to content

Commit

Permalink
QEBC permute order caching (#2036)
Browse files Browse the repository at this point in the history
Summary:

As titiled. This optimization is for CPU bound model. Copy implementation from SQEBC D56455884

Differential Revision: D57339912
  • Loading branch information
gnahzg authored and facebook-github-bot committed May 24, 2024
1 parent 8c7fa2f commit 9f1491a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 19 deletions.
56 changes: 49 additions & 7 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,19 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:


@torch.fx.wrap
def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return feature.keys()
def _feature_permute(
features: KeyedJaggedTensor,
features_order: List[int],
features_order_tensor: torch.Tensor,
) -> KeyedJaggedTensor:
return (
features.permute(
features_order,
features_order_tensor,
)
if features_order
else features
)


def for_each_module_of_type_do(
Expand Down Expand Up @@ -463,6 +473,41 @@ def __init__(
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)

self._has_uninitialized_kjt_permute_order: bool = True
self._has_features_permute: bool = True
self._features_order: List[int] = []

def _permute_kjt_order(
self, features: KeyedJaggedTensor
) -> List[KeyedJaggedTensor]:
if self._has_uninitialized_kjt_permute_order:
kjt_keys = features.keys()
for f in self._feature_names:
self._features_order.append(kjt_keys.index(f))

self.register_buffer(
"_features_order_tensor",
torch.tensor(
self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)

self._features_order = (
[]
if self._features_order == list(range(len(self._features_order)))
else self._features_order
)

self._has_uninitialized_kjt_permute_order = False

features = _feature_permute(
features, self._features_order, self._features_order_tensor
)
return features.split(self._feature_splits)

def forward(
self,
features: KeyedJaggedTensor,
Expand All @@ -476,10 +521,7 @@ def forward(
"""

embeddings = []
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)
kjts_per_key = self._permute_kjt_order(features)

for i, (emb_op, _) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
Expand Down
25 changes: 13 additions & 12 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,15 @@ def test_trace_and_script(self) -> None:

qebc = QuantEmbeddingBagCollection.from_float(ebc)

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
)

# need to first run the model once to initialize lazily modules
original_out = qebc(features)

from torchrec.fx import symbolic_trace

gm = symbolic_trace(qebc, leaf_modules=[ComputeKJTToJTDict.__name__])
Expand All @@ -459,22 +468,14 @@ def test_trace_and_script(self) -> None:
)
self.assertEqual(
non_placeholder_nodes[0].op,
"call_function",
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
"get_attr",
f"First non-placeholder node must be get_attr, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
"_get_kjt_keys",
f"First non-placeholder node must be '_get_kjt_keys', got {non_placeholder_nodes[0].name} instead",
"_features_order_tensor",
f"First non-placeholder node must be '_features_order_tensor', got {non_placeholder_nodes[0].name} instead",
)

features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
)

original_out = qebc(features)
traced_out = gm(features)

scripted_module = torch.jit.script(gm)
Expand Down

0 comments on commit 9f1491a

Please sign in to comment.