## **Adam Optimizer**

In [2]:
import torch
from typing import Tuple, Callable, Union, Dict, Iterable
import numpy as np
import math

from torch.optim import Optimizer
from torch import Tensor


class Adam(Optimizer):
  def __init__(self, params : Iterable[Tensor], lr : float = 1e-3, betas : Tuple[float, float] = (0.9, 0.999), eps : float = 1e-3,
               weight_decay : float = 0.0):
    
    super().__init__(
        params, default = dict(lr = lr, betas = betas, eps = eps, weight_decay = weight_decay)
    )

  def update_params(self, param : Tensor, group : Dict) -> None:
    state = self.state[param]
    grad = param.grad 
    state["step"] += 1
    beta1, beta2 = group["betas"]
    exp_avg = state["exp_avg"] * beta1 + (grad) * (1-beta1)
    exp_avg_sq = beta2 * state["exp_avg_sq"] + (1-beta2) * grad ** 2

    weight_decay = group["weight_decay"]
    if weight_decay != 0:
      grad += weight_decay * param

    param -= group["lr"] * exp_avg / (exp_avg_sq.sqrt() + group["eps"])  

  @torch.no_grad()
  def step(self, closure : Callable[[], float] = None) -> Union[float, None]: 
    loss = 0
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for param in group["params"]:
        if not self.state[param]:
          state = self.state[param]
          state["step"] = 0
          state["exp_avg"] = torch.zeros_like(param)
          state["exp_avg_sq"] = torch.zeros_like(param)

        grad = param.grad
        if grad is not None:
          self.update_params(param, group)

    return loss             

## **AdaBelief Optimizer**

In [3]:
import torch
from typing import Tuple, Callable, Union, Dict, Iterable
import numpy as np
import math

from torch.optim import Optimizer
from torch import Tensor


class AdaBelief(Optimizer):
  def __init__(self, params : Iterable[Tensor], lr : float = 1e-3, betas : Tuple[float, float] = (0.9, 0.999), eps : float = 1e-3,
               weight_decay : float = 0.0):
    
    super().__init__(
        params, default = dict(lr = lr, betas = betas, eps = eps, weight_decay = weight_decay)
    )

  def update_params(self, param : Tensor, group : Dict) -> None:
    state = self.state[param]
    grad = param.grad 
    state["step"] += 1
    eps = group["eps"]
    bias_correction = 1e-7
    beta1, beta2 = group["betas"]
    exp_avg = state["exp_avg"] * beta1 + (grad) * (1-beta1)
    exp_avg_var = beta2 * state["exp_avg_var"] + (1-beta2) * (grad - exp_avg) ** 2

    weight_decay = group["weight_decay"]
    if weight_decay != 0:
      grad += weight_decay * param

    #param -= group["lr"] * exp_avg / (exp_avg_var.sqrt() + group["eps"])  
    #here we add a bias correction term, which is an additional difference from the behavior of Adam
    step_size = group["lr"] / bias_correction  #bias corr is any constant
    param -= step_size * exp_avg / (exp_avg_var.sqrt() + eps)

  @torch.no_grad()
  def step(self, closure : Callable[[], float] = None) -> Union[float, None]: 
    loss = 0
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for param in group["params"]:
        if not self.state[param]:
          state = self.state[param]
          state["step"] = 0
          state["exp_avg"] = torch.zeros_like(param)
          state["exp_avg_var"] = torch.zeros_like(param)

        grad = param.grad
        if grad is not None:
          self.update_params(param, group)

    return loss             