In [141]:
import inspect

from tensorflow import keras  # noqa: F401
from keras import optimizers as tfoptim

import torch
import torch.optim as torchoptim

import pandas as pd
pd.set_option('display.max_colwidth', 500)

import flatiron.core.tools as fict

In [113]:
def get_classes(module):
    members = inspect.getmembers(module)
    members = list(filter(lambda x: inspect.isclass(x[1]), members))
    members = list(filter(lambda x: not x[0].startswith('_'), members))
    classes = dict(members)
    return classes

def get_init_signature_data(class_, remove=['self']):
    sig = inspect.getfullargspec(class_)
    args = sig.args
    for item in remove:
        args.remove(item)

    if sig.defaults is not None:
        d = len(args) - len(sig.defaults)
        req = args[:d]
        opt = args[d:]
        args = {k: 'REQUIRED' for k in req}
        opt = dict(zip(opt, sig.defaults))
        args.update(opt)
    else:
        args = {k: 'REQUIRED' for k in args}
    
    if isinstance(sig.kwonlydefaults, dict):
        args.update(sig.kwonlydefaults)
    
    anno = sig.annotations
    for key, val in args.items():
        if key in anno:
            args[key] = (val, anno[key])
        else:
            args[key] = (val, 'UNTYPED')
    data = [dict(arg=k, default=v[0], type=v[1]) for k, v in args.items()]
    return data

def get_module_class_data(module):
    classes = get_classes(module)
    data = []
    for name, item in classes.items():
        datum = get_init_signature_data(item)
        for row in datum:
            row['class_'] = name
        data.extend(datum)
        
    cols = ['class_', 'arg', 'type_', 'default']
    data = pd.DataFrame(data, columns=cols)
    data['library'] = module.__name__.split('.')[0]
    data['module'] = module.__name__
    cols.insert(0, 'library')
    cols.insert(1, 'module')
    data = data[cols]
    
    return data

def get_optimizer_data():
    tf_optim_data = get_module_class_data(tfoptim)
    torch_optim_data = get_module_class_data(torchoptim)
    data = pd.concat([tf_optim_data, torch_optim_data], axis=0)
    
    mask = data.library == 'keras'
    data.loc[mask, 'library'] = 'tf'
    
    data['field'] = data['class_']
    mask = data.field == 'Nadam'
    data.loc[mask, 'field'] = 'NAdam'
    return data

In [156]:
data = get_optimizer_data()

opt = data.groupby('arg', as_index=False)[['library', 'class_']].agg(lambda x: x.unique())
opt['len_library'] = opt.library.apply(len)
opt['len_class'] = opt.class_.apply(len)
opt.sort_values(['len_library', 'len_class'], ascending=False, inplace=True)

opt['pref0'] = opt.library.apply(lambda x: '_'.join(x) + '_')
mask = opt.len_library == opt.len_library.max()
opt.loc[mask, 'pref0'] = ''

opt['pref1'] = opt.class_.apply(lambda x: '_'.join(x))
mask = opt.len_class == opt.len_class.max()
opt.loc[mask, 'pref1'] = ''

opt['prefix'] = opt.apply(lambda x: f'{x.pref0}{x.pref1}__', axis=1)
del opt['pref0']
del opt['pref1']

In [164]:
d = opt
mask = d.len_library == 2
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
59,weight_decay,"[tf, torch]","[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD, ASGD, NAdam, RAdam]",2,15,__
1,amsgrad,"[tf, torch]","[Adam, AdamW]",2,2,Adam_AdamW__
31,initial_accumulator_value,"[tf, torch]","[Adagrad, Ftrl]",2,2,Adagrad_Ftrl__
47,momentum,"[tf, torch]","[RMSprop, SGD]",2,2,RMSprop_SGD__
53,rho,"[tf, torch]","[Adadelta, RMSprop]",2,2,Adadelta_RMSprop__
9,centered,"[tf, torch]",[RMSprop],2,1,RMSprop__
50,nesterov,"[tf, torch]",[SGD],2,1,SGD__


In [162]:
lib = 'tf'

d = opt
mask = d.len_library == 1
d = d[mask]
mask = d.library.apply(lambda x: x[0] == lib)
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
11,clipnorm,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
12,clipvalue,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
19,ema_momentum,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
20,ema_overwrite_frequency,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
28,global_clipnorm,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
29,gradient_accumulation_steps,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
38,learning_rate,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
41,loss_scale_factor,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
49,name,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__
58,use_ema,[tf],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, Ftrl, Lamb, Lion, Nadam, RMSprop, SGD]",1,12,tf_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_Ftrl_Lamb_Lion_Nadam_RMSprop_SGD__


In [163]:
lib = 'torch'

d = opt
mask = d.len_library == 1
d = d[mask]
mask = d.library.apply(lambda x: x[0] == lib)
d = d[mask]
d

Unnamed: 0,arg,library,class_,len_library,len_class,prefix
51,params,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, LBFGS, NAdam, Optimizer, RAdam, RMSprop, Rprop, SGD, SparseAdam]",1,15,torch___
42,lr,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, LBFGS, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam]",1,14,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_LBFGS_NAdam_RAdam_RMSprop_Rprop_SGD_SparseAdam__
46,maximize,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam]",1,13,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop_SGD_SparseAdam__
26,foreach,[torch],"[ASGD, Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop, SGD]",1,12,torch_ASGD_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop_SGD__
17,differentiable,[torch],"[ASGD, Adadelta, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop, SGD]",1,11,torch_ASGD_Adadelta_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop_SGD__
21,eps,[torch],"[Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, SparseAdam]",1,10,torch_Adadelta_Adafactor_Adagrad_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_SparseAdam__
8,capturable,[torch],"[ASGD, Adadelta, Adam, AdamW, Adamax, NAdam, RAdam, RMSprop, Rprop]",1,9,torch_ASGD_Adadelta_Adam_AdamW_Adamax_NAdam_RAdam_RMSprop_Rprop__
7,betas,[torch],"[Adam, AdamW, Adamax, NAdam, RAdam, SparseAdam]",1,6,torch_Adam_AdamW_Adamax_NAdam_RAdam_SparseAdam__
27,fused,[torch],"[Adagrad, Adam, AdamW, SGD]",1,4,torch_Adagrad_Adam_AdamW_SGD__
0,alpha,[torch],"[ASGD, RMSprop]",1,2,torch_ASGD_RMSprop__
