Skip to content

Commit

Permalink
add remove_duplicate flag to named_parameters (pytorch#759)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#759

Since the remove_duplicate flag was added to named_buffers in D39493161 (pytorch@672d105), this adds the same flag to named_parameters

Differential Revision: D40801899

fbshipit-source-id: a01807c81f0069ea02bdb95d3f937d458f3d9a5d
  • Loading branch information
samdow authored and facebook-github-bot committed Oct 31, 2022
1 parent be01464 commit 468c379
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 20 deletions.
8 changes: 4 additions & 4 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def named_buffers(
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

Expand Down Expand Up @@ -460,7 +460,7 @@ def named_buffers(
yield from ()

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
combined_key = "/".join(
[config.name for config in self._config.embedding_tables]
Expand Down Expand Up @@ -648,7 +648,7 @@ def named_buffers(
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

Expand Down Expand Up @@ -692,7 +692,7 @@ def named_buffers(
yield from ()

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
combined_key = "/".join(
[config.name for config in self._config.embedding_tables]
Expand Down
32 changes: 24 additions & 8 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,12 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m, unexpected_keys=u)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
assert remove_duplicate, (
"remove_duplicate=False in named_parameters for"
"GroupedEmbeddingsLookup is not supported"
)
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)

Expand Down Expand Up @@ -346,8 +350,12 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
assert remove_duplicate, (
"remove_duplicate=False in named_parameters for"
"GroupedPooledEmbeddingsLookup is not supported"
)
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)
for emb_module in self._score_emb_modules:
Expand Down Expand Up @@ -461,8 +469,12 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m, unexpected_keys=u)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
assert remove_duplicate, (
"remove_duplicate=False in named_buffers for"
"MetaInferGroupedEmbeddingsLookup is not supported"
)
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)

Expand Down Expand Up @@ -617,8 +629,12 @@ def load_state_dict(
return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
assert remove_duplicate, (
"remove_duplicate=False in named_parameters for"
"MetaInferGroupedPooledEmbeddingsLookup is not supported"
)
for emb_module in self._emb_modules:
yield from emb_module.named_parameters(prefix, recurse)
for emb_module in self._score_emb_modules:
Expand Down
7 changes: 4 additions & 3 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ def named_modules(
yield from [(prefix, self)]

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for lookup in self._lookups:
yield from lookup.named_parameters(
append_prefix(prefix, "embedding_bags"), recurse
append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate
)

def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
Expand Down Expand Up @@ -789,8 +789,9 @@ def named_modules(
yield from [(prefix, self)]

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
# TODO: add remove_duplicate
for name, parameter in self._lookup.named_parameters("", recurse):
# update name to match embeddingBag parameter name
yield append_prefix(prefix, name.split(".")[-1]), parameter
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/grouped_position_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
for name, param in self.position_weights.items():
yield append_prefix(prefix, f"position_weights.{name}"), param
Expand Down
5 changes: 3 additions & 2 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,15 @@ def _named_parameters(
)

def named_parameters(
self, prefix: str = "", recurse: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
gen = self._named_parameters(self.module, prefix, recurse)
memo = set()
for key, param in gen:
if param in memo:
continue
memo.add(param)
if remove_duplicate:
memo.add(param)
yield key, param

def bare_named_parameters(
Expand Down
7 changes: 5 additions & 2 deletions torchrec/modules/fused_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ def parameters(
yield cast(nn.Parameter, self._emb_module.weights)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in _BatchedFusedEmbeddingLookups.named_parameters"
for table, weight in zip(
self._embedding_tables, self.split_embedding_weights()
):
Expand Down

0 comments on commit 468c379

Please sign in to comment.