Skip to content

Commit 2200dc9

Browse files
gnahzgfacebook-github-bot
authored andcommitted
QEBC permute order caching
Summary: As titiled. This optimization is for CPU bound model. Copy implementation from SQEBC D56455884 Differential Revision: D57339912
1 parent a29c82e commit 2200dc9

File tree

1 file changed

+37
-10
lines changed

1 file changed

+37
-10
lines changed

torchrec/quant/embedding_modules.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,6 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
9393
return feature.lengths()
9494

9595

96-
@torch.fx.wrap
97-
def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
98-
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
99-
return feature.keys()
100-
101-
10296
def for_each_module_of_type_do(
10397
module: nn.Module,
10498
module_types: List[Type[torch.nn.Module]],
@@ -463,6 +457,42 @@ def __init__(
463457
if register_tbes:
464458
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)
465459

460+
self._has_uninitialized_kjt_permute_order: bool = True
461+
self._has_features_permute: bool = True
462+
self._features_order: List[int] = []
463+
464+
def _permute_kjt_order(
465+
self, features: KeyedJaggedTensor
466+
) -> List[KeyedJaggedTensor]:
467+
if self._has_uninitialized_kjt_permute_order:
468+
kjt_keys = features.keys()
469+
for f in self.feature_names:
470+
self._features_order.append(kjt_keys.index(f))
471+
472+
self.register_buffer(
473+
"_features_order_tensor",
474+
torch.tensor(
475+
self._features_order,
476+
device=features.device(),
477+
dtype=torch.int32,
478+
),
479+
persistent=False,
480+
)
481+
482+
self._features_order = (
483+
[]
484+
if self._features_order == list(range(len(self._features_order)))
485+
else self._features_order
486+
)
487+
488+
self._has_uninitialized_kjt_permute_order = False
489+
490+
if self._features_order:
491+
kjt_permute = features.permute(
492+
self._features_order, self._features_order_tensor
493+
)
494+
return kjt_permute.split(self._feature_splits)
495+
466496
def forward(
467497
self,
468498
features: KeyedJaggedTensor,
@@ -476,10 +506,7 @@ def forward(
476506
"""
477507

478508
embeddings = []
479-
kjt_keys = _get_kjt_keys(features)
480-
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
481-
kjt_permute = features.permute(kjt_permute_order)
482-
kjts_per_key = kjt_permute.split(self._feature_splits)
509+
kjts_per_key = self._permute_kjt_order(features)
483510

484511
for i, (emb_op, _) in enumerate(
485512
zip(self._emb_modules, self._key_to_tables.keys())

0 commit comments

Comments
 (0)