## Load a model

In [14]:
# export
from tai_chi_engine import TaiChiTrained
from pathlib import Path
import logging
import pandas as pd
from PIL import Image
import streamlit as st

In [2]:
# !pip install streamlit==1.3.1

In [None]:
# export
@st.cache(allow_output_mutation=True)
def load_trained(project):
    PROJECT = Path(project)
    logging.warning('loading... takes time')
    trained = TaiChiTrained(PROJECT)
    return trained

trained = load_trained("./project/image_regression")

slug = trained.phase['task_slug']
st.title('A Tai-Chi Engine Model:')
st.write(f'Task: {slug}')

In [29]:
# export
from tai_chi_engine.quantify import (
    QuantifyText,
    QuantifyCategory,
    QuantifyImage,
    QuantifyMultiCategory,
    QuantifyNum
)
from typing import List, Any

In [29]:
# export
def st_num(name, quantify, enrich=None):
    value = quantify.mean_
    min_value = quantify.mean_ - 3*quantify.std_
    max_value = quantify.mean_ + 3*quantify.std_
    
    return st.slider(
        label = name,
        min_value = min_value,
        max_value = max_value,
        value = value)
    
def st_image(name, quantify,enrich=None):
    if enrich is None:
        size = 224
        convert = "RGB"
    elif enrich['enrich']!="EnrichImage":
        size = 224
        convert = "RGB"
    else:
        size = enrich['kwargs']['size']
        convert = enrich['kwargs']['convert']
    uploaded_file = st.file_uploader(label=name, type=["jpg","png","jpeg","JPG","PNG","JPEG"])
        
    if uploaded_file is not None:
        image = Image.open(uploaded_file).resize((size, size))
        st.image(image, channels = convert)
        return image

def st_text(name, quantify, enrich=None):
    if quantify.max_length>64:
        return st.text_area(label = name, )
    else:
        return st.text_input(label = name)
        
def st_multiselect(name, quantify, enrich=None):
    return st.multiselect(label=name, options=quantify.category.i2c)

def st_select(name, quantify, enrich=None):
    return st.selectbox(label=name, options=quantify.category.i2c)

In [30]:
# export
def get_enrich(trained,key):
    if 'enrich' not in trained.phase.config:
        return None
    for enrich in trained.phase['enrich']:
        if enrich['dst'] == key:
            return enrich
    return None


def build_app(trained):
    input_data = dict()
    for name, quantify in trained.qdict.items():
        if quantify.is_x:
            if type(quantify) == QuantifyNum:
                input_data.update({name: st_num(name, quantify)})
                logging.info(f"Input Field:{name} as float number")
            if type(quantify) == QuantifyImage:
                enrich = get_enrich(trained, name)
                input_data.update({name: st_image(name, quantify, enrich)})
                logging.info(f"Input Field:{name} as image")
            if type(quantify) == QuantifyText:
                logging.info(f"Input Field:{name} as text")
                input_data.update({name: st_text(name, quantify)})
            if type(quantify) == QuantifyMultiCategory:
                logging.info(f"Input Field:{name} as multi-select")
                val = st_multiselect(name, quantify)
                if quantify.actual_separator is not None:
                    val = quantify.actual_separator.join(val)
                input_data.update({name: val})
            if type(quantify) == QuantifyCategory:
                logging.info(f"Input Field:{name} as multi-select")
                input_data.update({name: st_select(name, quantify)})
            else:
                logging.warning(f"Don't know how to build Input Field:[{name}]")
                
    if st.button("Predict"):
        result = trained.predict(input_data)
        st.title("Result")
        if type(result)==pd.DataFrame:
            st.table(result.head(10))
        else:
            st.write(result)

build_app(trained)

# Update py file

In [11]:
import json

def load_all_code(ipynb_path):
    with open(ipynb_path, "r") as f:
        data = json.loads(f.read())

    code_blocks = []
    for cell in data["cells"]:
        if cell["cell_type"] == "code":
            source = cell["source"]
            if len(source) ==0: continue
            if source[0].startswith("# export"):
                code_blocks.append(''.join(source[1:]))
    return '\n\n'.join(code_blocks)
    


In [18]:
with open("./app.py","w") as f:
    f.write(load_all_code("streamlit_on_tai-chi.ipynb"))

In [9]:
trained.phase

PhaseConfig:{
  "enrich": [
    {
      "src": "path",
      "dst": "image",
      "kwargs": {
        "convert": "RGB",
        "size": 224
      },
      "enrich": "EnrichImage"
    },
    {
      "src": "path",
      "dst": "label",
      "kwargs": {},
      "enrich": "ParentAsLabel"
    }
  ],
  "quantify": [
    {
      "src": "image",
      "x": true,
      "kwargs": {
        "mean_": "imagenet",
        "std_": "imagenet"
      },
      "quantify": "QuantifyImage"
    },
    {
      "src": "label",
      "x": false,
      "kwargs": {
        "min_frequency": 1
      },
      "quantify": "QuantifyCategory"
    }
  ],
  "batch_level": {
    "valid_ratio": 0.1,
    "batch_size": 32,
    "shuffle": true,
    "num_workers": 0
  },
  "x_models": {
    "image": {
      "model_name": "ImageConvEncoder",
      "src": "image",
      "kwargs": {
        "name": "resnet18"
      }
    }
  },
  "y_models": {
    "label": {
      "model_name": "CategoryTop",
      "src": "label",
      "kwar