-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
optim.py
72 lines (62 loc) · 3.57 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup, getenv
from tinygrad.tensor import Tensor
class Optimizer:
def __init__(self, params: List[Tensor], lr: float):
# if it's None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None: x.requires_grad = True
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
def zero_grad(self):
for param in self.params: param.grad = None
def realize(self, extra=None):
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
def step(self) -> None: raise NotImplementedError
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
super().__init__(params, lr)
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = t.grad.contiguous() + self.wd * t.detach()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True)
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
class LAMB(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
super().__init__(params, lr)
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self) -> None:
self.t.assign(self.t + 1)
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1**self.t)
v_hat = self.v[i] / (1.0 - self.b2**self.t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)