Skip to content

Commit

Permalink
[BE][SparseAdam] cleaner way to verify no sparse params (#114425)
Browse files Browse the repository at this point in the history
Context:

#47724 fixed the problem that SparseAdam could not handle generators by using the `list(...)` construct. However, this meant that SparseAdam deviated from other optimizers in that it could _accept_ a raw Tensors/Parameter vs requiring a container of them. This is not really a big deal.

So why this PR?

I do think this PR is cleaner. It uses the fact that the Optimizer parent class already containerizes parameters into parameter groups, so we could reuse that here by calling `super().__init__` first and then filter the param_groups after. This change would also make SparseAdam consistent with the rest of our optimizers in that only containerized params are accepted, which technically is BC breaking SO I've added a deprecation warning that we should remove in May 2024.

(But is it really BC breaking when we've said in the docs that params should be an iterable this whole time? Maybe this is just a bug fix....😛)

Pull Request resolved: #114425
Approved by: https://github.com/drisspg
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Nov 29, 2023
1 parent febbc48 commit 7c1a501
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
8 changes: 8 additions & 0 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,14 @@ def test_sparse_adam(self):
sparse_only=True,
maximize=True,
)
import warnings
with warnings.catch_warnings(record=True) as ws:
SparseAdam(torch.zeros(3))
self.assertEqual(len(ws), 1)
for warning in ws:
self.assertEqual(len(warning.message.args), 1)
self.assertRegex(warning.message.args[0],
"Passing in a raw Tensor as ``params`` to SparseAdam ")
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
Expand Down
12 changes: 9 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,15 @@ def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
self._patch_step_function()

if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
if self.__class__.__name__ == 'SparseAdam':
warnings.warn(("Passing in a raw Tensor as ``params`` to SparseAdam "
"is deprecated. In the future, this will raise an error. "
"Please wrap your Tensor in an iterable instead."),
FutureWarning)
else:
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))

self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []
Expand Down
20 changes: 8 additions & 12 deletions torch/optim/sparse_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,21 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")

params = list(params)
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
super().__init__(params, defaults)

sparse_params = []
for index, param in enumerate(params):
if isinstance(param, dict):
# given param group, convert given params to a list first before iterating
param['params'] = list(param.get("params", []))
for d_index, d_param in enumerate(param['params']):
if d_param.is_sparse:
sparse_params.append([index, d_index])
elif param.is_sparse:
sparse_params.append(index)
for index, param_group in enumerate(self.param_groups):
assert isinstance(param_group, dict), f"param_groups must be a list of dicts, but got {type(param_group)}"
# given param group, convert given params to a list first before iterating
for d_index, d_param in enumerate(param_group['params']):
if d_param.is_sparse:
sparse_params.append([index, d_index])
if sparse_params:
raise ValueError(
f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"
)

defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
Expand Down

0 comments on commit 7c1a501

Please sign in to comment.