# Tai-Chi core engine
> What library should we be?
This proposal, is more of a branding proposal, targeting people who's going to play with AI, from various back grounds.
* That means, we're going to talk about how people view this library, how they think of ```pip install -Uqq unpackai``` like if I have dandroff recengtly and my mind just jump right into the headshoulders.
* For ML, currently, the **jump** is about the following, this is not a throught marketing research, just quick examples from a deep learning practitioner:
    * Try free structure quickly, do experiments: pytorch
    * Goes to production, run model on edge devices, Tensorflow
    * Play with GPU accelerated tensor calculation: Jax
    * Play with tf but in simpler layer sense: Keras
    * Transformer in clean code: Huggingface
    * Visualize things with interactive features: Plotly
    * Deploy model prototype: streamlit
* Surely you think I fail to mention ```fastai```, this is where the **branding goes wrong**, fastai library is bounded tightly with the education. It's considered a good creation along side its famous course, after the education. Its product feature has many limitation: docs too brief, not supporting multi-device training, very limited numbers of callbacks went beyond Jeremy H's own teaching.
* Most important of all, ```fastai``` isn't enjoyable to use, **it's just packing many things mentioned in the course**.

## What we shouldn't be
I know the course is life changing for me and I feel very grateful. But let's not be their library.

### The pipeline wrapping plan
It all started from a notebook, quite like a template notebook we have for the course. A notebook that achieves the data processing, model building, interpretation for a specific DL task.

Then came the packaging part, we wrap **dozens of lines of codes**, which scares our kind students, into simple functions, or class.

The wrapped functions are simple to use, to look at, it was executed in 1 line mostly. So friendly to our innocent students.

This is what a python library is about, right? Wrap things into functions which can be further wraped into even less lines.

It's nothing wrong about this approach at first. Some DL task, if need be, can be shrank into **less than 10 lines of codes.**
* The 1st line load the data, 
* the 2nd line set how to transform data, 
* the 3rd line build/load the model, 
* the 4th line trained model.
* the 5th line interpret the model in various ways

Well the above do look like a decent **structure** to start with, then we pave out the tasks, different contributors take different tasks, can be developed in parallel, and we can have the agile/crum/kanban fun to track our progress!

Even if we do this, we could build a useful product, no less.

#### Bad side about pipeline wrapping plan
So so many libraries are doing the same, from awesome people even. They usually end up to the following:
* It's a mess of functions, among them many good functions but a mess. It ends up a branding disaster. (**There is no way to answer: what can you library do, in a slogan**)
* A model zoo for a specific domain.
* Wraping things up means less and less involvement from the user. The user will spend very little time play with the functions, and each function usually achieve very specific task. Actually I do believe there is a equilibrium like:
$\large{UserPlayHours = a * Task Transferability}$

## Alternative approach

The salvation plan is somehow simpler at how we perceive the library:
* A library that allows you experiment AI/DL for various tasks

**BUT!!!**
* Many module with in the pipeline should be dropdown-list/checkbox **Choosable**.
* The **level of detail** we let them to play and choose, is the **level of the difficulty** we want them to enjoy

### What is level of detail ?
Level of details is the level of fuss we want user to focus on, this is the exact part fastai library got **WRONG**, which will explain most of our struggle so far:
* It offers smooth/ easy pipelines, for newbies and business people even.
* Any amount of reconfigure, is usually way too complicated for such audience
* There is a **GAP** between the 2 points above, hence no room for playing

#### Keras Example 
I started my AI journey with Keras, and I love keras by that time, because:
* Keras plays with **layers**(eg. Linear, Convolution), its most strenth is at astracting details beneath this level, and let users play with layers. 
* I spent lots of time, having fun playing with layers
* Aside from the things I have to redesign layer, I can deploy almost all kinds of models mentioned in any DL paper (𝑈𝑠𝑒𝑟𝑃𝑙𝑎𝑦𝐻𝑜𝑢𝑟𝑠=𝑎∗𝑇𝑎𝑠𝑘𝑇𝑟𝑎𝑛𝑠𝑓𝑒𝑟𝑎𝑏𝑖𝑙𝑖𝑡𝑦)

#### Pytorch lightning example
Well I moved on to the career team. I have to deal with layer level, I have to deal with different data/forward pipeline. PL is a good library because:
* It allows me play with the things I mentioned, but save my energy on things like looping, logging, multidevice training detail etc.
* If you see a training notebook built by PL, you'll see very little lines around training template.
* You'll find about a lots of lines on the specifications you intend to be different.

>The branding image of the examples are simple:
* Keras: play TensorFlow in a concept of layers
* Pytorch-Lightning: writting less template code

#### Unpackai Example
For our lib, I intend for them to focus on, exactly the same range of things we want people to learn:
* choose the columns they intend to use, in what way
* choose the data transformations
* choose the loss, the model structure to use (not keras.layer, not nn.module)
* hit run

## Demo of such example

versions
```json
{
  "torch": "1.7.1",
  "pytorch-lightning": "1.3.8",
  "unpackai": "0.1.8.10",
  "forgebox": "0.4.18.5"
}
```

In [1]:
from ipywidgets import interact, interact_manual
from forgebox.imports import *
from forgebox.category import Category
from forgebox.html import DOM, list_group, list_group_kv
from tqdm.notebook import tqdm

# for the purpose of easier developing
# I'm using pytorch-lightning here
# This is a questionable, tough and revokable dicision
import pytorch_lightning as pl

from typing import List, Dict, Callable, Any, Tuple
from torchvision import transforms as tfm
from PIL import Image
from ipywidgets import (
    VBox, HBox, HTML, Layout, Button, Output,
    Text, Textarea, IntSlider, FloatSlider, SelectMultiple, Dropdown, Checkbox
)
from typing import List, Dict, Any, Callable
from forgebox.thunder.callbacks import DataFrameMetricsCallback

Let's skip data download here, I mean it's download, we're not going to reinvent brilliant stuff around download

### Step 1 Everything starts with dataframe

For fastai, everything starts from list, an **ItemList** to be specific. **ImageList** and **TextList** is [**ItemList**](https://fastai1.fast.ai/tutorial.itemlist.html) with some slight enhanced feature.```[🧂, 🏓, 🍷, 🐻]```

For the clarity of education, or for simplecity as ultimate form of beauty, we use [**DataFrame**](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html) as starting point, ItemList in table format. In this way, every dataset has the same starting point, even the tabular data. 

## Helpers

### Phase/ Configuration

In [2]:
class Phase:
    """
    A configuration management mechanism
    """
    is_phase = True
    def __init__(self, **kwargs):
        self.config = dict()
        self.config.update(kwargs)
        
    def __setitem__(self, k, v):
        self.config[k] = v
    
    def __getitem__(self, k):
        return self.config[k]
    
    def __contains__(self, k):
        return k in self.config
    
    def __call__(self):
        return self.get_data(self.config)
    
    def get_data(self, raw):
        """
        Reconstruct back to dict or list or value format
        """
        if hasattr(raw,"is_phase"):
            return raw.get_data(raw.config)
        if type(raw) == list:
            raw = list(self.get_data(i) for i in raw)
            return raw
        if type(raw) == dict:
            for k, v in raw.items():
                raw[k] = self.get_data(v)
            return raw
        return raw

    def __str__(self):
        return json.dumps(self(), indent=2)
    
    def __repr__(self,):
        return f"Phase:{self}"
    
def save_phase():
    global phase
    global PROJECT
    PROJECT = Path(PROJECT)
    PROJECT.mkdir(exist_ok=True, parents=True)
    with open(PROJECT/"phase.json", "w") as f:
        f.write(str(phase))
        
def load_phase(new: bool=False):
    global phase
    global PROJECT
    PROJECT = Path(PROJECT)
    PROJECT.mkdir(exist_ok=True, parents=True)
    if new:
        return Phase()
    if (PROJECT/"phase.json").exists():
        with open(PROJECT/"phase.json", "r") as f:
            phase = Phase(**json.loads(f.read()))
            print(phase)
    else:
        phase = Phase()

### Widget Helpers

#### flash

In [3]:
class Flash:
    @staticmethod
    def create_msg_box(color, text, key:str = None):
        text = str(text)
        if key is not None:
            key = f"<strong>{key}</strong> "
        else:
            key = ""
        text_bar = HTML(f"""<div class='alert alert-{color}' role='alert'>
        {key} {text}</div>""", layout=Layout(width='95%'))
        close_btn = Button(description="x", layout=Layout(width='3%'))
        
        total = HBox([text_bar, close_btn])
        def close_bar():
            total.close()
        close_btn.click = close_bar
        return total
    
    @classmethod
    def get_info(cls, text, key:str = None):
        return cls.create_msg_box('info', text, key)
    
    @classmethod
    def get_warning(cls, text, key:str = None):
        return cls.create_msg_box('warning', text, key)
    
    @classmethod
    def get_danger(cls, text, key:str = None):
        return cls.create_msg_box('danger', text, key)
    
    @classmethod
    def get_success(cls, text, key:str = None):
        return cls.create_msg_box('success', text, key)
    
    @classmethod
    def info(cls, text, key:str = None):
        display(cls.get_info(text, key))
    
    @classmethod
    def warning(cls, text, key:str = None):
        display(cls.get_warning(text, key))
    
    @classmethod
    def danger(cls, text, key:str = None):
        display(cls.get_danger(text, key))
    
    @classmethod
    def success(cls, text, key:str = None):
        display(cls.get_success(text, key))

#### Editable List & Dict
And editable list within jupyter notebook

In [4]:
total_width = Layout(width="100%")


class EditableList(VBox):
    """
    Interactive list
    You can add item to the list
    Each added item has a remove button to remove such item
    """

    def __init__(self, data_list: List[Any] = [], pretty_json: bool = True):
        super().__init__([], layout=total_width)
        self.pretty_json = pretty_json
        for data in data_list:
            self+data

    def create_line(self, data):
        children = list(self.children)
        children.append(self.new_line(data))
        self.children = children

    def data_to_dom(self, data):
        if self.pretty_json:
            pretty = list_group_kv(data) if hasattr(
                data, "keys") else list_group(data)
            return HTML(str(pretty), layout=total_width)
        else:
            return HTML(json.dumps(data))

    def new_line(self, data) -> HBox:
        del_btn = Button(description="Remove", icon="trash")
        del_btn.button_style = 'danger'
        hbox = HBox([del_btn, self.data_to_dom(data)],
                    layout=total_width, box_style='info')
        hbox.data = data

        def remove_hbox():
            children = list(self.children)
            for i, c in enumerate(children):
                if id(c) == id(hbox):
                    children.remove(c)
            self.children = children
        del_btn.click = remove_hbox
        return hbox

    def __add__(self, data):
        self.create_line(data)
        return self

    def get_data(self) -> List[Any]:
        """
        Return the data of this list
        """
        return list(x.data for x in self.children)


class EditableDict(VBox):
    """
    Interactive dictionary
    You can add item to the dictionary
    Each added item has a remove button to remove such item
    """
    def __init__(self, data_dict: Dict[str, Any] = dict(), pretty_json: bool = True):
        super().__init__([], layout=total_width)
        self.pretty_json = pretty_json
        self+data_dict
        
    def on_update(self, func):
        """
        A decorator to set a function
        Every time the dict changed
        Will execute this function
        the default arg is the dictionary data
        """
        self.update_func = func
        return func
    
    def run_update(self):
        if hasattr(self, "update_func"):
            self.update_func(self.get_data())
        
    def create_line(self, key: str, data: Any):
        children_map = dict((child.key, child) for child in self.children)
        children_map[key] = self.new_line(key, data)
        self.children = list(children_map.values())
        self.run_update()
        
    def data_to_dom(self, data):
        if self.pretty_json:
            pretty = list_group_kv(data) if hasattr(
                data, "keys") else list_group(data)
            return HTML(str(pretty), layout=total_width)
        else:
            return HTML(json.dumps(data))

    def new_line(self, key: str, data: Any) -> HBox:
        del_btn = Button(description="Remove", icon="trash")
        del_btn.button_style = 'danger'
        key_info = HTML(f"<h4 class='text-primary p-1'>{key}</h4>")
        hbox = HBox([VBox([key_info, del_btn]), self.data_to_dom(data)],
                    layout=total_width, box_style='')
        hbox.data = data
        hbox.key = key

        def remove_hbox():
            children = list(self.children)
            for c in children:
                if id(c) == id(hbox):
                    children.remove(c)
            self.children = children
            self.run_update()
        del_btn.click = remove_hbox
        return hbox
    
    def __setitem__(self, k, v):
        self.create_line(k, v)
    
    def __add__(self, kv):
        for k,v in kv.items():
            self.create_line(k, v)
        return self

    def get_data(self) -> Dict[str, Any]:
        """
        Return the data of this dict
        """
        return dict((x.key,x.data) for x in self.children)

#### StepByStep

In [5]:
class LivingStep:
    """
    A step interactive for StepByStep
    """

    def __init__(
        self, func: Callable,
        top_block: HTML = None
    ):
        self.output = Output()
        self.func = func
        self.top_block = top_block

    def __call__(self, **kwargs):
        with self.output:
            if self.top_block is not None:
                display(self.top_block)
            return self.func(**kwargs)

    def new_top_block(self, top_block):
        self.top_block = top_block


class StepByStep:
    """
    A tool to manage progress step by step
    """

    def __init__(
        self,
        funcs: Dict[str, Callable],
        top_board: HTML = None,
        kwargs: Dict[str, Any] = dict()
    ):
        self.step_keys: List[str] = list(funcs.keys())
        self.steps: Dict[str, LivingStep] = dict(
            (k, LivingStep(f)) for k, f in funcs.items())
        self.furthest: int = 0
        self.current: int = -1
        self.kwargs: Dict[str, Any] = kwargs
        self.execute_cache: Dict[str, bool] = dict()
        self.top_board: HTML = top_board
        self.page_output: Output = Output()
        self.footer: Output = Output()
        self.create_widget()

    def rerun(self,**kwargs):
        """
        Rerun the current step function
        """
        # find the step
        step: LivingStep = self.steps[self.step_keys[self.current]]
        # clear old output
        step.output.clear_output()
        self.kwargs.update(kwargs)
        step(progress=self, **self.kwargs)

    def create_control_bar(self,):
        self.bar_hbox = list()
        self.next_btn: Button = Button(
            description="Next", icon='check', button_style='info')
        self.rerun_btn = Button(description="Rerun Step",
                                icon='play', button_style='success')
        self.title = HTML(f"<h4 class='text-primary'>Step By Step</h4>")
        self.next_btn.click = self.next_step
        self.rerun_btn.click = self.rerun
        self.bar_hbox.append(self.title)
        self.bar_hbox.append(self.next_btn)
        self.bar_hbox.append(self.rerun_btn)
        return HBox(self.bar_hbox)

    def create_widget(self) -> None:
        self.vbox_list = []
        if self.top_board is not None:
            self.vbox_list.append(self.top_board)
        
        # create buttons for progress axis
        self.progress_btns = dict(
            (k, Button(
                description=f"{i+1}:{k}",
                icon="cube",
                button_style="danger"
                if i <= self.furthest else ""))
            for i, (k, v) in enumerate(self.steps.items())
        )
        # assign action to first button
        first_btn: Button = list(self.progress_btns.values())[0]
        first_btn.click: Callable = self.to_page_action(0)
        self.progress_bar = HBox(list(self.progress_btns.values()))
        
        # assemble the entire widget
        self.vbox_list.append(self.progress_bar)
        self.vbox_list.append(self.create_control_bar())
        self.vbox_list.append(self.page_output)
        self.widget = VBox(self.vbox_list)

    def to_page_action(
        self, page_id: int
    ) -> Callable:
        """
        generate the button click function
        """
        def to_page_func():
            return self[page_id]
        return to_page_func

    def update_furthest(self):
        """
        Update the "furthest mark"
        Also enact the next progress button
        """
        if self.furthest < self.current:
            if self.current < len(self):
                # update even button
                btn = self.progress_btns[self.step_keys[self.current]]
                btn.click = self.to_page_action(
                    self.current)
                btn.button_style = 'danger'
            self.furthest = self.current
            
    def __repr__(self):
        keys = " => ".join(self.step_keys)
        return f"Progress Axis: [{keys}]"

    def __getitem__(self, page_id):
        """
        Display a single page
        """
        if (page_id < 0) or (page_id >= len(self)):
            return
        self.current: int = page_id
        key: str = self.step_keys[page_id]
        step: LivingStep = self.steps[key]
        self.title.value: str = f"<h4 class='text-danger'>Step {page_id+1}: {key}</h4>"
        self.page_output.clear_output()

        with self.page_output:
            display(step.output)
        if key not in self.execute_cache:
            rt = step(progress=self, **self.kwargs)
            if hasattr(rt,"keys"):
                self.kwargs(rt)
            self.execute_cache[key] = True

    def next_step(self, **kwargs):
        self.current += 1
        if self.current >= len(self):
            self.current = 0
        self.update_furthest()
        return self[self.current]

    def __len__(self):
        return len(self.step_keys)

    def __call__(self, **kwargs):
        """
        Start the entire progress widget
        """
        display(self.widget)
        display(self.footer)
        self.kwargs.update(kwargs)
        self.next_step(**self.kwargs)

    def show_in(self, step_name: str) -> Callable:
        """
        A decorator that will make the function
            to show under a specific step window
        """
        step = self.steps[step_name]

        def decorator(func: Callable) -> Callable:
            def wrapper(*args, **kwargs):
                with step.output:
                    return func(*args, **kwargs)
            return wrapper
        return decorator

    def show_footer(self, func: Callable):
        """
        A decorator, where functions excuted
            within this, will be showon under footer
        """
        def wrapper(*args, **kwargs):
            with self.footer:
                return func(*args, **kwargs)
        return wrapper

### typings for interactives

Typing for interactive details
```self()``` will create widgets automatically

In [6]:
class InteractiveTyping:
    """
    Typing for interactive details
    self.__call__() will create widgets directly
    """
    name = "anything"
    is_typing = True

    def solid(self, default) -> None:
        """
        Reset default value
        """
        if default is not None:
            self.default = default


class INT(InteractiveTyping):
    def __init__(self, min_: int = 0, max_: int = 10, step: int = 1, default: int = None):
        self.max_ = max_
        self.min_ = min_
        self.step = step
        self.default = default if default is not None else 1

    def __repr__(self):
        return f"int[{self.min_}-{self.max_}, :{self.step}]={self.default}"

    def __call__(self, default: int = None):
        self.solid(default)
        return IntSlider(
            value=self.default,
            min=self.min_,
            max=self.max_,
            step=self.step,
        )


class BOOL(InteractiveTyping):
    def __init__(self, name:str="", default: bool = True,):
        self.default = default
        self.name = name

    def __repr__(self):
        return f"bool={self.default}"

    def __call__(self, default: bool = None) -> Checkbox:
        self.solid(default)
        return Checkbox(value=self.default, description=self.name)


class FLOAT(InteractiveTyping):
    def __init__(self, min_: int = -1., max_: int = 1., step: int = .01, default: int = None):
        self.max_ = max_
        self.min_ = min_
        self.step = step
        self.default = default if default is not None else 0.01

    def __repr__(self):
        return f"float[{self.min_}-{self.max_}, :{self.step}]={self.default}"

    def __call__(self, default: int = None):
        self.solid(default)
        return FloatSlider(
            value=self.default,
            min=self.min_,
            max=self.max_,
            step=self.step,
        )


class STR(InteractiveTyping):
    """
    String object
    will create text or textarea
    """

    def __init__(self, default: str = None, use_area: bool = False):
        """
        use_area: do we use Textarea, if False,we use Text
        """
        self.default = "" if default is None else default
        self.use_area = use_area

    def __repr__(self):
        return f"str='{self.default}'"

    def __call__(self, default: str = None):
        self.solid(default)
        if self.use_area:
            return Textarea(value=self.default, layout=Layout(width="80%"))
        return Text(value=self.default)


class LIST(InteractiveTyping):
    """
    dropdown list type or multiselection type
    """

    def __init__(self, options: List[Any] = [], default: Any = None, multi: bool = False):
        """
        if multi: default should be iterable
        else: default should be one of the option
        """
        self.options = options
        self.default = default
        self.multi = multi

    def __repr__(self):
        if self.multi:
            size = f"[0-{self.default}]/{len(self.options)}"
        else:
            size = f"1/{len(self.options)}"
        return f"list,{size}"

    def __call__(self, default: Any = None):
        self.solid(default)
        if self.multi:
            inter = SelectMultiple(options=self.options)
        else:
            inter = Dropdown(options=self.options)

        if self.default is not None:
            # if multi: default should be iterable
            # else: default should be one of the option
            inter.value = self.default
        return inter

### Enhanced Interactive

The original ```interact_manual``` isn't powerful enough for this situation, so the following is a more flexible way to decorate an interactive function

In [7]:
class InteractiveAnnotations:
    """
    Build interactive based on the info of function's ```__annotations__```
    """

    def __init__(
        self, func: Callable,
        icon: str = "rocket",
        description: str = 'Run',
        button_style='primary'
    ):
        self.func = func
        self.icon = icon
        self.button_style = button_style
        self.description = description
        self.build_vbox(func)

    @classmethod
    def on(
        cls,
        callback: Callable,
        icon: str = 'rocket',
        description: str = 'Run',
        button_style: str = 'primary'
    ) -> Callable:
        """
        Use this class as a decorator
        @InteractiveAnnotation.on(callback)
        def target_func(a:STR(), b:INT()=1):
            ...
        """
        def decorator(func: Callable):
            obj = cls(
                func,
                icon=icon,
                description=description,
                button_style=button_style
            )
            display(obj.vbox)
            obj.register_callback(callback=callback)
            return func
        return decorator

    def build_vbox(self, func: Callable):
        row_list = []
        self.fields = dict()
        for k, v in func.__annotations__.items():
            if hasattr(v, "is_typing") == False:
                continue
            widget = v()
            widget.description = k
            row_list.append(widget)
            self.fields.update({k: widget})

        # final button
        self.final_btn = Button(
            description=self.description,
            icon=self.icon,
        )
        self.final_btn.button_style = self.button_style
        row_list.append(self.final_btn)

        # create interactive
        self.vbox = VBox(row_list)
        return self.vbox

    def register_callback(
        self,
        callback: Callable
    ) -> None:
        def run_callback():
            kwargs = self()
            callback(kwargs)
        self.final_btn.click = run_callback

    def __call__(self) -> Dict[str, Any]:
        """
        extract interactive data values
        """
        rt = dict()
        for k, widget in self.fields.items():
            rt.update({k: widget.get_interact_value()})
        return rt

#### Test callback & decorator

In [8]:
def print_stuff(kwargs):
    Flash.info(str(kwargs))

@InteractiveAnnotations.on(print_stuff, "flask", "test", button_style="warning")
def some_func(e, a:STR(), b:INT()=2, d=3):
    Flash.danger(str(kwargs))

VBox(children=(Text(value='', description='a'), IntSlider(value=1, description='b', max=10), Button(button_sty…

### Intercept interactive

In [9]:
def print_kwargs(kwargs):
    print(kwargs)
    return kwargs


def reconfig_manual_interact(
    widget,
    description: str = "Create",
    button_style: str = "primary",
    icon: str = "plus"
) -> Button:
    """
    reconfigure the button of interactive features
    """
    btn = None
    for w in widget.children:
        if type(w) == Button:
            btn = w
            break
    btn.description = description
    btn.button_style = button_style
    btn.icon = icon
    return btn


def interact_intercept(
    func:Callable,
    result_cb: Callable = print_kwargs
):
    """
    Initialize a class with interactive features
    """
    annotations = func.__annotations__
    defaults = func.__defaults__
    kwargs = dict()
    if defaults is not None:
        for (k, typing), default in zip(annotations.items(), defaults):
            kwargs.update({k: typing(default)})
    obj = dict()

    def fillin_init(**kwargs):
        obj.update({
            "kwargs": kwargs,
        })
    f = interact_manual(fillin_init, **kwargs)

    btn = reconfig_manual_interact(f.widget)

    if btn is not None:
        original = btn.click

        def new_click_event():
            original()
            return result_cb(obj['kwargs'])
        btn.click = new_click_event

    return obj, f

def init_interact(cls, result_cb: Callable = print_kwargs):
    return interact_intercept(cls.__init__, result_cb=result_cb)

In [10]:
STR('RGB')()

Text(value='RGB')

## Enrich columns (feature transformation, label extraction)
After this step, there will only be **MORE** column ➕

### Enrich Classes

In [11]:
class Enrich:
    """
    Enrich Base Class
    Some default attributes
    - is_enrich = True
    - typing = None # output typing
    - multi_cols = False # use multi-column as input
    - prefer = None
    - lazy = False  # shall we execute enrichment only through the iteration
    - src = None # source column
    """
    is_enrich = True
    typing = None # output typing
    multi_cols = False # use multi-column as input
    prefer = None
    lazy = False  # shall we execute enrichment only through the iteration
    src = None # source column

    def __init__(self): pass

    def __call__(self, row):
        return row
    
    def rowing(self, row):
        if self.multi_cols:
            return self(row)
        else:
            return self(row[self.src])


class EnrichImage(Enrich):
    """
    Create Image column from image path column
    """
    prefer = "QuantifyImage"
    typing = Image
    lazy = True
    

    def __init__(
        self, convert: STR("RGB") = "RGB",
        size: LIST(options=[28, 128, 224, 256, 512], default=224) = 224,
    ):
        self.convert = convert
        self.size = size

    def __repr__(self):
        return f"[Image:{self.size}]"

    def __call__(self, x):
        img = Image.open(x).convert(self.convert)
        img = img.resize((self.size, self.size))
        return img


class ParentAsLabel(Enrich):
    typing = str
    prefer = "QuantifyCategory"
    def __call__(self, path: Path,) -> str:
        """
        Use parent folder name as label
        """
        return Path(path).parent.name
    
ENRICHMENTS = dict(
    EnrichImage=EnrichImage,
    ParentAsLabel=ParentAsLabel,
)

In [12]:
obj,f = init_interact(EnrichImage)

interactive(children=(Text(value='RGB', description='convert'), Dropdown(description='size', index=2, options=…

### Set Enrich 🎸

In [13]:
def set_enrich(**kwargs):
    df = kwargs['df']
    phase = kwargs['phase']

    DOM(f"{len(df)} rows of data, example table", "h3")()
    display(df.sample(5))
    display(HTML("<hr>"))

    def setting_col():
        enrich_data_list = phase['enrich'] if 'enrich' in phase else []
        enrich_box = EditableList(enrich_data_list)
        display(enrich_box)

        
        def set_enrich_(src=["[all_columns]", ]+list(df.columns)):
            DOM(f"Setting up column enrich: {src}", "h4")()
            if src == "[all_columns]":
                display(df.head(3))
            else:
                display(df[[src, ]].head(3))

            def choose_enrich(dst="", enrich=ENRICHMENTS):
                DOM(f"Source: {src}, Destination: {dst}, for {enrich.__name__}", "h4")(
                )
                DOM(f"{enrich.__doc__}", "quote")()

                def result_callback(kwargs):
                    extra = {"src": src, "dst": dst,
                                "kwargs": kwargs, "enrich": enrich.__name__}
                    enrich_box+extra
                    phase['enrich'] = enrich_box.get_data()
                obj, decoed_func = init_interact(enrich, result_callback)
            choose_enrich_widget = interact_manual(choose_enrich).widget
            reconfig_manual_interact(
                choose_enrich_widget,
                description="Choose", button_style='warning')
        set_enrich_widget = interact_manual(set_enrich_).widget
        reconfig_manual_interact(set_enrich_widget, button_style='warning')
    setting_col()

### Execute enrichment
> apply the enrichment settings to the dataframe

In [14]:
def execute_enrich(
    df: pd.DataFrame, phase:Phase
):
    if 'enrich' not in phase:
        return df
    for en_conf in tqdm(phase["enrich"], leave=False):
        enrich_name = en_conf['enrich']
        enrich_cls = ENRICHMENTS[enrich_name]
        kwargs = en_conf['kwargs']
        src = en_conf['src']
        dst = en_conf['dst']
        # The class with lazy loading, will only 
        # call the class only if necessary
        if enrich_cls.lazy:
            obj = enrich_cls(**kwargs)
            obj.src = src
            df[dst] = obj
        # The class without lazy loading
        # create the column now
        else:
            obj = enrich_cls(**kwargs)
            if src=="[all_columns]":
                df[dst] = df.apply(obj, axis=1)
            else:
                df[dst] = df[src].apply(obj)
    return df

## Quantify: Choose columns as X and Y, put them into number

### Size classes

In [15]:
class SIZE_DIMENSION:
    pass

class BATCH_SIZE(SIZE_DIMENSION):
    def __repr__(self): return f"BATCH_SIZE"

class SEQUENCE_SIZE(SIZE_DIMENSION):
    pass

class IMAGE_SIZE(SIZE_DIMENSION):
    pass

### Quantify classes

In [16]:
ct = Category([1,2,3,4,5])

In [17]:
class Quantify:
    is_quantify = True
    """
    # From all things to number
    The AI model does not understand anything, say, picture, text
    Unless you transform it to integer and float tensors

    Quantify and its subclass controls the
        numericalization / collation of the data pipeline
    The base class of quantify does: NOTHING
    """

    def __init__(self,):
        pass

    def __call__(self, list_of_items):
        return list(list_of_items)

    def adapt(self, column):
        """
        A function to let the data processing
        adapt to the data column
        """
        pass

    def __hash__(self,):
        if hasattr(self, "name"):
            return self.name
        else:
            return self.__class__.__name__


class QuantifyImage(Quantify):
    """
    Transform PIL.Image to tensor
    """

    def __init__(
        self,
        mean_: LIST(["imagenet", "0.5 x 3"]) = "imagenet",
        std_: LIST(["imagenet", "0.5 x 3"]) = "imagenet",
    ):
        if type(mean_) == str:
            if mean_ == "imagenet":
                mean_ = [0.485, 0.456, 0.406]
            elif mean_ == "0.5 x 3":
                mean_ = [.5, .5, .5]
            else:
                raise ValueError(
                    f"Mean configuration: {mean_} not valid")

        if type(std_) == str:
            if std_ == "imagenet":
                std_ = [0.229, 0.224, 0.225]
            elif std_ == "0.5 x 3":
                std_ = [.5, .5, .5]
            else:
                raise ValueError(
                    f"Standard Variation configuration: {std_} not valid")

        self.transform = tfm.Compose([
            tfm.ToTensor(),
            tfm.Normalize(mean=mean_, std=std_),
        ])

        self.shape = (BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)

    def __repr__(self):
        return f"Quantify Image to tensors:{self.transform}"

    def __call__(self, list_of_image):
        return torch.stack(list(
            self.transform(img) for img in list_of_image))


class QuantifyText(Quantify):
    def __init__(
        self,
        pretrained: STR(default="bert-base-cased") = "bert-base-cased",
        max_length: INT(default=512, min_=12, max_=1024, step=4) = 512,
        padding: LIST(options=[
            "do_not_pad",
            "max_length",
            "longest"], default="max_length") = "max_length",
        return_token_type_ids: BOOL(name="Token Type IDs", default=True) = True,
        return_attention_mask: BOOL(name="Attention Mask", default=True) = True,
        return_offsets_mapping: BOOL(name="Offset Mapping", default=False) = False,
    ):
        self.pretrained = pretrained
        self.max_length = max_length
        self.padding = padding
        self.return_token_type_ids = return_token_type_ids
        self.return_attention_mask = return_attention_mask
        self.return_offsets_mapping = return_offsets_mapping
        self.truncation = True
        self.return_tensors = 'pt'
        self.shape = (BATCH_SIZE, SEQUENCE_SIZE)

    def adapt(self, column):
        """
        Initialize tokenizer
        """
        from transformers import AutoTokenizer
        Flash.info("Loading transformer tokenizer, takes time", key="Alert!")
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.pretrained, use_fast=True)

    def __call__(self, list_of_text: List[str]):
        list_of_text = list(list_of_text)
        return self.tokenizer(
            list_of_text,
            padding=self.padding,
            max_length=self.max_length,
            truncation=self.truncation,
            return_token_type_ids=self.return_token_type_ids,
            return_attention_mask=self.return_attention_mask,
            return_tensors=self.return_tensors,
            return_offsets_mapping=self.return_offsets_mapping,
        )


class QuantifyCategory(Quantify):
    """
    Transform single categorical data to index numbers in pytorch tensors
    """

    def __init__(
        self,
        min_frequency: INT(min_=1, max_=20, default=1) = 1,
    ):
        self.min_frequency = min_frequency

    def adapt(self, column):
        # category statistics
        value_counts = pd.DataFrame(column.value_counts())

        # if minimun freq is 1
        # very category occured should be accounted for
        # hence no missing token padding is required
        if self.min_frequency < 2:
            self.category = Category(
                arr=np.array(value_counts.index),
                pad_mst=False)

        # we need missing token
        # for category's frequency < self.min_frequency
        else:
            categories = np.array(
                list(value_counts.index[
                    value_counts.values.reshape(-1) > self.min_frequency]))
            self.category = Category(arr=categories, pad_mst=True)

        self.shape = (BATCH_SIZE, len(self.category))

    def __repr__(self):
        return f"Quantify Category:{self.category}"

    def __call__(self, list_of_strings):
        return torch.LongTensor(self.category.c2i[np.array(list_of_strings)])


class QuantifyMultiCategory(Quantify):
    """
    Transform multi-categorical data to index numbers in pytorch tensors
    """

    def __init__(
        self,
        min_frequency: INT(min_=1, max_=20, default=1) = 1,
        separator: LIST(options=["[None]", ",", ";", "[Space]", "[By Char]"], default=",") = ",",
    ):
        self.min_frequency = min_frequency
        friendly_mapping = {
            "[None]": None,
            "[Space]": " ",
            "[By Char]": "",
        }
        if separator in friendly_mapping:
            separator = friendly_mapping.get(separator)
        self.separator = separator

    @staticmethod
    def stripping(x):
        return x.strip()

    def break_cell(self, value):
        if value is None:
            return []
        break_list = list(i for i in map(
            self.stripping, str(value).split(self.separator)) if len(i) > 0)
        return break_list

    def adapt(self, column):
        if self.separator is None:
            sample_col = column
        else:
            sample_col = column.apply(self.break_cell)
        value_counts = pd.DataFrame(sample_col.explode().value_counts())

        # if minimun freq is 1
        # very category occured should be accounted for
        # hence no missing token padding is required
        if self.min_frequency < 2:
            self.category = Category(
                arr=np.array(value_counts.index),
                pad_mst=False)

        # we need missing token
        # for category's frequency < self.min_frequency
        else:
            categories = np.array(
                list(value_counts.index[
                    value_counts.values.reshape(-1) > self.min_frequency]))
            self.category = Category(arr=categories, pad_mst=True)

    def __call__(
        self, list_of_strings: List[str]
    ) -> torch.LongTensor:
        """
        Return a batch of n-hot array tensor
        """
        if self.separator is None:
            col: List[str] = list_of_strings
        else:
            col: List[List[str]] = list(map(self.break_cell, list_of_strings))
        arrays: List[np.array] = []
        for item in col:
            array: np.array = np.zeros(len(self.category))
            if len(item) > 0:
                one_idx: np.array = self.category.c2i[item]
                array[one_idx] = 1
            arrays.append(array)
        return torch.LongTensor(np.stack(arrays))


class QuantifyNum(Quantify):
    """
    Quantify contineous data, like float numbers
    The only process is normalization on the entire population
    """
    shape = (BATCH_SIZE, 1)

    def adapt(self, column):
        self.mean_ = column.mean()
        self.std_ = column.std()

    def __call__(self, list_of_num):
        return (torch.FloatTensor(list_of_num)[:, None]-self.mean_)/self.std_

    def backward(self, x):
        return x*self.std_+self.mean_


QUANTIFY = dict(
    Quantify=Quantify,
    QuantifyNum=QuantifyNum,
    QuantifyImage=QuantifyImage,
    QuantifyCategory=QuantifyCategory,
    QuantifyMultiCategory=QuantifyMultiCategory,
    QuantifyText=QuantifyText,
)

#### Pytorch Dataset

In [18]:
class TaiChiDataset(Dataset):
    """
    A pytorch dataset working under our core engine
    The dataset class should on be defined here once
    """
    def __init__(self, df, columns: List[Any] = None):
        self.df = df
        self.columns = list(df.columns) if columns is None else columns

    def __len__(self):
        return len(self.df)

    def shuffle(self):
        self.df = self.df.sample(frac=1.).reset_index(drop=True)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = dict(self.df.loc[idx])
        rt = dict()
        for col in self.columns:
            v = row[col]
            if hasattr(v, "is_enrich"):
                rt[col] = v.rowing(row)
            else:
                rt[col] = v
        return rt
    
    def split(
        self,
        valid_ratio:FLOAT(min_=0.01, max_=0.5, default=.1, step=0.01)=.1
    ) -> Tuple[Any]:
        """
        Split dataset to train, validation
        """
        cls = self.__class__
        slicing = (np.random.rand(len(self)) < valid_ratio)
        return (
            cls(self.df[~slicing].reset_index(drop=True), self.columns),
            cls(self.df[slicing].reset_index(drop=True), self.columns)
        )

    def dataloader(
        self,
        batch_size: LIST(options=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], default=32) = 32,
        shuffle: LIST(options=[True, False], default=False) = False,
        num_workers: LIST(options=[0, 2, 4, 8, 16], default=0) =0,
    ):
        """
        Create dataloader from dataset
        """
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers)

### Choose XY 🎸

In [19]:
def choose_xy(**kwargs):
    df = kwargs.get("df")
    phase = kwargs.get("phase")
    progress = kwargs.get('progress')

    DOM(f"{len(df)} rows of data, example table", "h3")()
    display(df.sample(5))
    display(HTML("<hr>"))
    DOM("Please Choose Column", "h3")()
    DOM("The AI model will try to guess the Y with the input X", "div", {"style":"color:#666699"})()
    
    task = 'quantify'
    # enrich by columns
    if "enrich" in phase:
        by_destination = dict((en['dst'], en) for en in phase['enrich'])
    else:
        by_destination = dict()
    
    data_list = phase[task] if task in phase else []
    mol_box = EditableList(data_list)
    display(mol_box)

    @interact_manual
    def set_quantify_(src=list(df.columns), use_for = ["As X", "As Y"]):
        DOM(f"Quantify Column: {src} {use_for}", "h4")()
        display(df[[src, ]].head(3))
        
        quantify_dropdown = Dropdown(options=list(QUANTIFY.keys()))
        
        # check the hint from last step
        prefer = None
        if src in by_destination:
            col_config = by_destination[src]
            cls = ENRICHMENTS[col_config['enrich']]

            # In case the enrich layer has the preference
            if hasattr(cls, "prefer"):
                prefer = cls.prefer
                
                # set default value to drop down value,
                # if the the previous hint suggest so
                quantify_dropdown.value = prefer
                DOM(f"Prefered quantifying:\t{cls.prefer}", "h4")()
            if hasattr(cls, "typing"):
                DOM(f"Output data type:\t{cls.typing}", "h4")()
        
        @interact_manual
        def choose_quantify(quantify = quantify_dropdown):
            cls = QUANTIFY[quantify]
            def result_callback(kwargs):
                extra = {"src": src, "x":(use_for=="As X"),
                        "kwargs": kwargs, "quantify": cls.__name__}
                mol_box+extra
                phase['quantify'] = mol_box.get_data()
                
            obj, decoded = init_interact(cls, result_callback)

In [20]:
def execute_quantify(
    df: pd.DataFrame, phase:Phase
):
    # existance check
    if 'quantify' not in phase:
        raise KeyError(f"No quantify stepset")
    
    qdict = dict()
    for i, qconf in tqdm(enumerate(phase['quantify']), leave = False):
        qname = qconf['quantify']
        kwargs = qconf['kwargs']
        src = qconf['src']
        x = qconf['x']
        
        cls = QUANTIFY[qname]
        qobj = cls(**kwargs)
        qobj.src = src
        qobj.is_x = x
        qobj.adapt(df[src])
        qdict.update({src:qobj})
    return qdict

### Create Dataloader
This part handles:
* Spliting
* To dataloader

In [21]:
class TaiChiCollate:
    """
    Universal all power full collate function
    1 for all collation
    """
    def __init__(self, quantify_dict):
        self.quantify_dict = quantify_dict
        
    def make_df(self, batch):
        return pd.DataFrame(list(batch))
        
    def __len__(self):
        return len(self.quantify_dict)
        
    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        """
        This call will execute the __call__(a_list_of_items)
        from Quantify objects column by column
        """
        batch_df = self.make_df(batch)
        rt = dict()
        for src,qobj in self.quantify_dict.items():
            rt.update({
                src:qobj(list(batch_df[src]))
            })
        return rt

class TaiChiDataModule(pl.LightningDataModule):
    def __init__(self, dataset: TaiChiDataset, quantify_dict: Dict[str, Quantify]):
        super().__init__()
        self.dataset = dataset
        self.quantify_dict = quantify_dict
        
        self.collate = TaiChiCollate(quantify_dict)
        
    def configure(
        self,
        valid_ratio:FLOAT(min_=0.01, max_=0.5, default=.1, step=0.01)=.1,
        batch_size: LIST(options=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], default=32) = 32,
        shuffle: LIST(options=[True, False], default=False) = True,
        num_workers: LIST(options=[0, 2, 4, 8, 16], default=0) =0,
    ):  
        self.train_ds, self.val_ds = self.dataset.split(valid_ratio)
        self.batch_size=batch_size
        self.shuffle=shuffle
        self.num_workers=num_workers
        
    def train_dataloader(self):
        self.train_dl = self.train_ds.dataloader(
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers)
        self.train_dl.collate_fn = self.collate
        return self.train_dl
    
    def val_dataloader(self):
        self.val_dl = self.val_ds.dataloader(
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers)
        self.val_dl.collate_fn = self.collate
        return self.val_dl

###  Choose your model, loss

In [22]:
from torchvision.models import (
    resnet18, resnet34, resnet50, resnet101, resnet152,
    resnext101_32x8d, resnext50_32x4d)

## Models

### Entry modules

In [23]:
RESNET_OPTIONS = {"resnet18": resnet18,
                  "resnet34": resnet34,
                  "resnet50": resnet50,
                  "resnet101": resnet101,
                  "resnet152": resnet152,
                  "resnext101_32x8d": resnext101_32x8d,
                  "resnext50_32x4d": resnext50_32x4d}

In [24]:
class MidJoint1d(nn.Module):
    def __init__(self, keys):
        super().__init__()
        self.keys = keys
    
    def forward(self, data):
        tensors = list(data[key] for key in self.keys)
        return torch.cat(tensors,dim=1)

In [25]:
class EntryModel(nn.Module):
    is_entry = True
    
    @classmethod
    def from_quantify(cls, ):
        raise ImportError(
            f"Please define class function 'from_quantify' for {cls.__name__}"
        )
    
class Empty(EntryModel):
    def __init__(self):
        super().__init__()
        self.out_features=1
    
    def forward(self, x):
        return x
    
    @classmethod
    def from_quantify(cls,
        quantify):
        return cls()

class ImageConvEncoder(EntryModel):
    def __init__(self, model):
        super().__init__()
        self.name = "cnn"
        self.output_shape = (BATCH_SIZE, model.fc.in_features)
        self.out_features = model.fc.in_features
        model.fc = Empty()
        self.model = model

    def forward(self, data):
        return self.model(data)

    def __repr__(self):
        return f"""ComputerVisionEncoder: {self.name}
        Outputs shape:{self.output_shape}"""

    @classmethod
    def from_quantify(
        cls,
        quantify,
        name: LIST(options=list(
            RESNET_OPTIONS.keys()), default="resnet18"),
    ):
        model = RESNET_OPTIONS[name](pretrained=True, progress=True,)
        obj = cls(model)
        obj.name = name
        return obj


class CategoryEncoder(EntryModel):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.model = nn.Embedding(
            num_embeddings,
            embedding_dim)
        
    def forward(self, idx):
        return self.model(idx)
    
    @classmethod
    def from_quantify(
        cls,
        quantify,
        embedding_dim: LIST(
            options=[4, 8, 16, 32, 64, 128, 256, 512], default=128) = 128):
        num_embeddings = len(quantify.category)
        obj = CategoryEncoder(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
        )
        obj.out_features = embedding_dim
        return obj
    
class MultiCategoryEncoder(EntryModel):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.model = nn.Embedding(
            num_embeddings,
            embedding_dim)
        
    def forward(self, idx):
        return idx.float()@self.model.weight
    
    @classmethod
    def from_quantify(
        cls,
        quantify,
        embedding_dim: LIST(
            options=[4, 8, 16, 32, 64, 128, 256, 512], default=128) = 128):
        num_embeddings = len(quantify.category)
        obj = MultiCategoryEncoder(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
        )
        obj.out_features = embedding_dim
        return obj


class TransformerEncoder(EntryModel):
    """
    A model part to encode sequnce data in to vectors
    """

    def __init__(self, model, encoder_mode: BOOL(default=True) = True,):
        super().__init__()
        self.model = model
        self.encoder_mode = encoder_mode

    def forward(self, kwargs):
        outputs = self.model(**kwargs)
        if self.encoder_mode:
            # output vector
            if "pooler_output" in outputs:
                return outputs.pooler_output
            else:
                return (
                    outputs.last_hidden_state*kwargs['attention_mask'][:,:,None]
                ).mean(1)
        return outputs

    @classmethod
    def from_quantify(
        cls,
        quantify,
        name: STR(default="bert-base-uncased") = 'bert-base-uncased',
        encoder_mode: BOOL(default=True) = True,
    ):
        from transformers import AutoModel
        model = AutoModel.from_pretrained(name)
        obj = cls(model)
        obj.name = name
        obj.encoder_mode = encoder_mode
        if encoder_mode:
            obj.out_features= model.config.hidden_size
        return obj

In [26]:
# entry = TransformerEncoder.from_quantify(0)

In [27]:
# with torch.no_grad():
#     vector = entry(data['review_content'])

In [28]:
# entry = ImageConvEncoder.from_quantify(0,name="resnet18")

# with torch.no_grad():
#     vectors = entry(data['image'])

### Exit modules

In [29]:
def accuracy(y_, y):
    return (y_.argmax(-1) == y).float().mean()

def bi_accuracy(y_, y):
    return ((y_>.5).float() == y).float().mean()

class BCEWithLogitsLossCasted(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce_ = nn.BCEWithLogitsLoss()
        
    def forward(self, y_, y):
        return self.bce_(y_,y.float())
    

In [30]:
class ExitModel(nn.Module):
    metric_funcs = dict()

    def loss_step(self, x, y):
        y_ = self(x)
        loss = self.crit(y_, y)
        metrics = dict()
        if hasattr(self, "activation"):
            y_ = self.activation(y_)
        for k, func in self.metric_funcs.items():
            metrics.update({k: func(y_, y)})
        return dict(loss=loss, y_=y_, **metrics)


class CategoryTop(ExitModel):
    prefer = "CrossEntropyLoss"
    input_dim = 2

    def __init__(self, in_features, out_features):
        super().__init__()
        self.top = nn.Linear(
            in_features=in_features, out_features=out_features)
        self.activation = nn.Softmax(dim=-1)
        self.crit = nn.CrossEntropyLoss()
        self.metric_funcs.update({"acc": accuracy})

    def forward(self, x):
        return self.top(x)

    @classmethod
    def from_quantify(cls, quantify, entry_part):
        out_features = len(quantify.category)
        in_features = entry_part.out_features
        return cls(
            in_features=in_features,
            out_features=out_features,
        )


class MultiCategoryTop(ExitModel):
    prefer = "BCEWithLogitsLossCasted"
    input_dim = 2

    def __init__(self, in_features, out_features):
        super().__init__()
        self.top = nn.Linear(
            in_features=in_features, out_features=out_features)
        self.activation = nn.Sigmoid()
        self.crit = BCEWithLogitsLossCasted()
        self.metric_funcs.update({"acc": bi_accuracy})

    def forward(self, x):
        return self.top(x)

    @classmethod
    def from_quantify(cls, quantify, entry_part):
        out_features = len(quantify.category)
        in_features = entry_part.out_features
        return cls(
            in_features=in_features,
            out_features=out_features,
        )


class RegressionTop(ExitModel):
    prefer = "MSELoss"
    input_dim = 2

    def __init__(self, in_features, out_features):
        super().__init__()
        self.top = nn.Linear(
            in_features=in_features, out_features=out_features)
        self.crit = nn.MSELoss()

    def forward(self, x):
        return self.top(x)

    @classmethod
    def from_quantify(cls, quantify, entry_part):
        out_features = 1
        in_features = entry_part.out_features
        return cls(
            in_features=in_features,
            out_features=out_features,
        )

### EntireModel

In [31]:
# mapping quantify to the following entry or exit model
QUANTIFY_2_ENTRY_MAP = dict({
    QuantifyImage:[
        ImageConvEncoder,
    ],
    QuantifyCategory:[
        CategoryEncoder,
    ],
    QuantifyMultiCategory:[
        MultiCategoryEncoder,
    ],
    QuantifyText:[
        TransformerEncoder,
    ],
    QuantifyNum:[
        Empty,
    ],
})
QUANTIFY_2_EXIT_MAP = dict({
    QuantifyCategory:[
        CategoryTop,
    ],
    QuantifyMultiCategory:[
        MultiCategoryTop,
    ],
    QuantifyNum:[
        RegressionTop,
    ],
})

# all entry and exit model
ENTRY_ALL = dict(
    ImageConvEncoder=ImageConvEncoder,
    CategoryEncoder=CategoryEncoder,
    MultiCategoryEncoder=MultiCategoryEncoder,
    TransformerEncoder=TransformerEncoder,
    Empty=Empty,
)
EXIT_ALL = dict(
    CategoryTop=CategoryTop,
    MultiCategoryTop=MultiCategoryTop,
    RegressionTop=RegressionTop,
)

In [32]:
def choose_models(
    quantify,
    cls_options,
    model_conf: EditableDict,
):
    def config_model(ModelClass=cls_options):
        def starting_cls(kwargs):
            model_conf[quantify.src] = dict(
                model_name=ModelClass.__name__,
                src=quantify.src,
                kwargs=kwargs,
            )

        ia = InteractiveAnnotations(
            ModelClass.from_quantify,
            description="Okay",
            icon='rocket',
            button_style='success')

        ia.register_callback(starting_cls)
        display(ia.vbox)
    inter = interact_manual(config_model)
    reconfig_manual_interact(
        inter.widget,
        description="Yes!", icon="cube", button_style='info')
    return inter


def set_model(quantify_dict: Dict[str, Quantify], phase: Phase):
    display(HTML("""<h3>Set up model structure</h3>
    <quote>You'll have to setup a model part for each of the column</quote>"""))

    x_models = EditableDict()
    y_models = EditableDict()

    if "x_models" in phase:
        x_models + phase['x_models']
    if "y_models" in phase:
        y_models + phase['y_models']
    display(HTML("<h3>Current model config:</h3>"))
    display(HTML(f"""
    <h3 class='text-primary'>🤖 <strong>Entry</strong> parts of the model</h3>
    <h4>To understand the X columns</h4>
    """))
    display(x_models)
    display(HTML(f"""
    <h3 class='text-danger'>🦾 <strong>Exit</strong> parts of the model</h3>
    <h4>To understand & predict the Y column</h4>
    """))
    display(y_models)

    @x_models.on_update
    def update_x_models(x_models_data):
        phase['x_models'] = x_models_data

    @y_models.on_update
    def update_y_models(y_models_data):
        phase['y_models'] = y_models_data

    display(HTML("<h4>Change model config:</h4>"))
    for src, quantify in quantify_dict.items():
        if quantify.is_x:
            entry_cls_options = dict(
                (q.__name__, q)
                for q in QUANTIFY_2_ENTRY_MAP.get(quantify.__class__))

            if entry_cls_options is None:
                Flash.danger(
                    f"We do not support {quantify.__class__} as X data",
                    key="Error!")
                continue
            display(HTML(f"""
            <h3 class='text-primary'>Choose Model For X Columns:
            <strong>{src}</strong></h3>"""))
            choose_models(quantify, entry_cls_options, x_models)
    for src, quantify in quantify_dict.items():
        if quantify.is_x == False:
            exit_cls_options = dict(
                (q.__name__, q)
                for q in QUANTIFY_2_EXIT_MAP.get(quantify.__class__))
            if entry_cls_options is None:
                Flash.danger(
                    f"We do not support {quantify.__class__} as Y data",
                    key="Error!"
                )
            display(HTML(f"""
            <h3 class='text-danger'>Choose Model For Y Column:
            <strong>{src}</strong></h3>"""))
            choose_models(quantify, exit_cls_options, y_models)


def set_datamodule(progress, df, qdict, phase):
    ds = TaiChiDataset(df)
    datamodule = TaiChiDataModule(ds, qdict)

    batch_level = EditableDict()
    if "batch_level" in phase:
        batch_level['batch_level'] = phase['batch_level']
    display(HTML("<h3>How we make data rows into batch</h3>"))
    display(batch_level)
    model_output = Output()

    def configure_setting(kwargs):
        batch_level['batch_level'] = kwargs
        phase['batch_level'] = kwargs
        set_model_btn_event()

    def set_model_btn_event():
        if 'batch_level' not in phase:
            Flash.warning("batch level config not set",
                          key="Warning")
            return
        datamodule.configure(**phase['batch_level'])
        progress.kwargs['datamodule'] = datamodule

        model_output.clear_output()
        with model_output:
            set_model(qdict, phase)

    interact_intercept(datamodule.configure, configure_setting)

    set_model_btn = Button(description="Set Batch",
                           icon='cog', button_style='info')
    set_model_btn.click = set_model_btn_event
    display(set_model_btn)
    display(model_output)

In [33]:
class EntryDict(nn.Module):
    """
    Create entry parts for different columns
    """
    def __init__(
        self,
        phase: Phase,
        qdict: Dict[str, EntryModel]
    ):
        super().__init__()
        model_dict = nn.ModuleDict()
        for src, model_cfg in phase['x_models'].items():
            quantify = qdict[src]
            
            # find column class
            model_cls = ENTRY_ALL[model_cfg['model_name']]
            # the kwargs to start the column model object
            model_kwargs = model_cfg['kwargs']
            # the model object
            model = model_cls.from_quantify(quantify, **model_kwargs)
            
            # add the model by column name
            model_dict[src] = model
        
        # calculate the output size for dimention 1 (after concatenation)
        self.out_features = sum(
            list(model.out_features for src, model in model_dict.items()))
        self.model_dict = model_dict

    def forward(self, inputs):
        outputs = []
        for src, model in self.model_dict.items():
            # input data for column
            src_input = inputs[src]
            
            # forward pass for column_model(column_data)
            outputs.append(model(src_input))
        # concat the results
        return torch.cat(outputs, dim=1)

In [34]:
class AssembledModel(pl.LightningModule):
    def __init__(
        self,
        phase: Phase,
        qdict: Dict[str, EntryModel],
        entry_lr: LIST(options=[1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], default=1e-4)=1e-4,
        exit_lr: LIST(options=[1e-1, 1e-2, 1e-3, 1e-4, ], default=1e-3)=1e-3,
    ):
        super().__init__()
        self.entry_lr = entry_lr
        self.exit_lr = exit_lr
        self.entry_dict = EntryDict(phase, qdict)
        exit_cfg = list(phase['y_models'].values())[0]
        
        self.exit_src = exit_cfg['src']
        self.exit_kwargs = exit_cfg['kwargs']
        exit_cls = EXIT_ALL[exit_cfg['model_name']]
        
        exit_quantify = qdict[self.exit_src]
        
        self.exit_part = exit_cls.from_quantify(
            exit_quantify,self.entry_dict, **self.exit_kwargs)
    
    def forward(self, inputs):
        vec = self.entry_dict(inputs)
        return self.exit_part(vec)
    
    def loss_step(self, inputs):
        vec = self.entry_dict(inputs)
        return self.exit_part.loss_step(vec, inputs[self.exit_src])
    
    def training_step(self, batch, batch_idx):
        rt = self.loss_step(batch)
        for k, v in rt.items():
            if v.numel()==1:
                self.log(f"trn_{k}", v)
        return rt['loss']

    def validation_step(self, batch, batch_idx):
        rt = self.loss_step(batch)
        for k, v in rt.items():
            if v.numel()==1:
                self.log(f"val_{k}", v)
        return rt['loss']
    
    def configure_optimizers(self,):
        param_groups = [
            {"params":self.entry_dict.parameters(), "lr":self.entry_lr},
            {"params":self.exit_part.parameters(), "lr":self.exit_lr},
        ]
        return torch.optim.Adam(param_groups)
    

In [35]:
def create_model(phase, qdict):
    if "y_models" in phase:
        y_models = phase["y_models"]
        if len(y_models)>1:
            raise ValueError("Multiple targets are not supported by now")
        else:
            return AssembledModel(phase, qdict)
    else:
        raise ValueError("phase must contain 'y_models' configuration for now")

### Training

In [36]:
def make_slug_name(phase):
    xs = '-'.join(list(q['src'] for q in phase['quantify'] if q["x"]))
    ys = '-'.join(list(q['src'] for q in phase['quantify'] if q["x"]==False))
    return '_'.join([xs,'to',ys])

Create a vivid **name** for the task

In [37]:
def set_trainer(
    phase,
    project:STR(default="default",) = "default",
    tensorboard: BOOL(default=True)=True,
    show_metric: BOOL(default=True)=True,
    max_epochs: INT(min_=1, max_=200, default=5)=5,
    use_gpu:BOOL(default=True)=True,
):
    if project=='default':
        global PROJECT
        if str(PROJECT)!="None":
            project = PROJECT
        else:
            project = "./project"
    project = Path(project)
    TASK_SLUG = make_slug_name(phase)
    csv_logger = pl.loggers.CSVLogger(project/"csv_log", name = TASK_SLUG, )
    loggers = [
        csv_logger
    ]
    if tensorboard:
        loggers.append(
            pl.loggers.TensorBoardLogger(save_dir=project/'tensorboard', name=TASK_SLUG)
        )
    rt = dict(
        max_epochs = max_epochs,
        logger =loggers)
    callbacks = []
    if show_metric:
        callbacks.append(
            DataFrameMetricsCallback())
        
    rt.update({"callbacks":callbacks})
    
    if use_gpu:
        rt.update(dict(gpus=1))
#     rt.update(dict(
#         auto_select_gpus=True,
#     ))
    return rt

def run_training(phase, final_model, datamodule):
    def set_trainer_callback(kwargs):
        task_slug = phase['task_slug']
        Flash.info(f"Create trainer for task: {task_slug}",key="Notice")
        trainer_kwargs = set_trainer(phase, **kwargs)
        trainer = pl.Trainer(**trainer_kwargs)
        Flash.success("Start training, this is not a drill!",key="Alert!")
        trainer.fit(final_model, datamodule=datamodule)
        return trainer
    return set_trainer_callback

## Steps

In [38]:
def step_enrich(**kwargs):
    df = kwargs['df']
    phase = kwargs['phase']
    progress = kwargs['progress']
    set_enrich(df=df, phase=phase)


def step_quantify(**kwargs):
    df = kwargs.get('df')
    phase = kwargs.get('phase')
    progress = kwargs.get('progress')
    execute_enrich(df=df, phase=phase)
    # creating dataset
    ds = TaiChiDataset(df)
    progress.kwargs['dataset'] = ds
    # preview a row of data
    display(HTML(f"<h3>A row of data</h3>"))

    @interact
    def show_row(idx=IntSlider(min=0, max=min(len(ds), 30))):
        list_group_kv(ds[idx])()
    choose_xy(progress=progress, df=df, phase=phase)


def step_modeling(**kwargs):
    df = kwargs.get('df')
    phase = kwargs.get('phase')
    progress = kwargs.get('progress')

    qdict = execute_quantify(df=df, phase=phase)
    progress.kwargs['qdict'] = qdict
    set_datamodule(progress, df, qdict, phase)


def step_training(**kwargs):
    df = kwargs.get('df')
    phase = kwargs.get('phase')
    progress = kwargs.get('progress')
    qdict = kwargs.get('qdict')
    datamodule = kwargs.get('datamodule')

    if (qdict is None) or (datamodule is None):
        print(f"Please finish last step first")
    final_model = create_model(phase, qdict)
    save_phase()
    TASK_SLUG = make_slug_name(phase)
    phase['task_slug'] = TASK_SLUG
    progress.kwargs['model'] = final_model
    interact_intercept(set_trainer, 
                       run_training(
                            phase, final_model, datamodule)
                      )

In [39]:
STEPS_MAP: Dict[str, Callable] = {
    "Enrich": step_enrich,
    "Quantify": step_quantify,
    "Model": step_modeling,
    "Train": step_training}

In [40]:
class TaiChiLearn:
    """
    A dataframe please
    then we learn
    """
    def __init__(self, df, phase):
        self.progress = StepByStep(
            STEPS_MAP, kwargs={"df":df, "phase":phase})
        
    def __call__(self):
        self.progress()

## Demo tasks

> Load all the code above in one shot, the demo starts here

### Extra helpers
> These are helper function relate to the task, ```only``` to modify the dataframe

In [41]:
import random


def df_creator_image_folder(path: Path) -> pd.DataFrame:
    """
    Create a dataframe ,
    Which list all the image path under a system folder
    """
    path = Path(path)
    files = []
    formats = ["jpg", "jpeg", "png"]
    for fmt in formats:
        files.extend(path.rglob(f"*.{fmt.lower()}"))
        files.extend(path.rglob(f"*.{fmt.upper()}"))
    return pd.DataFrame({"path": files}).sample(frac=1.).reset_index(drop=True)


def noise():
    return random.random()*.1


def turn_bear_dataset_to_regression(base_df):
    base_df["grizzly_score"] = base_df['path'].apply(
        lambda x: .9 + noise() if Path(x).parent.name == 'grizzly' else .1 + noise())
    return base_df

### Choose dataset

In [42]:
# BEAR_DATASET = HOME/"Downloads"/"bear_dataset"
BEAR_DATASET = Path("/GCI/data/bear_dataset")
ROTTEN_TOMATOES = Path("/GCI/data/rttmt")
NETFLIX = Path("/GCI/data/nf")

Choose one of the following to run 

#### Netflix 📺

In [43]:
base_df = pd.read_csv(NETFLIX/"netflix_titles.csv")
base_df.head()

Unnamed: 0,show_id,type,title,director,cast,country,date_added,release_year,rating,duration,listed_in,description
0,s1,Movie,Dick Johnson Is Dead,Kirsten Johnson,,United States,"September 25, 2021",2020,PG-13,90 min,Documentaries,"As her father nears the end of his life, filmm..."
1,s2,TV Show,Blood & Water,,"Ama Qamata, Khosi Ngema, Gail Mabalane, Thaban...",South Africa,"September 24, 2021",2021,TV-MA,2 Seasons,"International TV Shows, TV Dramas, TV Mysteries","After crossing paths at a party, a Cape Town t..."
2,s3,TV Show,Ganglands,Julien Leclercq,"Sami Bouajila, Tracy Gotoas, Samuel Jouy, Nabi...",,"September 24, 2021",2021,TV-MA,1 Season,"Crime TV Shows, International TV Shows, TV Act...",To protect his family from a powerful drug lor...
3,s4,TV Show,Jailbirds New Orleans,,,,"September 24, 2021",2021,TV-MA,1 Season,"Docuseries, Reality TV","Feuds, flirtations and toilet talk go down amo..."
4,s5,TV Show,Kota Factory,,"Mayur More, Jitendra Kumar, Ranjan Raj, Alam K...",India,"September 24, 2021",2021,TV-MA,2 Seasons,"International TV Shows, Romantic TV Shows, TV ...",In a city of coaching centers known to train I...


#### The bear 🐻

In [43]:
base_df = df_creator_image_folder(BEAR_DATASET)
base_df = turn_bear_dataset_to_regression(base_df)
base_df

Unnamed: 0,path,grizzly_score
0,/GCI/data/bear_dataset/grizzly/00000166.jpg,0.930374
1,/GCI/data/bear_dataset/grizzly/00000080.jpg,0.962466
2,/GCI/data/bear_dataset/black/00000154.jpg,0.154357
3,/GCI/data/bear_dataset/teddys/00000047.jpg,0.198507
4,/GCI/data/bear_dataset/black/00000142.jpg,0.173941
...,...,...
517,/GCI/data/bear_dataset/grizzly/00000091.jpg,0.977294
518,/GCI/data/bear_dataset/grizzly/00000040.jpg,0.914690
519,/GCI/data/bear_dataset/teddys/00000001.jpg,0.131970
520,/GCI/data/bear_dataset/teddys/00000130.jpg,0.105197


#### The rotten tomatoes 🍅 🎬

In [43]:
# the rotten tomatoes dataset, we are not using every line
base_df = pd.read_csv(ROTTEN_TOMATOES/'critic_reviews.csv', nrows=200000)
base_df = base_df[~base_df['review_score'].isna()].reset_index(drop=True)
base_df = base_df[~base_df['review_content'].isna()].reset_index(drop=True)
base_df = base_df[~base_df['critic_name'].isna()].reset_index(drop=True)

base_df = base_df[base_df['review_score'].apply(lambda x: "/" in x)].reset_index(drop=True)

base_df['review_score'] = base_df['review_score'].apply(eval)

base_df

Unnamed: 0,rotten_tomatoes_link,critic_name,top_critic,publisher_name,review_type,review_score,review_date,review_content
0,m/0814255,Ben McEachen,False,Sunday Mail (Australia),Fresh,0.700,2010-02-09,Whether audiences will get behind The Lightnin...
1,m/0814255,Nick Schager,False,Slant Magazine,Rotten,0.250,2010-02-10,Harry Potter knockoffs don't come more transpa...
2,m/0814255,Bill Goodykoontz,True,Arizona Republic,Fresh,0.700,2010-02-10,"Percy Jackson isn't a great movie, but it's a ..."
3,m/0814255,Jim Schembri,True,The Age (Australia),Fresh,0.600,2010-02-10,"Crammed with dragons, set-destroying fights an..."
4,m/0814255,Mark Adams,False,Daily Mirror (UK),Fresh,0.800,2010-02-10,"This action-packed fantasy adventure, based on..."
...,...,...,...,...,...,...,...,...
108237,m/bottle_shock,Phil Villarreal,False,Arizona Daily Star,Rotten,0.500,2008-08-29,"It might have worked better as a documentary, ..."
108238,m/bottle_shock,Todd Gilchrist,False,IGN Movies,Rotten,0.400,2008-08-29,Bottle Shock feels more like an excuse to exer...
108239,m/bottle_shock,Austin Kennedy,False,Sin Magazine,Rotten,0.625,2008-09-02,"I was slightly involved towards the end, but t..."
108240,m/bottle_shock,Sean P. Means,False,Salt Lake Tribune,Rotten,0.500,2008-09-05,"Flat, musty and with a hint of flopsweat."


### Where the model & config is going to end up:
a PROJECT folder

In [44]:
HOME = Path(os.environ['HOME'])
# PROJECT = Path("./project")
# PROJECT = Path("./project/image_regression")
# PROJECT = Path("./project/rotten1")
# PROJECT = Path("./project/rotten_text")
PROJECT = Path("./project/netflix")

### Start of the pipeline

Initiate the ```phase``` to track the configuration

In [45]:
load_phase()
# phase = Phase()

{
  "quantify": [
    {
      "src": "country",
      "x": true,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    },
    {
      "src": "listed_in",
      "x": false,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    },
    {
      "src": "cast",
      "x": true,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    }
  ],
  "batch_level": {
    "valid_ratio": 0.1,
    "batch_size": 8,
    "shuffle": true,
    "num_workers": 0
  },
  "x_models": {
    "country": {
      "model_name": "MultiCategoryEncoder",
      "src": "country",
      "kwargs": {
        "embedding_dim": 64
      }
    },
    "cast": {
      "model_name": "MultiCategoryEncoder",
      "src": "cast",
      "kwargs": {
        "embedding_dim": 64
      }
    }
  },
  "y_models": {
    "listed_in": 

In [46]:
learn = TaiChiLearn(base_df, phase)
learn()

VBox(children=(HBox(children=(Button(button_style='danger', description='1:Enrich', icon='cube', style=ButtonS…

Output()

In [49]:
phase

Phase:{
  "quantify": [
    {
      "src": "country",
      "x": true,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    },
    {
      "src": "listed_in",
      "x": false,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    },
    {
      "src": "cast",
      "x": true,
      "kwargs": {
        "min_frequency": 1,
        "separator": ","
      },
      "quantify": "QuantifyMultiCategory"
    }
  ],
  "batch_level": {
    "valid_ratio": 0.1,
    "batch_size": 8,
    "shuffle": true,
    "num_workers": 0
  },
  "x_models": {
    "country": {
      "model_name": "MultiCategoryEncoder",
      "src": "country",
      "kwargs": {
        "embedding_dim": 64
      }
    },
    "cast": {
      "model_name": "MultiCategoryEncoder",
      "src": "cast",
      "kwargs": {
        "embedding_dim": 64
      }
    }
  },
  "y_models": {
    "listed

In [48]:
datamodule = learn.progress.kwargs['datamodule']
model = learn.progress.kwargs['model']