In [1]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

In [2]:
class SGD(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate

            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
                return loss

In [3]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)

In [4]:
for t in range(100):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

23.470273971557617
22.540849685668945
21.907804489135742
21.404787063598633
20.978832244873047
20.6052303314209
20.270122528076172
19.964824676513672
19.6834774017334
19.421907424926758
19.177013397216797
18.946426391601562
18.728282928466797
18.521089553833008
18.323619842529297
18.13486099243164
17.953968048095703
17.780208587646484
17.612972259521484
17.45171546936035
17.295970916748047
17.14533042907715
16.99942398071289
16.857934951782227
16.72057342529297
16.587074279785156
16.457210540771484
16.330766677856445
16.207550048828125
16.087387084960938
15.970115661621094
15.855589866638184
15.743671417236328
15.634238243103027
15.527173042297363
15.422367095947266
15.319723129272461
15.219145774841309
15.120550155639648
15.023857116699219
14.92898941040039
14.8358736038208
14.74444580078125
14.654644012451172
14.566405296325684
14.4796781539917
14.39440631866455
14.31054401397705
14.228041648864746
14.146855354309082
14.06694221496582
13.988260269165039
13.910776138305664
13.83444976

In [5]:
opt = SGD([weights], lr=10)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

11.134657859802246
7.126181125640869
5.253116607666016
4.110002517700195
3.329101800918579
2.760206937789917
2.3278684616088867
1.9892299175262451
1.7178564071655273
1.496443748474121


In [6]:
opt = SGD([weights], lr=10)
for t in range(20):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

1.3131428956985474
0.840411365032196
0.6195153594017029
0.4847046136856079
0.3926107883453369
0.32551929354667664
0.27453234791755676
0.23459571599960327
0.2025918811559677
0.1764800250530243
0.15486279129981995
0.1367487758398056
0.12141421437263489
0.10831809788942337
0.09704789519309998
0.08728361874818802
0.07877346873283386
0.0713166669011116
0.06475135684013367
0.0589456781744957


In [7]:
opt = SGD([weights], lr=30)
for t in range(10):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

0.05379130691289902
0.00860660895705223
0.002852849429473281
0.001218679128214717
0.0005971528007648885
0.0003196819743607193
0.00018225137318950146
0.00010896284220507368
6.763714918633923e-05
4.328777504269965e-05
