In [1]:
import inspect
import re

from tensorflow import keras  # noqa: F401
from keras import optimizers as tfoptim

import torch
import torch.optim as torchoptim

import pandas as pd
pd.set_option('display.max_colwidth', 500)

import flatiron.core.tools as fict

2025-02-21 21:27:14.328265: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740173234.349365 1769482 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740173234.355200 1769482 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-21 21:27:14.374928: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def get_classes(module):
    members = inspect.getmembers(module)
    members = list(filter(lambda x: inspect.isclass(x[1]), members))
    members = list(filter(lambda x: not x[0].startswith('_'), members))
    classes = dict(members)
    return classes

def create_signature(arg, annotation, default):
    if annotation == 'UNTYPED':
        annotation = 'Any'
    if default == 'REQUIRED':
        default = ''
    else:
        default = f' = {default}'
    return f'{arg}: {annotation}{default}'

def get_init_signature_data(class_, remove=['self']):
    sig = inspect.getfullargspec(class_)
    args = sig.args
    for item in remove:
        args.remove(item)

    if sig.defaults is not None:
        d = len(args) - len(sig.defaults)
        req = args[:d]
        opt = args[d:]
        args = {k: 'REQUIRED' for k in req}
        opt = dict(zip(opt, sig.defaults))
        args.update(opt)
    else:
        args = {k: 'REQUIRED' for k in args}
    
    if isinstance(sig.kwonlydefaults, dict):
        args.update(sig.kwonlydefaults)
    
    anno = sig.annotations
    for key, val in args.items():
        if key in anno:
            args[key] = (val, anno[key].__name__)
        else:
            args[key] = (val, 'UNTYPED')
            
    data = []
    for arg, (default, type_) in args.items():
        data.append(dict(
            arg=arg,
            default=default,
            type_=type_,
            signature=create_signature(arg, type_, default),
        ))
    return data

def get_module_class_data(module):
    classes = get_classes(module)
    data = []
    for name, item in classes.items():
        datum = get_init_signature_data(item)
        for row in datum:
            row['class_'] = name
        data.extend(datum)
        
    cols = ['class_', 'arg', 'type_', 'default', 'signature']
    data = pd.DataFrame(data, columns=cols)
    data['library'] = module.__name__.split('.')[0]
    data['module'] = module.__name__
    cols.insert(0, 'library')
    cols.insert(1, 'module')
    data = data[cols]
    
    return data

def get_optimizer_data():
    tf_optim_data = get_module_class_data(tfoptim)
    torch_optim_data = get_module_class_data(torchoptim)
    data = pd.concat([tf_optim_data, torch_optim_data], axis=0)
    
    mask = data.library == 'keras'
    data.loc[mask, 'library'] = 'tf'
    
    data['field'] = data['class_']
    mask = data.field == 'Nadam'
    data.loc[mask, 'field'] = 'NAdam'
    
    mask = data.class_.apply(lambda x: x not in ['Optimizer', 'LossScaleOptimizer'])
    data = data[mask]
    
    mask = data.arg != 'params'
    data = data[mask]

    data.reset_index(drop=True, inplace=True)

    return data

def get_class_definitions(data, base_class='BaseConfig'):
    data = data.copy()
    data['config_name'] = data \
        .apply(lambda x: f'class {x.library.capitalize()}{x.class_}Config({base_class}):', axis=1) \
        .apply(lambda x: re.sub(' Tf', ' TF', x))
    class_def = data.groupby('config_name', as_index=False).signature.agg(lambda x: '    ' + '\n    '.join(sorted(x)))
    class_def = class_def.apply(lambda x: f'{x.config_name}\n{x.signature}', axis=1)
    return class_def.tolist()

In [3]:
def print_tf_optimizer_config_definitions():
    # TF
    d = get_optimizer_data()
    mask = d.library == 'tf'
    d = d[mask]

    # TFBaseConfig
    tf_base_args = [
        'clipnorm',
        'clipvalue',
        'ema_momentum',
        'ema_overwrite_frequency',
        'global_clipnorm',
        'gradient_accumulation_steps',
        'learning_rate',
        'loss_scale_factor',
        'name',
        'use_ema',
        'weight_decay',
    ]
    mask = d.arg.apply(lambda x: x in tf_base_args)
    d0 = d[mask]
    tf_base = get_class_definitions(d0)[0]
    tf_base = re.sub('TFAdadeltaConfig', 'TFBaseConfig', tf_base)
    print(tf_base, '\n\n')

    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('TFAdadeltaConfig', 'TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')

    # TFBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[~mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args)
    d2 = d2[~mask]
    tf_subclass = get_class_definitions(d2, 'TFBaseConfig')
    for item in tf_subclass:
        print(item, '\n\n')

    # TFEpsilonBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args or x in tf_eps_args)
    d2 = d2[~mask]
    tf_eps_subclass = get_class_definitions(d2, 'TFEpsilonBaseConfig')
    for item in tf_eps_subclass:
        print(item, '\n\n')
        
print_tf_optimizer_config_definitions()

class TFBaseConfig(BaseConfig):
    clipnorm: Any = None
    clipvalue: Any = None
    ema_momentum: Any = 0.99
    ema_overwrite_frequency: Any = None
    global_clipnorm: Any = None
    gradient_accumulation_steps: Any = None
    learning_rate: Any = 0.001
    loss_scale_factor: Any = None
    name: Any = adadelta
    use_ema: Any = False
    weight_decay: Any = None 


class TFEpsilonBaseConfig(TFBaseConfig):
    epsilon: Any = 1e-07 


class TFAdafactorConfig(TFBaseConfig):
    beta_2_decay: Any = -0.8
    clip_threshold: Any = 1.0
    epsilon_1: Any = 1e-30
    epsilon_2: Any = 0.001
    relative_step: Any = True 


class TFFtrlConfig(TFBaseConfig):
    beta: Any = 0.0
    initial_accumulator_value: Any = 0.1
    l1_regularization_strength: Any = 0.0
    l2_regularization_strength: Any = 0.0
    l2_shrinkage_regularization_strength: Any = 0.0
    learning_rate_power: Any = -0.5 


class TFLionConfig(TFBaseConfig):
    beta_1: Any = 0.9
    beta_2: Any = 0.99 


class TFSGDConfig(T

In [67]:
def print_torch_optimizer_config_definitions():
    def get_base_class(data, arg, class_, base_class):
        mask = data.arg.apply(lambda x: x == arg)
        temp = data[mask]
        result = get_class_definitions(temp, base_class)[0]
        result = re.sub('class Torch[a-zA-Z]*Config', f'class {class_}', result)
        return result
    
    def get_torch_class_definition(class_, inherit, signature):
        inherit = ', '.join(sorted(filter(lambda x: x != '', inherit)))
        output = f'class Torch{class_}Config(TorchBaseConfig, {inherit}):\n    '
        output = re.sub(', \)', ')', output)
        regex = '(params|lr|maximize|foreach|differentiable|eps|capturable|weight_decay):'
        signature = list(filter(lambda x: not re.search(regex, x), signature))
        output +=  '\n    '.join(sorted(signature))
        return output

    # Torch
    data = get_optimizer_data()
    mask = data.library == 'torch'
    data = data[mask]

    # TFBaseConfig
    mask = data.arg.apply(lambda x: x == 'lr')
    d0 = data[mask]
    torch_base = get_class_definitions(d0)[0]
    torch_base = re.sub('TorchASGDConfig', 'TorchBaseConfig', torch_base)
    print(torch_base, '\n\n')    
    
    # aux base configs
    lut = [
        ('maximize', 'TMax'),
        ('foreach', 'TFor'),
        ('differentiable', 'TDiff'),
        ('eps', 'TEps'),
        ('capturable', 'TCap'),
        ('weight_decay', 'TDecay'),
    ]
    data['inherit'] = ''
    for arg, cls_ in lut:
        print(get_base_class(data, arg, cls_, 'TorchBaseConfig'), '\n\n')
        mask = data.arg == arg
        data.loc[mask, 'inherit'] = cls_
    
    # classes
    class_def = data \
        .groupby('class_', as_index=False)[['inherit', 'signature']] \
        .agg(lambda x: x) \
        .apply(lambda x: get_torch_class_definition(x.class_, x.inherit, x.signature), axis=1) \
        .tolist()  
    for item in class_def:
        print(item, '\n\n')

print_torch_optimizer_config_definitions()

class TorchBaseConfig(BaseConfig):
    lr: Union = 0.01 


class TMax(TorchBaseConfig):
    maximize: bool = False 


class TFor(TorchBaseConfig):
    foreach: Optional = None 


class TDiff(TorchBaseConfig):
    differentiable: bool = False 


class TEps(TorchBaseConfig):
    eps: float = 1e-06 


class TCap(TorchBaseConfig):
    capturable: bool = False 


class TDecay(TorchBaseConfig):
    weight_decay: float = 0 


class TorchASGDConfig(TorchBaseConfig, TCap, TDecay, TDiff, TFor, TMax):
    alpha: float = 0.75
    lambd: float = 0.0001
    t0: float = 1000000.0 


class TorchAdadeltaConfig(TorchBaseConfig, TCap, TDecay, TDiff, TEps, TFor, TMax):
    rho: float = 0.9 


class TorchAdafactorConfig(TorchBaseConfig, TDecay, TEps, TFor, TMax):
    beta2_decay: float = -0.8
    d: float = 1.0 


class TorchAdagradConfig(TorchBaseConfig, TDecay, TDiff, TEps, TFor, TMax):
    fused: Optional = None
    initial_accumulator_value: float = 0
    lr_decay: float = 0 


class TorchAdamConfig(To

In [74]:
from keras.src.optimizers.sgd import SGD
inspect.getfullargspec(SGD)

FullArgSpec(args=['self', 'learning_rate', 'momentum', 'nesterov', 'weight_decay', 'clipnorm', 'clipvalue', 'global_clipnorm', 'use_ema', 'ema_momentum', 'ema_overwrite_frequency', 'loss_scale_factor', 'gradient_accumulation_steps', 'name'], varargs=None, varkw='kwargs', defaults=(0.01, 0.0, False, None, None, None, None, False, 0.99, None, None, None, 'SGD'), kwonlyargs=[], kwonlydefaults=None, annotations={})

In [49]:
data = get_optimizer_data()

opt = data.groupby('arg', as_index=False)[['library', 'class_']].agg(lambda x: x.unique())
opt['len_library'] = opt.library.apply(len)
opt['len_class'] = opt.class_.apply(len)
opt.sort_values(['len_library', 'len_class'], ascending=False, inplace=True)

opt['pref0'] = opt.library.apply(lambda x: '_'.join(x) + '_')
mask = opt.len_library == opt.len_library.max()
opt.loc[mask, 'pref0'] = ''

opt['pref1'] = opt.class_.apply(lambda x: '_'.join(x))
mask = opt.len_class == opt.len_class.max()
opt.loc[mask, 'pref1'] = ''

opt['prefix'] = opt.apply(lambda x: f'{x.pref0}{x.pref1}__', axis=1)
del opt['pref0']
del opt['pref1']

In [50]:
d = opt
mask = d.len_library == 2
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
54,weight_decay,"[tf, torch]","[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD, ASGD, NAdam, RAdam]",2,15,__
1,amsgrad,"[tf, torch]","[Adam, AdamW]",2,2,Adam_AdamW__
29,initial_accumulator_value,"[tf, torch]","[Adagrad, Ftrl]",2,2,Adagrad_Ftrl__
43,momentum,"[tf, torch]","[RMSprop, SGD]",2,2,RMSprop_SGD__
48,rho,"[tf, torch]","[Adadelta, RMSprop]",2,2,Adadelta_RMSprop__
9,centered,"[tf, torch]",[RMSprop],2,1,RMSprop__
46,nesterov,"[tf, torch]",[SGD],2,1,SGD__


In [60]:
lib = 'tf'

d = opt
mask = d.library.apply(lambda x: lib in x)
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
54,weight_decay,"[tf, torch]","[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD, ASGD, NAdam, RAdam]",2,15,__
1,amsgrad,"[tf, torch]","[Adam, AdamW]",2,2,Adam_AdamW__
29,initial_accumulator_value,"[tf, torch]","[Adagrad, Ftrl]",2,2,Adagrad_Ftrl__
43,momentum,"[tf, torch]","[RMSprop, SGD]",2,2,RMSprop_SGD__
48,rho,"[tf, torch]","[Adadelta, RMSprop]",2,2,Adadelta_RMSprop__
9,centered,"[tf, torch]",[RMSprop],2,1,RMSprop__
46,nesterov,"[tf, torch]",[SGD],2,1,SGD__
11,clipnorm,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
12,clipvalue,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
17,ema_momentum,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__


In [61]:
c = d.class_.apply(lambda x: {k: k for k in x}).tolist()
i = d.arg.tolist()
pd.DataFrame(c, index=i).map(lambda x: '' if pd.isnull(x) else 'x')

Unnamed: 0,Adadelta,Adafactor,Adagrad,Adam,AdamW,Adamax,Ftrl,Lamb,Lion,Nadam,RMSprop,SGD,ASGD,NAdam,RAdam
weight_decay,x,x,x,x,x,x,x,x,x,x,x,x,x,x,x
amsgrad,,,,x,x,,,,,,,,,,
initial_accumulator_value,,,x,,,,x,,,,,,,,
momentum,,,,,,,,,,,x,x,,,
rho,x,,,,,,,,,,x,,,,
centered,,,,,,,,,,,x,,,,
nesterov,,,,,,,,,,,,x,,,
clipnorm,x,x,x,x,x,x,x,x,x,x,x,x,,,
clipvalue,x,x,x,x,x,x,x,x,x,x,x,x,,,
ema_momentum,x,x,x,x,x,x,x,x,x,x,x,x,,,


In [62]:
lib = 'torch'

d = opt
mask = d.library.apply(lambda x: lib in x)
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
54,weight_decay,"[tf, torch]","[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD, ASGD, NAdam, RAdam]",2,15,__
1,amsgrad,"[tf, torch]","[Adam, AdamW]",2,2,Adam_AdamW__
29,initial_accumulator_value,"[tf, torch]","[Adagrad, Ftrl]",2,2,Adagrad_Ftrl__
43,momentum,"[tf, torch]","[RMSprop, SGD]",2,2,RMSprop_SGD__
48,rho,"[tf, torch]","[Adadelta, RMSprop]",2,2,Adadelta_RMSprop__
9,centered,"[tf, torch]",[RMSprop],2,1,RMSprop__
46,nesterov,"[tf, torch]",[SGD],2,1,SGD__
38,lr,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, LBFGS, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam]",1,14,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_LBFGS_NAdam_RAdam_RMSprop_Rprop_SGD_SparseAdam__
42,maximize,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam]",1,13,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop_SGD_SparseAdam__
24,foreach,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop, SGD]",1,12,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop_SGD__


In [63]:
c = d.class_.apply(lambda x: {k: k for k in x}).tolist()
i = d.arg.tolist()
pd.DataFrame(c, index=i).map(lambda x: '' if pd.isnull(x) else 'x')

Unnamed: 0,Adadelta,Adafactor,Adagrad,Adam,AdamW,Adamax,Ftrl,Lamb,Lion,Nadam,RMSprop,SGD,ASGD,NAdam,RAdam,LBFGS,Rprop,SparseAdam
weight_decay,x,x,x,x,x,x,x,x,x,x,x,x,x,x,x,,,
amsgrad,,,,x,x,,,,,,,,,,,,,
initial_accumulator_value,,,x,,,,x,,,,,,,,,,,
momentum,,,,,,,,,,,x,x,,,,,,
rho,x,,,,,,,,,,x,,,,,,,
centered,,,,,,,,,,,x,,,,,,,
nesterov,,,,,,,,,,,,x,,,,,,
lr,x,x,x,x,x,x,,,,,x,x,x,x,x,x,x,x
maximize,x,x,x,x,x,x,,,,,x,x,x,x,x,,x,x
foreach,x,x,x,x,x,x,,,,,x,x,x,x,x,,x,
