Skip to content

Commit

Permalink
Update on "[BE][SparseAdam] cleaner way to verify no sparse params"
Browse files Browse the repository at this point in the history
Context:

#47724 fixed the problem that SparseAdam could not handle generators by using the `list(...)` construct. However, this meant that SparseAdam deviated from other optimizers in that it could _accept_ a raw Tensors/Parameter vs requiring a container of them. This is not really a big deal.

So why this PR?

I do think this PR is cleaner. It uses the fact that the Optimizer parent class already containerizes parameters into parameter groups, so we could reuse that here by calling `super().__init__` first and then filter the param_groups after. This change would also make SparseAdam consistent with the rest of our optimizers in that only containerized params are accepted, which technically is BC breaking but would be minor, I believe.




[ghstack-poisoned]
  • Loading branch information
janeyx99 committed Nov 28, 2023
2 parents 9db4236 + 24cd0a0 commit 8082734
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 7 additions & 0 deletions test/optim/test_optim.py
Expand Up @@ -1296,6 +1296,13 @@ def test_sparse_adam(self):
sparse_only=True,
maximize=True,
)
import warnings
with warnings.catch_warnings(record=True) as w:
SparseAdam(torch.zeros(3))
self.assertEqual(len(w), 1)
for warning in w:
self.assertEqual(len(warning.message.args), 1)
self.assertRegex(warning.message.args[0], "Passing in a raw Tensor is deprecated.")
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
Expand Down
11 changes: 8 additions & 3 deletions torch/optim/optimizer.py
Expand Up @@ -255,9 +255,14 @@ def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
self._patch_step_function()

if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
if self.__class__.__name__ == 'SparseAdam':
warnings.warn(("Passing in a raw Tensor is deprecated. In the future, "
"this will raise an error. Please wrap your Tensor in "
"an iterable instead."), UserWarning)
else:
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))

self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []
Expand Down

0 comments on commit 8082734

Please sign in to comment.