In [None]:
import torch

from collections.abc import Callable, Iterable
from typing import Optional
import math

class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]          # 获取学习率
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]           # 获取与该参数关联的状态
                t = state.get("t", 0)           # 获取迭代次数（如果没有则从 0 开始）
                grad = p.grad.data              # 获取损失对 p 的梯度
                p.data -= lr / math.sqrt(t + 1) * grad   # 就地更新权重张量
                state["t"] = t + 1              # 迭代次数加 1
        return loss

learning_rates = [1e1, 1e2, 1e3]

for lr in learning_rates:
    print(f"\n=== Learning rate = {lr} ===")
    weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
    opt = SGD([weights], lr=lr)

    for t in range(10):
        opt.zero_grad()
        loss = (weights ** 2).mean()
        loss.backward()
        opt.step()
        print(f"iter {t:2d}: loss = {loss.item():.6f}")





=== Learning rate = 10.0 ===
iter  0: loss = 30.651350
iter  1: loss = 19.616863
iter  2: loss = 14.460715
iter  3: loss = 11.313965
iter  4: loss = 9.164311
iter  5: loss = 7.598264
iter  6: loss = 6.408128
iter  7: loss = 5.475928
iter  8: loss = 4.728895
iter  9: loss = 4.119392

=== Learning rate = 100.0 ===
iter  0: loss = 18.623650
iter  1: loss = 18.623650
iter  2: loss = 3.195312
iter  3: loss = 0.076471
iter  4: loss = 0.000000
iter  5: loss = 0.000000
iter  6: loss = 0.000000
iter  7: loss = 0.000000
iter  8: loss = 0.000000
iter  9: loss = 0.000000

=== Learning rate = 1000.0 ===
iter  0: loss = 25.972237
iter  1: loss = 9375.976562
iter  2: loss = 1619378.500000
iter  3: loss = 180138592.000000
iter  4: loss = 14591223808.000000
iter  5: loss = 920873271296.000000
iter  6: loss = 47274625335296.000000
iter  7: loss = 2033956656513024.000000
iter  8: loss = 74967305531949056.000000
iter  9: loss = 2407283402236493824.000000


In [7]:
print('1')

1
