In [2]:
import numpy as np

In [3]:
class Adafactor:
    def __init__(self, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, clip_threshold=1.0):
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.clip_threshold = clip_threshold
    
    def update(self, params, grads):
        m = [np.zeros_like(param) for param in params]
        v = [np.zeros_like(param) for param in params]

        updated_params = []
        for i, (param, grad) in enumerate(zip(params, grads)):
            grad_squared = grad ** 2
            if grad.ndim > 1:
                row_mean = np.mean(grad_squared, axis=-1, keepdims=True)
                col_mean = np.mean(grad_squared, axis=-2, keepdims=True)
                m[i] = self.beta2 * m[i] + (1 - self.beta2) * grad_squared
                r_factor = row_mean
                c_factor = col_mean
            else:
                m[i] = self.beta2 * m[i] + (1 - self.beta2) * grad_squared
                r_factor = c_factor = m[i]

            grad_norm = np.linalg.norm(grad)
            if grad_norm > self.clip_threshold:
                grad = grad * (self.clip_threshold / grad_norm)

            v[i] = self.beta1 * v[i] + (1 - self.beta1) * grad

            update_step = grad / (np.sqrt(r_factor * c_factor) + self.eps)
            param_update = param - self.learning_rate * update_step
            updated_params.append(param_update)

        return updated_params

In [4]:
params = np.array([2.0])
grads = np.array([4.0])

In [5]:
adafactor = Adafactor()
adafactor.update(params, grads)

[1.9375000390624757]