In [None]:
from transformers import AutoModel

In [31]:
from forgebox.imports import *

In [71]:
from typing import Dict, Any

In [None]:
model = AutoModel.from_pretrained("bert-base-uncased")

In [5]:
names = dict((k,v) for k,v in model.named_parameters())

In [451]:
class ParamWizard:
    """
    A powerful handler that can manage discriminative configurations
    """

    def __init__(self, model):
        self.model = model
        self.param_dict = dict(model.named_parameters())
        self.param_level = dict((k, len(list(k.split('.'))))
                                for k in self.param_dict.keys())
        self.create_names_df()

    def create_names_df(self,):
        self.names = list(self.param_dict.keys())
        self.names_df = pd.DataFrame(
            dict(
                named=self.param_level.keys(),
                level=self.param_level.values(),
            )
        )

    def find_prefix(self, prefix: str):
        """
        Find all the parameters with a given prefix
        """
        return list(self.names_df[self.names_df.named.str.startswith(prefix)]['named'])

    def find_next_level(self, prefix: str):
        """
        Find all the sub level with given prefix
        """
        prefix_level = len(prefix.split('.')) if prefix != "" else 0
        df = self.names_df[self.names_df.named.str.startswith(prefix)]
        if len(df)<=1:
            return None
        sub_col = df['named'].apply(
            lambda x: x.split('.')[prefix_level])
        col_ct = sub_col.value_counts()
        return col_ct
    
    def find_kw(self, prefix: str):
        df = self.names_df[self.names_df.named.str.startswith(prefix)]
        
        kw_df = pd.DataFrame(
            df['named'].apply(lambda x:x.split('.')).explode().value_counts()
        ).reset_index()
        
        kw_df = kw_df.rename(columns = {"index":"kw"})
        if prefix=="":
            return kw_df
        kw_df = kw_df[~kw_df.kw.isin(prefix.split('.'))].reset_index(drop=True)
        return kw_df
        

    def __len__(self):
        return len(self.param_dict)

    def __getitem__(self, key):
        return self.param_dict[key]

    def create_conf_dict(self, conf):
        if "lr" in conf:
            conf['lr'] = float(conf['lr'])
        if "freeze" in conf:
            del conf['freeze']
        return conf

    def grouping(self, prefix_dict: Dict[str, Dict[str, Any]]):
        """
        Create ```param_groups```
            that can be used for initializing the optimizer
        """
        prefix_list = []
        kw_list = []

        for key in prefix_dict.keys():
            prefix, kw = key.split("|")
        
            prefix_list.append(prefix)
            kw_list.append(kw)

        group_df = pd.DataFrame(
            {"prefix": prefix_list, "kw":kw_list, "conf": list(prefix_dict.values())})
        group_df['level'] = group_df.prefix.apply(lambda x: len(x.split('.')))
        group_df = group_df.sort_values(
            by=['level', "kw", 'prefix'], ascending=True
        ).reset_index(drop=True)

        group_df['sn'] = list(range(len(group_df)))

        mapper = self.names_df.copy()
        mapper['group'] = "*"

        for idx, group in group_df.iterrows():
            keys = self.find_prefix(group.prefix)
            # under the prefix filter
            filtering = mapper.named.isin(keys)
            # under the contains keyword filter
            if group.kw!="[ALL]":
                filtering*=mapper.named.apply(lambda x:group.kw in x.split("."))
            mapper['group'].loc[filtering] = group.sn
        return dict(group_df = group_df, mapper=mapper)
        # create param_groups
        
    def create_param_groups(self, prefix_dict: Dict[str, Dict[str, Any]]):
        dfs = self.grouping(prefix_dict)
        group_df = dfs["group_df"]
        mapper = dfs["mapper"]

        param_groups = []
        # iter over groups
        for idx, group in group_df.iterrows():
            params = []
            if group.conf.get("freeze"):
                freeze=True
            else:
                freeze=False
            kw = group.kw
            conf = self.create_conf_dict(group.conf)
            sub_mapper = mapper[mapper.group == group.sn]
            if len(sub_mapper)==0:
                continue
            for sub_idx, sub in sub_mapper.iterrows():
                parameter = self.param_dict[sub.named]
                if freeze:
                    parameter.requires_grad = False
                else:
                    parameter.requires_grad = True
                    params.append(parameter)
            if len(params)>0:
                param_groups.append(dict(
                    params = params,
                    **conf
                ))
        default_mapper = mapper[mapper.group == "*"]
        # iter over the remaining
        if len(default_mapper)>0:
            params = []
            for sub_idx, sub in default_mapper.iterrows():
                params.append(self.param_dict[sub.named])
            param_groups.append(dict(params=params))
                
        return param_groups


In [452]:
wizard = ParamWizard(model)

In [369]:
wizard.find_next_level("embeddings")

LayerNorm                2
position_embeddings      1
token_type_embeddings    1
word_embeddings          1
Name: named, dtype: int64

In [370]:
wizard.find_kw("")

Unnamed: 0,kw,named
0,encoder,192
1,layer,192
2,attention,120
3,weight,101
4,bias,98
5,output,96
6,dense,74
7,self,72
8,LayerNorm,50
9,value,24


In [388]:
from tai_chi_tuna.front.typer import LIST, INT, FLOAT
from tai_chi_tuna.front.structure import EditableDict
from tai_chi_tuna.front.widget import InteractiveAnnotations
from tai_chi_tuna.config import PhaseConfig
from ipywidgets import Button, Output, HTML, Dropdown, HBox, VBox

In [414]:
def optimizer_group_conf(
    freeze: BOOL(default=False), 
    lr: LIST(options = list(f"1e-{i}" for i in range(1,8)), default="1e-3")="1e-3",
    weight_decay: FLOAT(min_=0.,max_=.3, default=.0, step=.01) = 0.,
):
    return dict(lr=lr, weight_decay=weight_decay)


def combine_prefix(prefix: str, sub: str) -> str:
    if prefix != "":
        return f"{prefix}.{sub}"
    else:
        return sub


def set_opt_confs(wizard: ParamWizard, phase: PhaseConfig):
    editable = EditableDict()
    
    @editable.on_update
    def set_phase(kwargs):
        phase['param_groups'] = kwargs

    def create_conf_ia_callback(prefix, sub, kw_drop):
        def set_conf(conf):
            editable[f"{combine_prefix(prefix, sub)}|{kw_drop.value}"] = conf
        return set_conf

    def deeper_click(prefix, output):
        def wr():
            output.clear_output()
            with output:
                return find_next(prefix)
        return wr

    def btn_conf_widget(prefix, sub, output, kw_drop):
        def btn_conf_widget_click():
            output.clear_output()
            with output:
                title = HTML(
                        f"""<h4>
                        Set Learning for
                        <strong class='text-danger'>
                        "{combine_prefix(prefix, sub)}"</strong></h4>""")
                conf_banner = HBox([
                    title, kw_drop
                ])
                display(
                    conf_banner
                    )
                ia = InteractiveAnnotations(
                    optimizer_group_conf,
                    icon="terminal", description="Yes"
                )

                ia.register_callback(
                    create_conf_ia_callback(prefix, sub, kw_drop)
                )
                display(ia.vbox)
        return btn_conf_widget_click
    
    def control_layer(prefix, sub, ct):
        sub_hbox_list = []
        output = None
        if ct >= 2:
            deeper = Button(button_style="success",
                            icon="plus", description="Deeper")
            # go deeper recursively
            output = Output(
                layout={
                    "border":"2px dashed #7780FF",
                }
            )
            deeper.click = deeper_click(combine_prefix(prefix, sub), output)
            sub_hbox_list.append(deeper)

        kw_df = wizard.find_kw(prefix=combine_prefix(prefix, sub))
        kw_drop = Dropdown(description="Keyword",
                            options=["[ALL]", ]+list(kw_df.kw), value="[ALL]")
        to_conf = Button(button_style="warning",
                            icon="cog", description=f"nn:{sub}")
        output2 = Output()

        to_conf.click = btn_conf_widget(prefix, sub, output2, kw_drop)
            
        sub_hbox_list.append(to_conf)

        sub_hbox = HBox(sub_hbox_list)
        display(sub_hbox)
        if output is not None:
            display(output)
        display(output2)
    
    def find_next(prefix=""):
        display(HTML(f"""<h4>
        Prefix:<strong class='text-primary'>"{prefix}"</strong></h4>"""))
        sub_ct = wizard.find_next_level(prefix)

        if sub_ct is None:
            return

        display(HTML(" - ".join(
            list(
                f"{sub} (x{ct})" for sub, ct in zip(sub_ct.index, sub_ct.values))
        )))
        map_dict = dict((f"{sub} ({ct})", (sub, ct))
                        for sub, ct in zip(sub_ct.index, sub_ct.values))

        @interact
        def select_sub(submodule=map_dict):
            sub, ct = submodule
            control_layer(prefix, sub, ct)
    
    btn_hide = Button(description = "hide", button_style="info", icon="folder")
    btn_config_learning = Button(description = "setting", 
                                 button_style="warning", icon="folder-open")
    controls = HBox([HTML("<h5>Optimizer Details (Optional):</h5>"), btn_hide, btn_config_learning])
    display(controls)
    display(editable)
    over_out = Output(layout={
                    "border":"2px dashed #7780FF",
                })
    
    def start_set_conf():
        over_out.clear_output()
        with over_out:
            find_next("")
            control_layer("", "", 0)
    
    def end_set_conf():
        over_out.clear_output()
    
    display(over_out)
    btn_config_learning.click = start_set_conf
    btn_hide.click = end_set_conf

In [411]:
example_phase = PhaseConfig()

In [None]:
wizard = ParamWizard(model)
set_opt_confs(wizard, example_phase)

In [447]:
example_phase

PhaseConfig:{
  "param_groups": {
    "encoder|bias": {
      "lr": 0.0001,
      "weight_decay": 0.0
    },
    "encoder|[ALL]": {
      "lr": 1e-05,
      "weight_decay": 0.05
    },
    "pooler.dense.weight|[ALL]": {
      "lr": 0.01,
      "weight_decay": 0.0
    }
  }
}

In [None]:
wizard(example_phase['param_groups'])