@@ -93,12 +93,6 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
93
93
return feature .lengths ()
94
94
95
95
96
- @torch .fx .wrap
97
- def _get_kjt_keys (feature : KeyedJaggedTensor ) -> List [str ]:
98
- # this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
99
- return feature .keys ()
100
-
101
-
102
96
def for_each_module_of_type_do (
103
97
module : nn .Module ,
104
98
module_types : List [Type [torch .nn .Module ]],
@@ -463,6 +457,42 @@ def __init__(
463
457
if register_tbes :
464
458
self .tbes : torch .nn .ModuleList = torch .nn .ModuleList (self ._emb_modules )
465
459
460
+ self ._has_uninitialized_kjt_permute_order : bool = True
461
+ self ._has_features_permute : bool = True
462
+ self ._features_order : List [int ] = []
463
+
464
+ def _permute_kjt_order (
465
+ self , features : KeyedJaggedTensor
466
+ ) -> List [KeyedJaggedTensor ]:
467
+ if self ._has_uninitialized_kjt_permute_order :
468
+ kjt_keys = features .keys ()
469
+ for f in self .feature_names :
470
+ self ._features_order .append (kjt_keys .index (f ))
471
+
472
+ self .register_buffer (
473
+ "_features_order_tensor" ,
474
+ torch .tensor (
475
+ self ._features_order ,
476
+ device = features .device (),
477
+ dtype = torch .int32 ,
478
+ ),
479
+ persistent = False ,
480
+ )
481
+
482
+ self ._features_order = (
483
+ []
484
+ if self ._features_order == list (range (len (self ._features_order )))
485
+ else self ._features_order
486
+ )
487
+
488
+ self ._has_uninitialized_kjt_permute_order = False
489
+
490
+ if self ._features_order :
491
+ kjt_permute = features .permute (
492
+ self ._features_order , self ._features_order_tensor
493
+ )
494
+ return kjt_permute .split (self ._feature_splits )
495
+
466
496
def forward (
467
497
self ,
468
498
features : KeyedJaggedTensor ,
@@ -476,10 +506,7 @@ def forward(
476
506
"""
477
507
478
508
embeddings = []
479
- kjt_keys = _get_kjt_keys (features )
480
- kjt_permute_order = [kjt_keys .index (k ) for k in self ._feature_names ]
481
- kjt_permute = features .permute (kjt_permute_order )
482
- kjts_per_key = kjt_permute .split (self ._feature_splits )
509
+ kjts_per_key = self ._permute_kjt_order (features )
483
510
484
511
for i , (emb_op , _ ) in enumerate (
485
512
zip (self ._emb_modules , self ._key_to_tables .keys ())
0 commit comments