@@ -785,9 +785,7 @@ def purge(self) -> None:
785785 def named_split_embedding_weights (
786786 self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
787787 ) -> Iterator [Tuple [str , torch .Tensor ]]:
788- assert (
789- remove_duplicate
790- ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
788+ assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
791789 for config , param in zip (
792790 self ._config .embedding_tables ,
793791 self .emb_module .split_embedding_weights (),
@@ -899,9 +897,7 @@ def named_parameters(
899897 def named_split_embedding_weights (
900898 self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
901899 ) -> Iterator [Tuple [str , torch .Tensor ]]:
902- assert (
903- remove_duplicate
904- ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
900+ assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
905901 for config , tensor in zip (
906902 self ._config .embedding_tables ,
907903 self .split_embedding_weights (),
@@ -1082,8 +1078,9 @@ def named_parameters(
10821078 combined_key = "/" .join (
10831079 [config .name for config in self ._config .embedding_tables ]
10841080 )
1085- yield append_prefix (prefix , f"{ combined_key } .weight" ), cast (
1086- nn .Parameter , self ._emb_module .weights
1081+ yield (
1082+ append_prefix (prefix , f"{ combined_key } .weight" ),
1083+ cast (nn .Parameter , self ._emb_module .weights ),
10871084 )
10881085
10891086
@@ -1101,7 +1098,8 @@ def __init__(
11011098 self ._pg = pg
11021099
11031100 self ._pooling : PoolingMode = pooling_type_to_pooling_mode (
1104- config .pooling , sharding_type # pyre-ignore[6]
1101+ config .pooling ,
1102+ sharding_type , # pyre-ignore[6]
11051103 )
11061104
11071105 self ._local_rows : List [int ] = []
@@ -1220,9 +1218,7 @@ def purge(self) -> None:
12201218 def named_split_embedding_weights (
12211219 self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
12221220 ) -> Iterator [Tuple [str , torch .Tensor ]]:
1223- assert (
1224- remove_duplicate
1225- ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1221+ assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
12261222 for config , tensor in zip (
12271223 self ._config .embedding_tables ,
12281224 self .emb_module .split_embedding_weights (),
@@ -1362,9 +1358,7 @@ def named_parameters(
13621358 def named_split_embedding_weights (
13631359 self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
13641360 ) -> Iterator [Tuple [str , PartiallyMaterializedTensor ]]:
1365- assert (
1366- remove_duplicate
1367- ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1361+ assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
13681362 for config , tensor in zip (
13691363 self ._config .embedding_tables ,
13701364 self .split_embedding_weights (),
@@ -1567,6 +1561,7 @@ def named_parameters(
15671561 combined_key = "/" .join (
15681562 [config .name for config in self ._config .embedding_tables ]
15691563 )
1570- yield append_prefix (prefix , f"{ combined_key } .weight" ), cast (
1571- nn .Parameter , self ._emb_module .weights
1564+ yield (
1565+ append_prefix (prefix , f"{ combined_key } .weight" ),
1566+ cast (nn .Parameter , self ._emb_module .weights ),
15721567 )
0 commit comments