In [81]:
import inspect
import re

from tensorflow import keras  # noqa: F401
from keras import optimizers as tfoptim
from keras import losses as tfloss

import torch
import torch.optim as torchoptim
import torch.nn.modules.loss as torchloss

import pandas as pd
pd.set_option('display.max_colwidth', 500)

import flatiron.core.tools as fict

In [101]:
def get_classes(module):
    members = inspect.getmembers(module)
    members = list(filter(lambda x: inspect.isclass(x[1]), members))
    members = list(filter(lambda x: not x[0].startswith('_'), members))
    classes = dict(members)
    return classes

def create_signature(arg, annotation, default):
    if annotation == 'UNTYPED':
        annotation = 'Any'
    if default == 'REQUIRED':
        default = ''
    else:
        default = f' = {default}'
    return f'{arg}: {annotation}{default}'

def get_init_signature_data(class_, remove=['self']):
    sig = inspect.getfullargspec(class_)
    args = sig.args
    for item in remove:
        args.remove(item)

    if sig.defaults is not None:
        d = len(args) - len(sig.defaults)
        req = args[:d]
        opt = args[d:]
        args = {k: 'REQUIRED' for k in req}
        opt = dict(zip(opt, sig.defaults))
        args.update(opt)
    else:
        args = {k: 'REQUIRED' for k in args}
    
    if isinstance(sig.kwonlydefaults, dict):
        args.update(sig.kwonlydefaults)
    
    anno = sig.annotations
    for key, val in args.items():
        if key in anno:
            args[key] = (val, anno[key].__name__)
        else:
            args[key] = (val, 'UNTYPED')
            
    data = []
    for arg, (default, type_) in args.items():
        data.append(dict(
            arg=arg,
            default=default,
            type_=type_,
            signature=create_signature(arg, type_, default),
        ))
    return data

def get_module_class_data(module):
    classes = get_classes(module)
    data = []
    for name, item in classes.items():
        try:
            datum = get_init_signature_data(item)
        except:
            continue
        for row in datum:
            row['class_'] = name
        data.extend(datum)
        
    cols = ['class_', 'arg', 'type_', 'default', 'signature']
    data = pd.DataFrame(data, columns=cols)
    data['library'] = module.__name__.split('.')[0]
    data['module'] = module.__name__
    cols.insert(0, 'library')
    cols.insert(1, 'module')
    data = data[cols]
    
    return data

def get_optimizer_data():
    tf_optim_data = get_module_class_data(tfoptim)
    torch_optim_data = get_module_class_data(torchoptim)
    data = pd.concat([tf_optim_data, torch_optim_data], axis=0)
    
    mask = data.library == 'keras'
    data.loc[mask, 'library'] = 'tf'
    
    data['field'] = data['class_']
    mask = data.field == 'Nadam'
    data.loc[mask, 'field'] = 'NAdam'
    
    mask = data.class_.apply(lambda x: x not in ['Optimizer', 'LossScaleOptimizer'])
    data = data[mask]
    
    mask = data.arg != 'params'
    data = data[mask]

    data.reset_index(drop=True, inplace=True)

    return data

def get_loss_data():
    tf_loss_data = get_module_class_data(tfloss)
    torch_loss_data = get_module_class_data(torchloss)
    data = pd.concat([tf_loss_data, torch_loss_data], axis=0)
    
    mask = data.library == 'keras'
    data.loc[mask, 'library'] = 'tf'
    
#     data['field'] = data['class_']
#     mask = data.field == 'Nadam'
#     data.loc[mask, 'field'] = 'NAdam'
    
#     mask = data.class_.apply(lambda x: x not in ['Optimizer', 'LossScaleOptimizer'])
#     data = data[mask]
    
#     mask = data.arg != 'params'
#     data = data[mask]

#     data.reset_index(drop=True, inplace=True)

    return data

def get_class_definitions(data, base_class='BaseConfig'):
    data = data.copy()
    data['config_name'] = data \
        .apply(lambda x: f'class {x.library.capitalize()}{x.class_}Config({base_class}):', axis=1) \
        .apply(lambda x: re.sub(' Tf', ' TF', x))
    class_def = data.groupby('config_name', as_index=False).signature.agg(lambda x: '    ' + '\n    '.join(sorted(x)))
    class_def = class_def.apply(lambda x: f'{x.config_name}\n{x.signature}', axis=1)
    return class_def.tolist()

In [102]:
get_loss_data()

Unnamed: 0,library,module,class_,arg,type_,default,signature
0,tf,keras.api.losses,BinaryCrossentropy,from_logits,UNTYPED,False,from_logits: Any = False
1,tf,keras.api.losses,BinaryCrossentropy,label_smoothing,UNTYPED,0.0,label_smoothing: Any = 0.0
2,tf,keras.api.losses,BinaryCrossentropy,axis,UNTYPED,-1,axis: Any = -1
3,tf,keras.api.losses,BinaryCrossentropy,reduction,UNTYPED,sum_over_batch_size,reduction: Any = sum_over_batch_size
4,tf,keras.api.losses,BinaryCrossentropy,name,UNTYPED,binary_crossentropy,name: Any = binary_crossentropy
...,...,...,...,...,...,...,...
88,torch,torch.nn.modules.loss,TripletMarginWithDistanceLoss,swap,bool,False,swap: bool = False
89,torch,torch.nn.modules.loss,TripletMarginWithDistanceLoss,reduction,str,mean,reduction: str = mean
90,torch,torch.nn.modules.loss,deprecated,message,str,REQUIRED,message: str
91,torch,torch.nn.modules.loss,deprecated,category,Optional,<class 'DeprecationWarning'>,category: Optional = <class 'DeprecationWarning'>


In [157]:
def print_tf_optimizer_config_definitions():
    # TF
    d = get_optimizer_data()
    mask = d.library == 'tf'
    d = d[mask]

    # TFBaseConfig
    tf_base_args = [
        'clipnorm',
        'clipvalue',
        'ema_momentum',
        'ema_overwrite_frequency',
        'global_clipnorm',
        'gradient_accumulation_steps',
        'learning_rate',
        'loss_scale_factor',
        'name',
        'use_ema',
        'weight_decay',
    ]
    mask = d.arg.apply(lambda x: x in tf_base_args)
    d0 = d[mask]
    tf_base = get_class_definitions(d0)[0]
    tf_base = re.sub('TFAdadeltaConfig', 'TFBaseConfig', tf_base)
    print(tf_base, '\n\n')

    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('class TF[a-zA-Z]*Config', 'class TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')
    
    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('class TF[a-zA-Z]*Config', 'class TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')

    # TFBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[~mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args)
    d2 = d2[~mask]
    tf_subclass = get_class_definitions(d2, 'TFBaseConfig')
    for item in tf_subclass:
        print(item, '\n\n')

    # TFEpsilonBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args or x in tf_eps_args)
    d2 = d2[~mask]
    tf_eps_subclass = get_class_definitions(d2, 'TFEpsilonBaseConfig')
    for item in tf_eps_subclass:
        print(item, '\n\n')


def print_torch_optimizer_config_definitions():
    def get_base_class(data, arg, class_, base_class):
        mask = data.arg.apply(lambda x: x == arg)
        temp = data[mask]
        result = get_class_definitions(temp, base_class)[0]
        result = re.sub('class Torch[a-zA-Z]*Config', f'class {class_}', result)
        return result
    
    def get_torch_class_definition(class_, inherit, signature):
        inherit = ', '.join(sorted(filter(lambda x: x != '', inherit)))
        output = f'class Torch{class_}Config(TorchBaseConfig, {inherit}):\n    '
        output = re.sub(', \)', ')', output)
        regex = '(params|lr|maximize|foreach|differentiable|eps|capturable|weight_decay):'
        signature = list(filter(lambda x: not re.search(regex, x), signature))
        output +=  '\n    '.join(sorted(signature))
        return output

    # Torch
    data = get_optimizer_data()
    mask = data.library == 'torch'
    data = data[mask]

    # TFBaseConfig
    mask = data.arg.apply(lambda x: x == 'lr')
    d0 = data[mask]
    torch_base = get_class_definitions(d0)[0]
    torch_base = re.sub('TorchASGDConfig', 'TorchBaseConfig', torch_base)
    print(torch_base, '\n\n')    
    
    # aux base configs
    lut = [
        ('maximize', 'TMax'),
        ('foreach', 'TFor'),
        ('differentiable', 'TDiff'),
        ('eps', 'TEps'),
        ('capturable', 'TCap'),
        ('weight_decay', 'TDecay'),
    ]
    data['inherit'] = ''
    for arg, cls_ in lut:
        print(get_base_class(data, arg, cls_, 'TorchBaseConfig'), '\n\n')
        mask = data.arg == arg
        data.loc[mask, 'inherit'] = cls_
    
    # classes
    class_def = data \
        .sort_values('class_') \
        .groupby('class_', as_index=False)[['inherit', 'signature']] \
        .agg(lambda x: x) \
        .apply(lambda x: get_torch_class_definition(x.class_, x.inherit, x.signature), axis=1) \
        .tolist()  
    for item in class_def:
        print(item, '\n\n')

In [159]:
def print_tf_optimizer_config_definitions():
    def get_base_class(data, arg, class_, base_class):
        mask = data.arg.apply(lambda x: x == arg)
        temp = data[mask]
        result = get_class_definitions(temp, base_class)[0]
        result = re.sub('class TF[a-zA-Z]*Config', f'class {class_}', result)
        return result
    
    def get_tf_class_definition(class_, inherit, signature):
        inherit = ', '.join(sorted(filter(lambda x: x != '', inherit)))
        output = f'class TF{class_}Config(TFBaseConfig, {inherit}):\n    '
        output = re.sub(', \)', ')', output)
        regex = '(epsilon|beta_1|beta_2):'
        signature = list(filter(lambda x: not re.search(regex, x), signature))
        output +=  '\n    '.join(sorted(signature))
        return output

    # TF
    data = get_optimizer_data()
    mask = data.library == 'tf'
    data = data[mask]

    # TFBaseConfig
    args = [
        'clipnorm',
        'clipvalue',
        'ema_momentum',
        'ema_overwrite_frequency',
        'global_clipnorm',
        'gradient_accumulation_steps',
        'learning_rate',
        'loss_scale_factor',
        'name',
        'use_ema',
        'weight_decay',
    ]
    mask = data.arg.apply(lambda x: x in args)
    d0 = data[mask]
    tf_base = get_class_definitions(d0)[0]
    tf_base = re.sub('TF[a-zA-Z]*Config', 'TFBaseConfig', tf_base)
    print(tf_base, '\n\n')    
    
    # aux base configs
    lut = [
        ('epsilon', 'TFEps'),
        ('beta_1', 'TFBeta'),
    ]
    data['inherit'] = ''
    for arg, cls_ in lut:
        print(get_base_class(data, arg, cls_, 'TFBaseConfig'), '\n\n')
        mask = data.arg == arg
        data.loc[mask, 'inherit'] = cls_
    
    # classes
    mask = data.arg.apply(lambda x: x not in args)
    d1 = data[mask]
    class_def = d1 \
        .groupby('class_', as_index=False)[['inherit', 'signature']] \
        .agg(lambda x: x) \
        .apply(lambda x: get_tf_class_definition(x.class_, x.inherit, x.signature), axis=1) \
        .tolist()  
    for item in class_def:
        print(item, '\n\n')
        
print_tf_optimizer_config_definitions()

class TFBaseConfig(BaseConfig):
    clipnorm: Any = None
    clipvalue: Any = None
    ema_momentum: Any = 0.99
    ema_overwrite_frequency: Any = None
    global_clipnorm: Any = None
    gradient_accumulation_steps: Any = None
    learning_rate: Any = 0.001
    loss_scale_factor: Any = None
    name: Any = adadelta
    use_ema: Any = False
    weight_decay: Any = None 


class TFEps(TFBaseConfig):
    epsilon: Any = 1e-07 


class TFBeta(TFBaseConfig):
    beta_1: Any = 0.9 


class TFAdadeltaConfig(TFBaseConfig, TFEps):
    rho: Any = 0.95 


class TFAdafactorConfig(TFBaseConfig):
    beta_2_decay: Any = -0.8
    clip_threshold: Any = 1.0
    epsilon_1: Any = 1e-30
    epsilon_2: Any = 0.001
    relative_step: Any = True 


class TFAdagradConfig(TFBaseConfig, TFEps):
    initial_accumulator_value: Any = 0.1 


class TFAdamConfig(TFBaseConfig, TFBeta, TFEps):
    amsgrad: Any = False 


class TFAdamWConfig(TFBaseConfig, TFBeta, TFEps):
    amsgrad: Any = False 


class TFAdamaxConfig(

In [158]:
print_tf_optimizer_config_definitions()

class TFBaseConfig(BaseConfig):
    clipnorm: Any = None
    clipvalue: Any = None
    ema_momentum: Any = 0.99
    ema_overwrite_frequency: Any = None
    global_clipnorm: Any = None
    gradient_accumulation_steps: Any = None
    learning_rate: Any = 0.001
    loss_scale_factor: Any = None
    name: Any = adadelta
    use_ema: Any = False
    weight_decay: Any = None 


class TFEpsilonBaseConfig(TFBaseConfig):
    epsilon: Any = 1e-07 


class TFEpsilonBaseConfig(TFBaseConfig):
    epsilon: Any = 1e-07 


class TFAdafactorConfig(TFBaseConfig):
    beta_2_decay: Any = -0.8
    clip_threshold: Any = 1.0
    epsilon_1: Any = 1e-30
    epsilon_2: Any = 0.001
    relative_step: Any = True 


class TFFtrlConfig(TFBaseConfig):
    beta: Any = 0.0
    initial_accumulator_value: Any = 0.1
    l1_regularization_strength: Any = 0.0
    l2_regularization_strength: Any = 0.0
    l2_shrinkage_regularization_strength: Any = 0.0
    learning_rate_power: Any = -0.5 


class TFLionConfig(TFBaseConfig)

In [75]:
print_torch_optimizer_config_definitions()

In [139]:
def print_tf_loss_config_definitions():
    # TF
    d = get_loss_data()
    mask = d.library == 'tf'
    d = d[mask]

    # TFBaseConfig
    tf_base_args = [
        'clipnorm',
        'clipvalue',
        'ema_momentum',
        'ema_overwrite_frequency',
        'global_clipnorm',
        'gradient_accumulation_steps',
        'learning_rate',
        'loss_scale_factor',
        'name',
        'use_ema',
        'weight_decay',
    ]
    mask = d.arg.apply(lambda x: x in tf_base_args)
    d0 = d[mask]
    tf_base = get_class_definitions(d0)[0]
    tf_base = re.sub('TFAdadeltaConfig', 'TFBaseConfig', tf_base)
    print(tf_base, '\n\n')

    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('TFAdadeltaConfig', 'TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')

    # TFBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[~mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args)
    d2 = d2[~mask]
    tf_subclass = get_class_definitions(d2, 'TFBaseConfig')
    for item in tf_subclass:
        print(item, '\n\n')

    # TFEpsilonBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args or x in tf_eps_args)
    d2 = d2[~mask]
    tf_eps_subclass = get_class_definitions(d2, 'TFEpsilonBaseConfig')
    for item in tf_eps_subclass:
        print(item, '\n\n')
        
def get_comparison_data(data, mask=None):
    data = data.copy()
    if mask is not None:
        mask = data.library == mask
        data = data[mask]
    data = data.groupby('arg', as_index=False)[['library', 'class_']].agg(lambda x: x.unique())
    data['len_library'] = data.library.apply(len)
    data['len_class'] = data.class_.apply(len)
    data.sort_values(['len_class', 'len_library'], ascending=False, inplace=True)
    return data

def get_comparison_checkboxes(data, mask=None):
    data = get_comparison_data(data, mask=mask)
    output = data.class_.apply(lambda x: {k: k for k in x}).tolist()
    index = data.arg.tolist()
    output = pd.DataFrame(output, index=index).map(lambda x: '' if pd.isnull(x) else 'x')
    return output

In [135]:
data = get_optimizer_data()

In [140]:
get_comparison_checkboxes(data, 'tf')

Unnamed: 0,Adadelta,Adafactor,Adagrad,Adam,AdamW,Adamax,Ftrl,Lamb,Lion,Nadam,RMSprop,SGD
clipnorm,x,x,x,x,x,x,x,x,x,x,x,x
clipvalue,x,x,x,x,x,x,x,x,x,x,x,x
ema_momentum,x,x,x,x,x,x,x,x,x,x,x,x
ema_overwrite_frequency,x,x,x,x,x,x,x,x,x,x,x,x
global_clipnorm,x,x,x,x,x,x,x,x,x,x,x,x
gradient_accumulation_steps,x,x,x,x,x,x,x,x,x,x,x,x
learning_rate,x,x,x,x,x,x,x,x,x,x,x,x
loss_scale_factor,x,x,x,x,x,x,x,x,x,x,x,x
name,x,x,x,x,x,x,x,x,x,x,x,x
use_ema,x,x,x,x,x,x,x,x,x,x,x,x


In [118]:
get_comparison_data(q, 'torch')

Unnamed: 0,ASGD,Adadelta,Adafactor,Adagrad,Adam,AdamW,Adamax,LBFGS,NAdam,RAdam,RMSprop,Rprop,SGD,SparseAdam
lr,x,x,x,x,x,x,x,x,x,x,x,x,x,x
maximize,x,x,x,x,x,x,x,,x,x,x,x,x,x
foreach,x,x,x,x,x,x,x,,x,x,x,x,x,
differentiable,x,x,,x,x,x,x,,x,x,x,x,x,
weight_decay,x,x,x,x,x,x,x,,x,x,x,,x,
eps,,x,x,x,x,x,x,,x,x,x,,,x
capturable,x,x,,,x,x,x,,x,x,x,x,,
betas,,,,,x,x,x,,x,x,,,,x
fused,,,,x,x,x,,,,,,,x,
alpha,x,,,,,,,,,,x,,,


In [161]:
from typing import Optional

import pydantic as pyd

OptBool = Optional[bool]
# ------------------------------------------------------------------------------


class BaseConfig(pyd.BaseModel):
    name: str


class TFBaseConfig(BaseConfig):
    clipnorm: OptBool = None
    clipvalue: OptBool = None
    ema_momentum: float = 0.99
    ema_overwrite_frequency: OptBool = None
    global_clipnorm: OptBool = None
    gradient_accumulation_steps: OptBool = None
    learning_rate: float = 0.001
    loss_scale_factor: OptBool = None
    use_ema: bool = False
    weight_decay: OptBool = None


class TFEps(pyd.BaseModel):
    epsilon: float = 1e-07


class TFBeta(pyd.BaseModel):
    beta_1: float = 0.9
    beta_2: float = 0.99


class TFAdafactorConfig(TFBaseConfig):
    beta_2_decay: float = -0.8
    clip_threshold: float = 1.0
    epsilon_1: float = 1e-30
    epsilon_2: float = 0.001
    relative_step: bool = True


class TFFtrlConfig(TFBaseConfig):
    beta: float = 0.0
    initial_accumulator_value: float = 0.1
    l1_regularization_strength: float = 0.0
    l2_regularization_strength: float = 0.0
    l2_shrinkage_regularization_strength: float = 0.0
    learning_rate_power: float = -0.5


class TFLionConfig(TFBaseConfig, TFBeta):
    pass


class TFSGDConfig(TFBaseConfig):
    momentum: float = 0.0
    nesterov: bool = False


class TFAdadeltaConfig(TFBaseConfig, TFEps):
    rho: float = 0.95


class TFAdagradConfig(TFBaseConfig, TFEps):
    initial_accumulator_value: float = 0.1


class TFAdamConfig(TFBaseConfig, TFBeta, TFEps):
    amsgrad: bool = False


class TFAdamWConfig(TFBaseConfig, TFBeta, TFEps):
    amsgrad: bool = False


class TFAdamaxConfig(TFBaseConfig, TFBeta, TFEps):
    pass


class TFLambConfig(TFBaseConfig, TFBeta, TFEps):
    pass


class TFNadamConfig(TFBaseConfig, TFBeta, TFEps):
    pass


class TFRMSpropConfig(TFBaseConfig, TFEps):
    centered: bool = False
    momentum: float = 0.0
    rho: float = 0.9


In [165]:
TFAdamWConfig(name='your-mom').dict()

/tmp/ipykernel_1769482/357340166.py:1: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  TFAdamWConfig(name='your-mom').dict()


{'epsilon': 1e-07,
 'beta_1': 0.9,
 'beta_2': 0.99,
 'name': 'your-mom',
 'clipnorm': None,
 'clipvalue': None,
 'ema_momentum': 0.99,
 'ema_overwrite_frequency': None,
 'global_clipnorm': None,
 'gradient_accumulation_steps': None,
 'learning_rate': 0.001,
 'loss_scale_factor': None,
 'use_ema': False,
 'weight_decay': None,
 'amsgrad': False}