Skip to content

Conversation

filipviz
Copy link
Contributor

@filipviz filipviz commented Sep 16, 2025

  1. Prevents unintended aliasing of self._last_lr/get_last_lr(...) with group["lr"] when group["lr"] is a tensor.
  2. Prevents unintended aliasing of LRScheduler.base_lrs with the group["initial_lr"]s.
  3. Updates test/optim/test_lrscheduler.py to test tensor LRs.
  4. Changes type annotations for _last_lr, get_last_lr(), base_lrs, get_lr(), and _get_closed_form_lr() from list[float] to list[float | Tensor]; adds documentation.

Fixes #163103

LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:

self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]

This PR adds a helper to address this:

def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]

Copy link

pytorch-bot bot commented Sep 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163120

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0ddc6b4 with merge base 6cfb080 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vadimkantorov
Copy link
Contributor

it would also be great to have purely closed-form / functional scheduler formula functions (maybe extracted out of currently implemented schedulers)... Then they could be used even outside of ParamGroup context

@filipviz
Copy link
Contributor Author

Sounds interesting @vadimkantorov. Does this come up often? I'd be open to adding something like this soon, but I want to expand lr_scheduler test coverage first (once this and #163122 land).

@albanD albanD removed their request for review September 17, 2025 14:37
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying what's going on! One highlevel thing I would ask for this PR to clarify (and document) is whether users expect get_lr or get_last_lr to return aliases vs copies of the Tensor lrs. We should at the very least document what is happening now (e.g., after this change, get_lr would return copies instead of aliases). And as that behavior is changing from previous releases, it is BC-breaking. Though, for this case, I think that it is fair to consider this a bug fix, as get_lr used to return immutable floats, and so we should follow that semantic and return un-aliased Tensors.

@janeyx99
Copy link
Contributor

What @vadimkantorov is talking about is something out of the scope for this PR, and what our get_closed_form functions attempt to help solve. (Vadim correct me if i'm misunderstanding!) At one time there was an effort to deprecate all the nonclosed forms, but it was unsuccessful because there were still use cases we didn't cover, and the person who maintained LRScheduler stopped doing so. So if you want a good can of worms, this situation may fall under that category 😛

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 17, 2025

Yes, my comment is out of scope. I was meaning to say that it's quite stunning how small edge cases / bugs / unexpected complexities are found around existing stateful/imperative schedulers. And I think that if closed-form purely functional formulas were also available as standalone functions (even if not all use-cases covered), users could use them directly and experiment with other scheduler APIs (also applicable not only to lr, but e.g. to weight decay)

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 18, 2025
@filipviz
Copy link
Contributor Author

Yeah, great points @janeyx99 - the PR was underbaked. Should be better now! Here's an overview:

Current behavior

There are several rough edges caused by patterns like this in lr_scheduler.py:

self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
  1. When group["lr"] is a tensor, _last_lr and get_last_lr() are aliased tensors despite their type annotation, and are silently mutated in future calls to .step() (which updates group["lr"] in-place). If the user modifies the outputs of get_last_lr() in place, they silently corrupt their learning rates! For example:
import torch
from torch import optim, nn

lr = torch.tensor(0.1)
opt = optim.SGD([nn.Parameter(torch.tensor(1.0))], lr=lr)
scheduler = optim.lr_scheduler.ConstantLR(opt, total_iters=1)

x = scheduler.get_last_lr()[0]
print(x)  # tensor(0.0333)

opt.step()
scheduler.step()  # mutates x
print(x)  # tensor(0.1000)

x /= 2  # corrupts param_groups[0]["lr"]

This causes unexpected behavior in common workflows:

history = []
for step in total_iters:
    opt.step()
    scheduler.step()
    history.append(scheduler.get_last_lr()[0])
print(history)  # All values == history[-1]
  1. get_lr() also had an inaccurate type annotation, and behaved inconsistently - MultiplicativeLR, StepLR, MultiStepLR, ConstantLR, LinearLR, ExponentialLR, PolynomialLR, and CosineAnnealingLR sometimes returned group["lr"] tensors directly, depending on the circumstances. For example, in StepLR.get_lr:
@override
def get_lr(self) -> list[float | Tensor]:
    """Compute the learning rate of each parameter group."""
    _warn_get_lr_called_within_step(self)

    if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
        return [group["lr"] for group in self.optimizer.param_groups]  # Aliased
    return [group["lr"] * self.gamma for group in self.optimizer.param_groups]  # New tensors

This isn't a huge issue since get_lr isn't really user-facing, but I think it's worth fixing.

Aside: it's not clear to me why get_lr is a public method. It seems to cause confusion. Should we consider moving it to _get_lr and adding a deprecation warning to get_lr?

  1. base_lrs can be tensors which alias group["initial_lr"], violating its type annotation. The aliasing is less of a concern here as neither is supposed to change. But when base_lrs are tensors, this propagates through the math in _get_closed_form_lr(), meaning its type annotation is inaccurate as well.

New behavior

This PR makes the existing type behavior explicit, but without the alias footguns. _last_lr, get_last_lr(), base_lrs, get_lr(), and _get_closed_form_lr() are now all list[float | Tensor] with each entry matching the type of the corresponding group["lr"].

Note that this PR doesn't add tests to enforce matching types, and that that CyclicLR and OneCycleLR still have slightly idiosyncratic behavior. I plan to address these in a follow-up PR to make reviewing easier, but can add them here if desired.

Alternatives

  1. If we instead wanted to satisfy the annotations by changing the behavior, we could do something like this:
- self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+ self._last_lr = [float(group["lr"]) for group in self.optimizer.param_groups]

Now that we use _update_param_group_val from #163098 this might not trigger recompilations each step, but I haven't tested this.

  1. get_last_lr() returns self._last_lr directly. I think this is less of a concern since _last_lr is re-assigned each iteration, but we could consider defensively returning copies.

Aside

What's the process for particularly marginal fixes? For example, I noticed ConstantLR currently checks:

if factor > 1.0 or factor < 0:
    raise ValueError(
        "Constant multiplicative factor expected to be between 0 and 1."
    )

Which allows factor=0, which will cause a ZeroDivisionError error once self.last_epoch == self.total_iters. Should check for factor <= 0. Separate PR?

@filipviz filipviz changed the title [optim] fix unintended aliasing in lr scheduler _last_lr and base_lr [optim] prevent unintended aliasing in lr_scheduler; update type annotations/docs Sep 18, 2025
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! Thank you for looking into it. I agree let's land this now + build more on top later.

With get_lr...yea deprecating + privatizing it sounds good given the context you've mentioned. Can you check through GH/some issues to ensure that no one is relying on get_lr in a way we don't expect though? (This is for a future PR, the current one is good as is.)

@janeyx99
Copy link
Contributor

janeyx99 commented Sep 18, 2025

Also, is there a reason get_lr is not returning copies like get_last_lr? Does our internal usage depend on mutating the lr? (I would guess that we'd also want get_lr to return copies)

@filipviz
Copy link
Contributor Author

Also, is there a reason get_lr is not returning copies like get_last_lr? Does our internal usage depend on mutating the lr? (I would guess that we'd also want get_lr to return copies)

get_lr() now does return copies in cases where it previously would have returned the group["lr"] tensors directly. For example, in StepLR.get_lr:

@override
def get_lr(self) -> list[float | Tensor]:
    """Compute the learning rate of each parameter group."""
    _warn_get_lr_called_within_step(self)

    if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
-        return [group["lr"] for group in self.optimizer.param_groups]  # Returned group["lr"]s
+        return _param_groups_val_list(self.optimizer, "lr") # Uses .clone()
    return [group["lr"] * self.gamma for group in self.optimizer.param_groups]  # New tensors

@filipviz
Copy link
Contributor Author

@janeyx99 seeing that there are some issues with the docs - do you have any advice on debugging this? I've been trying to build the docs locally and I'm running into lots of errors. Is there an environment where you find they work reliably?

@filipviz filipviz requested a review from janeyx99 September 22, 2025 18:46
@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / test (default, 3, 3, lf.windows.4xlarge.nonephemeral)

Details for Dev Infra team Raised by workflow job

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2025
@filipviz
Copy link
Contributor Author

CI caught an interesting issue!

  • I used the Python 3.10 a | b union type syntax throughout.
  • This mostly works in Python 3.9 because lr_scheduler.py has from __future__ import annotations
  • But this does not work as an argument to cast since arguments are evaluated at runtime:
    values = cast(list[float | Tensor], self._get_closed_form_lr())
TypeError: unsupported operand type(s) for |: 'type' and 'torch._C._TensorMeta'

And it attempts to bitwise-OR the types. The fix:

- values = cast(list[float | Tensor], self._get_closed_form_lr())
+ values = cast(list[Union[float, Tensor]], self._get_closed_form_lr())

Sorry for missing this @janeyx99 - just pushed a fix.

@janeyx99
Copy link
Contributor

@pytorchbot merge

Let’s go! c:

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 25, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jainapurva pushed a commit that referenced this pull request Sep 29, 2025
…tations/docs (#163120)

1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.

Fixes #163103

LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```

This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]
```
Pull Request resolved: #163120
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[lr_scheduler] Multiple LR schedulers incorrectly alias self._last_lr/get_last_lr(...) with group["lr"] when using tensor lr

6 participants