In [2]:
#pip install dash
#pip install dash-renderer
#pip install dash-html-components
#pip install dash-core-components
#pip install plotly

# Model (Package)

In [1]:
from flair.data import Sentence
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from flair.embeddings import TransformerWordEmbeddings,StackedEmbeddings,WordEmbeddings,BytePairEmbeddings
from tqdm import tqdm
import datetime
from sklearn.model_selection import train_test_split
torch.manual_seed(1)

<torch._C.Generator at 0x7fbb2f27a2d0>

# UI

## Package

In [2]:
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.express as px
import pandas as pd
from dash.dependencies import Input, Output, State
import base64
import datetime
import io
import dash_table
import pandas as pd
import base64
import io

## Layout

In [3]:
colors = {
    'background': '#B8D1DE',
    'text': '#454E53',
    'c1': '#DBE8EE',
    'c2':'#F6F9FB'
}

In [4]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

In [5]:
app.layout = html.Div(style={'backgroundColor': colors['c2']}, children = [
                # 标题
                html.Div(
                    id="banner",
                    className="banner",
                    children=[html.H2("增值税编码分类器")],
                    style={'textAlign': 'center',
                          'color': colors['text']}
                ),
                # Tab
                dcc.Tabs([
                    # Label 1
                    dcc.Tab(label='编码搜索', children=[
                        
                        # 左侧栏
                        html.Div(
                            id="left-column",
                            className="four columns",
                            style={'textAlign': 'left',
                                   'backgroundColor': colors['c1']},
                            children=[
                            # Input
                            html.Br(),
                            html.H6("输入商品名称："),
                            html.Div(dcc.Input(id='input-on-submit', type='text', placeholder="商品")),
                            html.Button('Submit', id='submit-val', n_clicks=0),
                            # Slider
                            html.Br(),
                            html.Br(),
                            html.H6('选择输出数量：'),
                            html.Div(dcc.Slider(
                                id='num_slider',
                                min=0,
                                max=5,
                                marks={i: 'Label {}'.format(i) if i == 1 else str(i) for i in range(1, 6)},
                                value=5,
                            )),
                            # 下拉栏
                            html.Br(),
                            html.Br(),
                            html.H6('选择正确编码：'),
                            dcc.Markdown('''（如果右侧表格内无正确分类可在此选择正确分类，帮助提升搜索的准确性）'''),
                            html.Div([
                                dcc.Dropdown(
                                    id='demo-dropdown',
                                    options=[{'label': pd.read_excel('best_5.xlsx')["合并栏"].iloc[i], 'value': pd.read_excel('best_5.xlsx')["label"].iloc[i]} for i in range(len(pd.read_excel('best_5.xlsx')))],
                                    placeholder="输入商品的正确编码"
                                ),
                                html.Div(id='dd-output-container')
                            ]),
                            html.Br(),
                            html.Br(),
                            html.Br(),
                        ]),
                        # Output
                        html.Div([
                            html.Div(id='container-button-basic',
                                     children='商品编码表',
                                     style={'textAlign': 'left',
                                            'backgroundColor': colors['c2'],
                                            'font-size': '11.5px'})
                        ]),
                    ]),
                    
                    # Label 2
                    dcc.Tab(label='模型训练', id='field-dropdown', children=[
                        # 例子下载
                        html.Div([
                            html.Br(),
                            html.Br(),
                            html.H6('上传表格示例'),
                            dcc.Markdown('''（保证列名与下表相符）'''),
                            html.Div(id='table'),
                            html.A(
                                'Download Data',
                                id='download-link',
                                download="rawdata.csv",
                                href="",
                                target="_blank"
                            )
                        ]),
                        
                        # Train new data
                        html.Div([
                            html.Br(),
                            html.Br(),
                            dcc.Markdown('''在下方上传表格开始训练模型，模型训练好后会提示“训练完成”'''),
                            dcc.Upload(
                                id='upload-data',
                                children=html.Div([
                                    'Drag and Drop or ',
                                    html.A('Select Files')
                                ]),
                                style={
                                    'width': '100%',
                                    'height': '60px',
                                    'lineHeight': '60px',
                                    'borderWidth': '1px',
                                    'borderStyle': 'dashed',
                                    'borderRadius': '5px',
                                    'textAlign': 'center',
                                    'margin': '10px'
                                },
                                # Not Allow multiple files to be uploaded
                                multiple=False
                            ),
                            html.Br(),
                            html.Div(id='output-data-upload'),
                            html.Br(),
                        ])
                    ]),
                ], colors={"border": colors['c2'],
                           "primary": colors['text'],
                           "background": colors['background']
                        })
             ])
    

## Callback

### Callback for In & Out

In [6]:
class BiLSTM(nn.Module):
    def __init__(self, tag_to_ix,hidden_dim):
        super(BiLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(3472, hidden_dim // 2,
                            num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim,len(tag_to_ix))

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(device),
                torch.randn(2, 1, self.hidden_dim // 2).to(device))

    def forward(self, embeddings):
        self.hidden = self.init_hidden()
        embeds = embeddings.view(embeddings.shape[0], 1, -1)
        lstm_out, self.hidden = self.lstm(embeds)
        lstm_out = lstm_out[0].view(-1, self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

In [7]:
word_embedding = WordEmbeddings('zh')
byte_embedding = BytePairEmbeddings('zh')
bert_embedding = TransformerWordEmbeddings('bert-base-chinese')
stacked_embeddings = StackedEmbeddings(embeddings=[word_embedding,byte_embedding,bert_embedding])

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
def main(item, num):    
    model = torch.load('model.pt')
    
    def argmax(vec,k):
        prob, idx = torch.torch.topk(vec, k)
        return prob.tolist(),idx.tolist()
    test_embedding_matrix = []
    for i in tqdm(range(1)):
            embeddings = []
            sentence = Sentence(item)
            stacked_embeddings.embed(sentence)
            for token in sentence:
              embeddings.append(token.embedding)
            embeddings = torch.stack(embeddings)
            embeddings = embeddings.view(-1,1,3472)
            test_embedding_matrix.append(embeddings)
    for i in range(1):
      embedding = test_embedding_matrix[i]
      model.eval()
      result = model(embedding)
      prob = F.softmax(result,dim=1)
      prob = argmax(prob,num)
      output = [item for sublist in prob[1] for item in sublist]
      prob = [item for sublist in prob[0] for item in sublist]
      return output


In [10]:
@app.callback(
    dash.dependencies.Output('container-button-basic', 'children'),
    [dash.dependencies.Input('submit-val', 'n_clicks'),
     dash.dependencies.Input('num_slider', 'value')],
    [dash.dependencies.State('input-on-submit', 'value')])
def update_output(n_clicks, num, item): 
    b_5 = pd.read_excel('best_5.xlsx', index_col = 0)
    output = main(item, num)
    return html.Table([
        html.Thead(
            html.Tr([html.Th(col) for col in ['编码','货物和劳务名称','商品和服务分类简称','说明']])
        ),
        html.Tbody([
            html.Tr([
                html.Td(b_5.iloc[i][col]) for col in [0,2,3,4]
            ]) for i in output
        ])
    ])

### Callback for Dropdown

In [11]:
def online_learning(X_train,y_train,lr=0.0001,epoch=15):
    model = torch.load('model.pt')
    
    train_embedding_matrix = []
    for i in tqdm(range(len(X_train))):
        embeddings = []
        sentence = Sentence(X_train[i])
        stacked_embeddings.embed(sentence)
        for token in sentence:
          embeddings.append(token.embedding)
        embeddings = torch.stack(embeddings)
        embeddings = embeddings.view(-1,1,3472)
        train_embedding_matrix.append(embeddings)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epoch): 
        train_loss = 0
        model.train()
        for i in tqdm(range(len(X_train))):
            embeddings = train_embedding_matrix[i]
            tags_index = y_train[i]
            model.zero_grad()
            outputs = model(embeddings)
            targets = torch.tensor(tags_index, dtype=torch.long).to(device)
            loss = criterion(outputs,targets)
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()  
    torch.save(model,'model.pt')

In [12]:
@app.callback(
    dash.dependencies.Output('dd-output-container', 'children'),
    [dash.dependencies.Input('demo-dropdown', 'value')],
    [dash.dependencies.State('input-on-submit', 'value')])
def update_output(dropdown, input_value):
    online_learning([input_value], [[dropdown]])
    return html.H6("选择成功")

### Callback for Example Download

In [13]:
df_down = pd.DataFrame({
    '商品名称': ['甜奶油', '铝盖PET胶瓶'],
    '税收分类编码': [1030204020000000000,1070601120000000000],
    '搜索项': ['金钻甜奶油 植物淡奶油 做蛋糕裱花原料1kg', '透明塑料食品罐 密封罐 储物罐子']
})

In [14]:
def generate_table(dataframe, max_rows=10):
    return html.Table(
        # Header
        [html.Tr([html.Th(col) for col in dataframe.columns])] +

        # Body
        [html.Tr([
            html.Td(dataframe.iloc[i][col]) for col in dataframe.columns
        ]) for i in range(min(len(dataframe), max_rows))]
    )

In [15]:
@app.callback(
    dash.dependencies.Output('table', 'children'),
    [dash.dependencies.Input('field-dropdown', 'value')])
def update_table(filter_value):
    return generate_table(df_down)

In [16]:
@app.callback(
    dash.dependencies.Output('download-link', 'href'),
    [dash.dependencies.Input('field-dropdown', 'value')])
def update_download_link(filter_value):
    csv_string = df_down.to_csv(index=False, encoding='utf-8')
    csv_string = "data:text/csv;charset=utf-8,%EF%BB%BF" + urllib.parse.quote(csv_string)
    return csv_string

### Callback for Trainning new model

In [17]:
def get_file(contents, filename, date):
    content_type, content_string = contents.split(',')

    decoded = base64.b64decode(content_string)
    try:
        if 'csv' in filename:
            # Assume that the user uploaded a CSV file
            df = pd.read_csv(
                io.StringIO(decoded.decode('utf-8')))
        elif 'xls' in filename:
            # Assume that the user uploaded an excel file
            df = pd.read_excel(io.BytesIO(decoded))
    except Exception as e:
        print(e)
        return html.Div([
            'There was an error processing this file.'
        ])

    return df

In [18]:
def get_model(data):
    X = data['搜索项'].values
    y = data['税收分类编码'].values
    
    def to_index(data, to_ix):
        input_index_list = []
        for sent in data:
            input_index_list.append([to_ix[sent]])
        return input_index_list
    
    def to_index(data, to_ix):
        input_index_list = []
        for sent in data:
            input_index_list.append([to_ix[sent]])
        return input_index_list
    
    tag_to_ix = {}
    for tag in y:
      if tag not in tag_to_ix:
        tag_to_ix[tag] = len(tag_to_ix)
        
    y = to_index(y,tag_to_ix)
    
    X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.33)
    
    train_embedding_matrix = []
    for i in tqdm(range(len(X_train))):
            embeddings = []
            sentence = Sentence(X_train[i])
            stacked_embeddings.embed(sentence)
            for token in sentence:
              embeddings.append(token.embedding)
            embeddings = torch.stack(embeddings)
            embeddings = embeddings.view(-1,1,3472)
            train_embedding_matrix.append(embeddings)
    
    test_embedding_matrix = []
    for i in tqdm(range(len(X_test))):
            embeddings = []
            sentence = Sentence(X_test[i])
            stacked_embeddings.embed(sentence)
            for token in sentence:
              embeddings.append(token.embedding)
            embeddings = torch.stack(embeddings)
            embeddings = embeddings.view(-1,1,3472)
            test_embedding_matrix.append(embeddings)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BiLSTM(tag_to_ix,400).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=0.00001, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    best_model = None
    best_accuracy = 0
    for epoch in range(15):  
        train_loss = 0
        model.train()
        for i in tqdm(range(len(X_train))):
            embeddings = train_embedding_matrix[i]
            tags_index = y_train[i]
            model.zero_grad()
            outputs = model(embeddings)
            targets = torch.tensor(tags_index, dtype=torch.long).to(device)
            loss = criterion(outputs,targets)
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()  
    
    torch.save(model, 'model.pt')
    return tag_to_ix
            
    

In [19]:
def merged_label(tag_to_ix, data = pd.read_excel('分类编码表_4col.xlsx',index_col=0)):
    df = pd.DataFrame(tag_to_ix.items(),columns=['编码', 'label'])
    df_1 = df.merge(data, left_on='编码', right_on='合并编码', suffixes=(False, False), how= 'left')
    df_2 = df_1.sort_values(by=['label'],ascending=True)
    df_2["合并栏"] = df_2["编码"].astype(str) + df_2["货物和劳务名称"]
    return df_2.to_excel("best_5.xlsx")

In [20]:
@app.callback(Output('output-data-upload', 'children'),
              [Input('upload-data', 'contents')],
              [State('upload-data', 'filename'),
               State('upload-data', 'last_modified')])
def update_output(list_of_contents, list_of_names, list_of_dates):
    if list_of_contents is not None:
        #html.H3("Updating...")
        df = get_file(list_of_contents, list_of_names, list_of_dates)
        tag = get_model(df)
        merged_label(tag)
        return html.H6("训练完成")

In [None]:
if __name__ == '__main__':
    app.run_server(debug=False)

Dash is running on http://127.0.0.1:8050/

 in production, use a production WSGI server like gunicorn instead.

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [31/Jul/2020 09:59:33] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [31/Jul/2020 09:59:34] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
127.0.0.1 - - [31/Jul/2020 09:59:34] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -
127.0.0.1 - - [31/Jul/2020 09:59:34] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [31/Jul/2020 09:59:34] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
  0%|          | 0/1 [00:00<?, ?it/s]

Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 2447, in wsgi_app
    response = self.full_dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1952, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1821, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1950, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/


127.0.0.1 - - [31/Jul/2020 09:59:34] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -
  0%|          | 0/1 [00:00<?, ?it/s][A
  0%|          | 0/1 [00:00<?, ?it/s]

Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 2447, in wsgi_app
    response = self.full_dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1952, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1821, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1950, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/

127.0.0.1 - - [31/Jul/2020 09:59:34] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -



Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 2447, in wsgi_app
    response = self.full_dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1952, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1821, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/_compat.py", line 39, in reraise
    raise value
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/app.py", line 1950, in full_dispatch_request
    rv = self.dispatch_request()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/flask/

127.0.0.1 - - [31/Jul/2020 09:59:34] "[35m[1mPOST /_dash-update-component HTTP/1.1[0m" 500 -
100%|██████████| 1/1 [00:00<00:00,  8.72it/s]
127.0.0.1 - - [31/Jul/2020 09:59:40] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
100%|██████████| 1/1 [00:00<00:00, 17.58it/s]
127.0.0.1 - - [31/Jul/2020 09:59:42] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
100%|██████████| 1/1 [00:00<00:00, 16.04it/s]
127.0.0.1 - - [31/Jul/2020 09:59:43] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
100%|██████████| 1/1 [00:00<00:00, 12.55it/s]
127.0.0.1 - - [31/Jul/2020 09:59:45] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
100%|██████████| 1/1 [00:00<00:00, 16.67it/s]
127.0.0.1 - - [31/Jul/2020 09:59:46] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
