<a href="https://colab.research.google.com/github/yangliupku/cs336_assignment1_basics/blob/main/notebooks/optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

In [5]:
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
      defaults = {'lr': lr}
      super().__init__(params, defaults)

    def step(self, closure: Optional[Callable]=None):
      loss = None
      if closure is not None:
        loss = closure()
      for group in self.param_groups:
        lr = group['lr']
        for p in group['params']:
          if p.grad is None:
            continue

          state = self.state[p]
          t = state.get('t', 0)
          grad = p.grad.data
          p.data -= lr / math.sqrt(t + 1) * grad
          state['t'] = t + 1
      return loss



In [14]:
weights = torch.nn.Parameter(5 * torch.randn(10, 10))
opt = SGD([weights], lr=1e1)


In [15]:
for t in range(10):
  opt.zero_grad()
  loss = (weights ** 2).mean()
  print(loss.cpu().item())
  loss.backward()
  opt.step()



20.543134689331055
13.147604942321777
9.691853523254395
7.5828399658203125
6.142100811004639
5.092504978179932
4.294852256774902
3.6700737476348877
3.1693975925445557
2.760897159576416


tensor(2.7609, grad_fn=<MeanBackward0>)