In [350]:
import inspect
import re

import numpy as np

from tensorflow import keras  # noqa: F401
from keras import optimizers as tfoptim
from keras import losses as tfloss
from keras import metrics as tfmetric

import torch
import torch.optim as torchoptim
import torch.nn.modules.loss as torchloss
import torchmetrics as torchmetric

import pandas as pd
pd.set_option('display.max_colwidth', 500)

import flatiron.core.tools as fict

In [410]:
def get_classes(module):
    members = inspect.getmembers(module)
    members = list(filter(lambda x: inspect.isclass(x[1]), members))
    members = list(filter(lambda x: not x[0].startswith('_'), members))
    classes = dict(members)
    return classes

def create_signature(arg, annotation, default):
    if annotation == 'UNTYPED':
        annotation = 'Any'
    if default == 'REQUIRED':
        default = ''
    else:
        default = f' = {default}'
    return f'{arg}: {annotation}{default}'

def get_init_signature_data(class_, remove=['self']):
    sig = inspect.getfullargspec(class_)
    args = sig.args
    for item in remove:
        args.remove(item)

    if sig.defaults is not None:
        d = len(args) - len(sig.defaults)
        req = args[:d]
        opt = args[d:]
        args = {k: 'REQUIRED' for k in req}
        opt = dict(zip(opt, sig.defaults))
        args.update(opt)
    else:
        args = {k: 'REQUIRED' for k in args}
    
    if isinstance(sig.kwonlydefaults, dict):
        args.update(sig.kwonlydefaults)
    
    anno = sig.annotations
    for key, val in args.items():
        if key in anno:
            args[key] = (val, anno[key].__name__)
        else:
            args[key] = (val, 'UNTYPED')
            
    data = []
    for arg, (default, type_) in args.items():
        data.append(dict(
            arg=arg,
            default=default,
            type_=type_,
            signature=create_signature(arg, type_, default),
        ))
    return data

def get_module_class_data(module):
    classes = get_classes(module)
    data = []
    for name, item in classes.items():
        try:
            datum = get_init_signature_data(item)
        except:
            continue
        for row in datum:
            row['class_'] = name
        data.extend(datum)
        
    cols = ['class_', 'arg', 'type_', 'default', 'signature']
    data = pd.DataFrame(data, columns=cols)
    data['library'] = module.__name__.split('.')[0]
    data['module'] = module.__name__
    cols.insert(0, 'library')
    cols.insert(1, 'module')
    data = data[cols]
    
    return data

def _get_data(tf_module, torch_module):
    tf_data = get_module_class_data(tf_module)
    torch_data = get_module_class_data(torch_module)
    data = pd.concat([tf_data, torch_data], axis=0)

    mask = data.library == 'keras'
    data.loc[mask, 'library'] = 'tf'
    
    mask = data.library == 'torchmetrics'
    data.loc[mask, 'library'] = 'torch'

    return data

def get_optimizer_data():
    data = _get_data(tfoptim, torchoptim)

    data['field'] = data['class_']
    mask = data.field == 'Nadam'
    data.loc[mask, 'field'] = 'NAdam'
    
    mask = data.class_.apply(lambda x: x not in ['Optimizer', 'LossScaleOptimizer'])
    data = data[mask]
    
    mask = data.arg != 'params'
    data = data[mask]

    data.reset_index(drop=True, inplace=True)

    return data

def get_loss_data():
    data = _get_data(tfoptim, torchoptim)    
    mask = data.class_.apply(lambda x: x not in ['deprecated'])
    data = data[mask]
    data.reset_index(drop=True, inplace=True)
    return data

def get_metric_data():
    data = _get_data(tfmetric, torchmetric)    
    mask = data.class_.apply(lambda x: x not in ['deprecated'])
    data = data[mask]
    data.reset_index(drop=True, inplace=True)
    return data

def get_class_definitions(data, base_class='BaseConfig'):
    data = data.copy()
    data['config_name'] = data \
        .apply(lambda x: f'class {x.library.capitalize()}{x.class_}Config({base_class}):', axis=1) \
        .apply(lambda x: re.sub(' Tf', ' TF', x))
    class_def = data \
        .groupby('config_name', as_index=False) \
        .signature.agg(lambda x: '    ' + '\n    '.join(sorted(x)))
    class_def = class_def \
        .apply(lambda x: f'{x.config_name}\n{x.signature}', axis=1) \
        .apply(lambda x: re.sub(' +$', '', x))
    return class_def.tolist()

def get_comparison_data(data, mask=None):
    data = data.copy()
    if mask is not None:
        mask = data.library == mask
        data = data[mask]
    data = data.groupby('arg', as_index=False)[['library', 'class_']].agg(lambda x: x.unique())
    data['len_library'] = data.library.apply(len)
    data['len_class'] = data.class_.apply(len)
    data.sort_values(['len_class', 'len_library'], ascending=False, inplace=True)
    return data

def get_comparison_checkboxes(data, mask=None):
    data = get_comparison_data(data, mask=mask)
    output = data.class_.apply(lambda x: {k: k for k in x}).tolist()
    index = data.arg.tolist()
    output = pd.DataFrame(output, index=index).map(lambda x: '' if pd.isnull(x) else 'x')
    return output

def get_base_class_text(data, arg, class_, base_class, class_re):
    mask = data.arg.apply(lambda x: x == arg)
    temp = data[mask]
    if len(temp) == 0:
        raise ValueError(f'{arg} arg sux')
    result = get_class_definitions(temp, base_class)[0]
    result = re.sub(class_re, f'class {class_}', result)
    return result

def get_subclass_text(aux, library, class_, inherit, signature, descriptor):
    caplib = library.capitalize()
    inherit = ', '.join(sorted(filter(lambda x: x != '', inherit)))
    output = f'class {caplib}{descriptor}{class_}Config({caplib}BaseConfig, {inherit}):\n    '
    output = re.sub(r', \)', ')', output)
    regex = '|'.join(aux.keys())
    regex = f'({regex}):'
    if isinstance(signature, str):
        signature = [signature]
    signature = list(filter(lambda x: not re.search(regex, x), signature))
    if signature == []:
        output += 'pass'
    else:
        output += '\n    '.join(sorted(signature))
    output = re.sub(' +$', '', output)
    return output

def print_config_definitions(data, aux, library, descriptor):
    caplib = library.capitalize()
    class_re = f'class {caplib}[a-zA-Z]*Config'

    mask = data.library == library
    data = data[mask]

    # base class
    base = get_class_definitions(data)[0].split('\n')[0]
    base = re.sub(class_re, f'class {caplib}BaseConfig', base)
    base += '\n    name: str'
    print(f'# {library.upper()}' + '-' * 70)
    print(base, '\n\n')

    # helper classes
    print('# HELPERS' + '-' * 70)
    data['inherit'] = ''
    for arg, cls_ in aux.items():
        text = get_base_class_text(data, arg, cls_, 'pyd.BaseModel', class_re)
        print(text, '\n\n')
        mask = data.arg == arg
        data.loc[mask, 'inherit'] = cls_

    # subclasses
    print('# ' + '-' * 78)
    class_def = data \
        .sort_values('class_') \
        .groupby('class_', as_index=False)[['inherit', 'signature']] \
        .agg(lambda x: x) \
        .apply(lambda x: get_subclass_text(
            aux, library, x.class_, x.inherit, x.signature, descriptor), axis=1
        ).tolist()
    for item in class_def:
        print(item, '\n\n')

In [411]:
def print_tf_optimizer_config_definitions():
    # TF
    d = get_optimizer_data()
    mask = d.library == 'tf'
    d = d[mask]

    # TFBaseConfig
    tf_base_args = [
        'clipnorm',
        'clipvalue',
        'ema_momentum',
        'ema_overwrite_frequency',
        'global_clipnorm',
        'gradient_accumulation_steps',
        'learning_rate',
        'loss_scale_factor',
        'name',
        'use_ema',
        'weight_decay',
    ]
    mask = d.arg.apply(lambda x: x in tf_base_args)
    d0 = d[mask]
    tf_base = get_class_definitions(d0)[0]
    tf_base = re.sub('TFAdadeltaConfig', 'TFBaseConfig', tf_base)
    print(tf_base, '\n\n')

    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('class TF[a-zA-Z]*Config', 'class TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')
    
    # TFEpsilonBaseConfig
    tf_eps_args = ['epsilon']
    mask = d.arg.apply(lambda x: x in tf_eps_args)
    d1 = d[mask]
    tf_eps = get_class_definitions(d1, 'TFBaseConfig')[0]
    tf_eps = re.sub('class TF[a-zA-Z]*Config', 'class TFEpsilonBaseConfig', tf_eps)
    print(tf_eps, '\n\n')

    # TFBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[~mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args)
    d2 = d2[~mask]
    tf_subclass = get_class_definitions(d2, 'TFBaseConfig')
    for item in tf_subclass:
        print(item, '\n\n')

    # TFEpsilonBaseConfig subclasses
    eps_classes = d1.class_.unique().tolist()
    mask = d.class_.apply(lambda x: x in eps_classes)
    d2 = d[mask]
    mask = d2.arg.apply(lambda x: x in tf_base_args or x in tf_eps_args)
    d2 = d2[~mask]
    tf_eps_subclass = get_class_definitions(d2, 'TFEpsilonBaseConfig')
    for item in tf_eps_subclass:
        print(item, '\n\n')

In [412]:
def print_torch_loss():
    data = get_loss_data()
    aux = dict(
        reduction='TReduct',
        reduce='TRed',
        size_average='TSize',
        margin='TMarg',
        weight='TWeight',
        eps='TEps',
    )
    print_config_definitions(data, aux, 'torch', 'Loss')

def print_torch_optimizer():
    data = get_optimizer_data()
    aux = dict(
        lr='TLR',
        maximize='TMax',
        foreach='TFor',
        differentiable='TDiff',
        eps='TEps',
        capturable='TCap',
        weight_decay='TDecay',
    )
    print_config_definitions(data, aux, 'torch', 'Opt')
    
def print_torch_metric():
    data = get_metric_data()
    aux = dict(
        ignore_index='TInd',
        nan_strategy='TNan',
        empty_target_action='TAct',
        num_outputs='TOut',
        reduction='TReduct',
        top_k='TTopK',
        num_classes='TCls',
    )
    print_config_definitions(data, aux, 'torch', 'Metric')

In [413]:
# print_torch_optimizer()
# print_torch_loss()
print_torch_metric()

# TORCH----------------------------------------------------------------------
class TorchBaseConfig(BaseConfig):
    name: str 


# HELPERS----------------------------------------------------------------------
class TInd(pyd.BaseModel):
    ignore_index: Optional = None 


class TNan(pyd.BaseModel):
    nan_strategy: Union = warn 


class TAct(pyd.BaseModel):
    empty_target_action: str = pos 


class TOut(pyd.BaseModel):
    num_outputs: int = 1 


class TReduct(pyd.BaseModel):
    reduction: Literal = sum 


class TTopK(pyd.BaseModel):
    top_k: Optional = None 


class TCls(pyd.BaseModel):
    num_classes: int 


# ------------------------------------------------------------------------------
class TorchMetricBLEUScoreConfig(TorchBaseConfig):
    n_gram: int = 4
    smooth: bool = False
    weights: Optional = None 


class TorchMetricBootStrapperConfig(TorchBaseConfig):
    base_metric: Metric
    mean: bool = True
    num_bootstraps: int = 10
    quantile: Union = None
    raw: 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data['inherit'] = ''


In [388]:
data = get_metric_data()
get_comparison_checkboxes(data, 'torch').T.map(lambda x: 1 if x == 'x' else 0).sum().head(40)

ignore_index                    12
nan_strategy                    11
empty_target_action             10
num_outputs                     10
reduction                       10
top_k                            7
num_classes                      5
nan_replace_value                4
adaptive_k                       3
base_metric                      3
data_range                       3
kernel_size                      3
lowercase                        3
postfix                          3
prefix                           3
return_sentence_level_score      3
sigma                            3
zero_mean                        3
allow_unknown_preds_category     2
bias_correction                  2
gaussian_kernel                  2
k1                               2
k2                               2
max_k                            2
metric                           2
mode                             2
multioutput                      2
n_gram                           2
normalize           