In [None]:
import torch
import inspect

try:
    import bitsandbytes as bnb
    adam8bit_class = bnb.optim.Adam8bit
except ImportError:
    adam8bit_class = None
    # pass, raise ImportError

try:
    import prodigyopt
    prodigy_class = prodigyopt.Prodigy
except ImportError:
    prodigy_class = None

optimizer_dict = {'adam': torch.optim.Adam, 'adam8bit': adam8bit_class, 'adamw': torch.optim.AdamW, 'prodigy': prodigy_class}

def filter_valid_params(constructor, params_dict):
    valid_params = inspect.signature(constructor).parameters
    filtered_params = {key: value for key, value in params_dict.items() if key in valid_params}
    return filtered_params

def prepare_optimizer_params(models, learning_rates):

In [None]:
model1 = torch.nn.Linear(3,4)
model2 = torch.nn.Conv2d(3, 6)
# 메모리 줄이기
model1_parameters = list(filter(lambda p: p.requires_grad, model1.parameters()))
model1_parameters_with_lr = {"params": model1_parameters, "lr": 0.15}

model2_parameters = list(filter(lambda p: p.requires_grad, model2.parameters()))
model2_parameters_with_lr = {"params": model2_parameters, "lr": 0.1}


params_to_optimizer = [model1_parameters_with_lr, model2_parameters_with_lr]
my_params = {'betas' : (0.1, 0.1), 'weight_decay' : 0.99, 'eps':0.999, 'lr': 0.01}
for key, optimizer_class in optimizer_dict.items():
    if optimizer_class is None:
        continue
    if key == 'adam8bit':
        
        my_params['lr'] = 0.15
    valid_params = filter_valid_params(optimizer_class, my_params)
    print(valid_params)
    optimizer = optimizer_class(params_to_optimize, **valid_params)
    print(optimizer)

LION

In [None]:
from typing import Tuple, Optional, Callable

import torch
from torch.optim.optimizer import Optimizer

def exists(val):
    return val is not None

# update functions

def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
    # stepweight decay

    p.data.mul_(1 - lr * wd)

    # weight update

    update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
    p.add_(update, alpha = -lr)

    # decay the momentum running average coefficient

    exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)

class Lion(Optimizer):
    # https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py
    def __init__(
        self,
        params,
        lr: float = 1e-4,
        betas: Tuple[float, float] = (0.9, 0.99),
        weight_decay: float = 0.0,
    ):
        assert lr > 0.
        assert all([0. <= beta <= 1. for beta in betas])

        defaults = dict(
            lr = lr,
            betas = betas,
            weight_decay = weight_decay
        )

        super().__init__(params, defaults)

        self.update_fn = update_fn
        
    @torch.no_grad()
    def step(
        self,
        closure: Optional[Callable] = None
    ):

        loss = None
        if exists(closure):
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in filter(lambda p: exists(p.grad), group['params']):

                grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]

                # init state - exponential moving average of gradient values

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']

                self.update_fn(
                    p,
                    grad,
                    exp_avg,
                    lr,
                    wd,
                    beta1,
                    beta2
                )

        return loss

AGD

In [None]:
from torch.nn.init import orthogonal_
import math


def singular_value(p):
    sv = math.sqrt(p.shape[0] / p.shape[1])
    if p.dim() == 4:
        sv /= math.sqrt(p.shape[2] * p.shape[3])
    return sv

class AGD:
    # https://github.com/jxbz/agd
    @torch.no_grad()
    def __init__(self, net, gain=1.0):

        self.net = net
        self.depth = len(list(net.parameters()))
        self.gain = gain

        for p in net.parameters():
            if p.dim() == 1: raise Exception("Biases are not supported.")
            if p.dim() == 2: orthogonal_(p)
            if p.dim() == 4:
                for kx in range(p.shape[2]):
                    for ky in range(p.shape[3]):
                        orthogonal_(p[:,:,kx,ky])
            p *= singular_value(p)

    @torch.no_grad()
    def step(self):

        G = 0
        for p in self.net.parameters():
            G += singular_value(p) * p.grad.norm(dim=(0,1)).sum()
        G /= self.depth

        log = math.log(0.5 * (1 + math.sqrt(1 + 4*G)))

        for p in self.net.parameters():
            factor = singular_value(p) / p.grad.norm(dim=(0,1), keepdim=True)
            p -= self.gain * log / self.depth * factor * p.grad

        return log