Skip to content

Commit

Permalink
QEBC permute order caching
Browse files Browse the repository at this point in the history
Summary: As titiled. This optimization is for CPU bound model. Copy implementation from SQEBC D56455884

Differential Revision: D57339912
  • Loading branch information
gnahzg authored and facebook-github-bot committed May 23, 2024
1 parent a29c82e commit 2200dc9
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
return feature.lengths()


@torch.fx.wrap
def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return feature.keys()


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -463,6 +457,42 @@ def __init__(
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)

self._has_uninitialized_kjt_permute_order: bool = True
self._has_features_permute: bool = True
self._features_order: List[int] = []

def _permute_kjt_order(
self, features: KeyedJaggedTensor
) -> List[KeyedJaggedTensor]:
if self._has_uninitialized_kjt_permute_order:
kjt_keys = features.keys()
for f in self.feature_names:
self._features_order.append(kjt_keys.index(f))

self.register_buffer(
"_features_order_tensor",
torch.tensor(
self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)

self._features_order = (
[]
if self._features_order == list(range(len(self._features_order)))
else self._features_order
)

self._has_uninitialized_kjt_permute_order = False

if self._features_order:
kjt_permute = features.permute(
self._features_order, self._features_order_tensor
)
return kjt_permute.split(self._feature_splits)

def forward(
self,
features: KeyedJaggedTensor,
Expand All @@ -476,10 +506,7 @@ def forward(
"""

embeddings = []
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)
kjts_per_key = self._permute_kjt_order(features)

for i, (emb_op, _) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
Expand Down

0 comments on commit 2200dc9

Please sign in to comment.