### import Library

In [8]:
import time
import importlib

import dash
import dash_core_components as dcc
import dash_html_components as html
import numpy as np
import pandas as pd
from dash.dependencies import Input, Output, State
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.manifold import TSNE
import plotly.graph_objects as go
from sklearn.decomposition import PCA

import ipynb.fs.full.dash_reusable_components as drc
import ipynb.fs.full.figures as figs

### app initialize

In [9]:
app = dash.Dash(
    __name__,
    meta_tags=[
        {"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
    ],
    assets_folder='C:/Users/zhaoy/OneDrive/Desktop/DashSVM/assets',
)
server = app.server

### Generate data

In [10]:
def generate_train_data(n_samples):
    fashion_idx = [
        "T-Shirt",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    idx_dic = {0: 'T-Shirt',
               1: 'Trouser',
               2: 'Pullover',
               3: 'Dress',
               4: 'Coat',
               5: 'Sandal',
               6: 'Shirt',
               7: 'Sneaker',
               8: 'Bag',
               9: 'Ankle boot'}
    df_train = pd.read_csv('fashion_mnist_data/fashion-mnist_train.csv')
    label_list = []
    df_train['label'] = df_train.label.replace(idx_dic)
    label_list = df_train['label']
    df_train.drop('label', 1, inplace=True)
    df_train = df_train.div(255.0)
    df_train.insert(0, 'label', label_list, False)
    df = df_train.sample(n=n_samples, random_state=42)
    df1 = df.loc[df['label'] != 'Ankle boot', :]
    df2 = df1.loc[df['label'] != 'Trouser', :]
    df3 = df2.loc[df['label'] != 'Bag', :]
    df4 = df3.loc[df['label'] != 'Coat', :]
    df5 = df4.loc[df['label'] != 'Sandal', :]
    df6 = df5.loc[df['label'] != 'Shirt', :]
    df7 = df6.loc[df['label'] != 'Sneaker', :]
    df8 = df7.loc[df['label'] != 'Dress', :]
    x_train = df8.loc[:, df8.columns != 'label']
    y_train= df8.label.values
    return x_train,y_train

In [11]:
def generate_test_data(n_samples):
    fashion_idx = [
        "T-Shirt",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    idx_dic = {0: 'T-Shirt',
               1: 'Trouser',
               2: 'Pullover',
               3: 'Dress',
               4: 'Coat',
               5: 'Sandal',
               6: 'Shirt',
               7: 'Sneaker',
               8: 'Bag',
               9: 'Ankle boot'}
    df_test=pd.read_csv('fashion_mnist_data/fashion-mnist_test.csv')
    label_list=[]
    df_test['label']=df_test.label.replace(idx_dic)
    label_list=df_test['label']
    df_test.drop('label', 1, inplace=True)
    df_test=df_test.div(255.0)
    df_test.insert(0, 'label',label_list, False)
    df_t = df_test.sample(n=int(n_samples/4), random_state=42)
    df1_t = df_t.loc[df_t['label'] != 'Ankle boot', :]
    df2_t = df1_t.loc[df_t['label'] != 'Trouser', :]
    df3_t = df2_t.loc[df_t['label'] != 'Bag', :]
    df4_t = df3_t.loc[df_t['label'] != 'Coat', :]
    df5_t = df4_t.loc[df_t['label'] != 'Sandal', :]
    df6_t = df5_t.loc[df_t['label'] != 'Shirt', :]
    df7_t = df6_t.loc[df_t['label'] != 'Sneaker', :]
    df8_t = df7_t.loc[df_t['label'] != 'Dress', :]
    x_test = df8_t.loc[:, df8_t.columns != 'label']
    y_test= df8_t.label.values
    return x_test,y_test

### interation (app callback)

In [12]:
@app.callback(
    Output("slider-svm-parameter-gamma-coef", "marks"),
    [Input("slider-svm-parameter-gamma-power", "value")],
)
def update_slider_svm_parameter_gamma_coef(power):
    scale = 10 ** power
    return {i: str(round(i * scale, 8)) for i in range(1, 10, 2)}


@app.callback(
    Output("slider-svm-parameter-C-coef", "marks"),
    [Input("slider-svm-parameter-C-power", "value")],
)
def update_slider_svm_parameter_C_coef(power):
    scale = 10 ** power
    return {i: str(round(i * scale, 8)) for i in range(1, 10, 2)}


@app.callback(
    Output("slider-threshold", "value"),
    [Input("button-zero-threshold", "n_clicks")],
    [State("graph-sklearn-svm", "figure")],
)
def reset_threshold_center(n_clicks, figure):
    if n_clicks:
        Z = np.array(figure["data"][0]["z"])
        value = -Z.min() / (Z.max() - Z.min())
    else:
        value = 0.4959986285375595
    return value


# Disable Sliders if kernel not in the given list
# 除了poly这个kernel其他的kernel都不可以调整degree
@app.callback(
    Output("slider-svm-parameter-degree", "disabled"),
    [Input("dropdown-svm-parameter-kernel", "value")],
)
def disable_slider_param_degree(kernel):
    return kernel != "poly"


@app.callback(
    Output("slider-svm-parameter-gamma-coef", "disabled"),
    [Input("dropdown-svm-parameter-kernel", "value")],
)
def disable_slider_param_gamma_coef(kernel):
    return kernel not in ["rbf", "poly", "sigmoid"]


@app.callback(
    dash.dependencies.Output('slider-output-container-cost', 'children'),
    [dash.dependencies.Input('cost-slider', 'value')])
def update_output_cost(value):
    return 'You have selected "{}"'.format(value)


@app.callback(
    dash.dependencies.Output('slider-output-container-gamma', 'children'),
    [dash.dependencies.Input('slider-svm-parameter-gamma', 'value')])
def update_output_gamma(value):
    return 'You have selected "{}"'.format(value)



# 整合所有的slider和svm_graph的关系
@app.callback(
    Output("div-graphs", "children"),
    [
        Input("dropdown-svm-parameter-kernel", "value"),
        Input("slider-svm-parameter-degree", "value"),
        Input("cost-slider", "value"),
        Input("slider-svm-parameter-gamma", "value"),
        Input("radio-svm-parameter-shrinking", "value"),
        Input("slider-threshold", "value"),
        Input("slider-dataset-sample-size", "value"),
    ],
)
def update_svm_graph(  # including ROC kurve,confusion matrix and svm graph
    kernel,
    degree,
    C_coef,
    gamma_coef,
    shrinking,
    threshold,
    sample_size,
):
    t_start = time.time()

    # Data for training
    X_train, y_train = generate_train_data(n_samples=sample_size)
    print("X_train",X_train.shape)#X_train:dataframe
    print("y_train",type(y_train))#y_train:array
    X_test, y_test = generate_test_data(n_samples=sample_size)
    
    C = C_coef
    gamma = gamma_coef

    if shrinking == "True":
        flag = True
    else:
        flag = False

    # Train SVM
    clf = SVC(C=C, kernel=kernel, degree=degree, gamma=gamma, shrinking=flag,probability=True)
    clf.fit(X_train, y_train)
    dec = clf.decision_function(X_train)
    prob=clf.predict_proba(X_test)
    prob_list=list(prob)
    for i in range(len(prob_list)):
        prob_list[i]=max(prob_list[i][0],prob_list[i][1])


    prediction_figure = figs.serve_prediction_plot(
        model=clf,
        X_train=X_train,
        X_test=X_test,
        y_train=y_train,
        y_test=y_test,
        threshold=threshold,
        prob_list=prob_list
    )

    return [
        html.Div(
            id="svm-graph-container",
            children=dcc.Loading(
                className="graph-wrapper",
                children=dcc.Graph(id="graph-sklearn-svm",
                                   figure=prediction_figure),
                style={"display": "none"},
            ),
        )
            ]

### App layout

In [13]:
app.layout = html.Div(
    children=[
        # .container class is fixed, .container.scalable is scalable
        html.Div(
            className="banner",
            children=[
                # Change App Name here
                html.Div(
                    className="container scalable",
                    children=[
                        # Change App Name here
                        html.H2(
                            id="banner-title",
                            children=[
                                html.A(
                                    "Support Vector Machine (SVM) Explorer",
                                    href="https://github.com/plotly/dash-svm",
                                    style={
                                        "text-decoration": "none",
                                        "color": "inherit",
                                    },
                                )
                            ],
                        ),
                        html.A(
                            id="banner-logo",
                            children=[
                                html.Img(src=app.get_asset_url(
                                    "dash-logo-new.png"))
                            ],
                            href="https://plot.ly/products/dash/",
                        ),
                    ],
                )
            ],
        ),
        html.Div(
            id="body",
            className="container scalable",
            children=[
                html.Div(
                    id="app-container",
                    # className="row",
                    children=[
                        html.Div(
                            # className="three columns",
                            id="left-column",
                            children=[
                                drc.NamedSlider(
                                    name="Sample Size",
                                    id="slider-dataset-sample-size",
                                    min=10000,
                                    max=40000,
                                    step=10000,
                                    marks={
                                        str(i): str(i)
                                        for i in [10000, 20000, 30000, 40000, 50000]
                                    },
                                    value=30000,
                                ),

                                drc.Card(
                                    id="button-card",
                                    children=[
                                        drc.NamedSlider(
                                            name="Threshold",
                                            id="slider-threshold",
                                            min=0,
                                            max=1,
                                            value=0.5,
                                            step=0.01,
                                        ),
                                        html.Button(
                                            "Reset Threshold",
                                            id="button-zero-threshold",
                                            style={"margin-right": "15px"}
                                        ),
                                    ],
                                ),

                                drc.Card(
                                    id="last-card",
                                    children=[
                                        drc.NamedDropdown(
                                            name="Kernel",
                                            id="dropdown-svm-parameter-kernel",
                                            options=[
                                                {
                                                    "label": "Radial basis function (RBF)",
                                                    "value": "rbf",
                                                },
                                                {"label": "Linear",
                                                 "value": "linear"
                                                 },
                                                {
                                                    "label": "Polynomial",
                                                    "value": "poly",
                                                },
                                                {
                                                    "label": "Sigmoid",
                                                    "value": "sigmoid",
                                                },
                                            ],
                                            value="rbf",
                                            clearable=False,
                                            searchable=False,

                                        ),
                                        html.Div(
                                            id="cost-container",
                                            children=[
                                                drc.NamedSlider(
                                                    name="C-coef",
                                                    id="cost-slider",
                                                    min=0.1,
                                                    max=100,
                                                    step=0.1,
                                                    value=10
                                                ),
                                                html.Div(
                                                    id="slider-output-container-cost",
                                                    style={'margin-top': 0.5
                                                           }
                                                )
                                            ]
                                        ),



                                        drc.NamedSlider(
                                            name="Degree",
                                            id="slider-svm-parameter-degree",
                                            min=2,
                                            max=10,
                                            value=3,
                                            step=0.5,
                                            marks={
                                                str(i): str(i) for i in range(2, 11, 2)
                                            },
                                        ),


                                        html.Div(
                                            id="gamma-container",
                                            children=[
                                                drc.NamedSlider(
                                                    name="gamma-coef",
                                                    id="slider-svm-parameter-gamma",
                                                    min=0,
                                                    max=1,
                                                    value=0.5,
                                                    step=0.00001
                                                ),
                                                html.Div(
                                                    id="slider-output-container-gamma",
                                                    style={'margin-top': 0.5
                                                           }
                                                )
                                            ]
                                        ),                                        


                                        html.Div(
                                            id="shrinking-container",
                                            children=[
                                                html.P(
                                                    children="Shrinking"),
                                                dcc.RadioItems(
                                                    id="radio-svm-parameter-shrinking",
                                                    labelStyle={
                                                        "margin-right": "7px",
                                                        "display": "inline-block",
                                                    },
                                                    options=[
                                                        {
                                                            "label": " Enabled",
                                                            "value": "True",
                                                        },
                                                        {
                                                            "label": " Disabled",
                                                            "value": "False",
                                                        },
                                                    ],
                                                    value="True",
                                                ),
                                            ],
                                        )
                                    ]
                                )
                            ]
                        ),
                        html.Div(
                            id="div-graphs",
                            children=dcc.Graph(
                                id="graph-sklearn-svm",
                                figure=dict(
                                    layout=dict(
                                        plot_bgcolor="#68786f", paper_bgcolor="#68786f"
                                    )  # Sets the background color of the plotting area in- between x and y axes.
                                    # Sets the legend background color
                                ),
                            ),
                        )
                    ]
                )
            ]
        )
    ]
)

### Running the server

In [None]:
if __name__ == "__main__":
    app.run_server(debug=False,port=8051)

Dash is running on http://127.0.0.1:8051/

Dash is running on http://127.0.0.1:8051/

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:8051/ (Press CTRL+C to quit)
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[36mGET /assets/base-styles.css?m=1618605243.493987 HTTP/1.1[0m" 304 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[36mGET /assets/custom-styles.css?m=1618605902.2833462 HTTP/1.1[0m" 304 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[36mGET /assets/dash-logo-new.png HTTP/1.1[0m" 304 -
127.0.0.1 - - [08/Jun/2021 11:30:47] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -


X_train (5973, 784)
y_train <class 'numpy.ndarray'>
[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 1481 samples in 0.058s...
[t-SNE] Computed neighbors for 1481 samples in 0.552s...
[t-SNE] Computed conditional probabilities for sample 1000 / 1481
[t-SNE] Computed conditional probabilities for sample 1481 / 1481
[t-SNE] Mean sigma: 2.042249
[t-SNE] KL divergence after 250 iterations with early exaggeration: 63.938782


127.0.0.1 - - [08/Jun/2021 11:35:10] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -


X_train (5973, 784)
y_train <class 'numpy.ndarray'>
[t-SNE] KL divergence after 1000 iterations: 0.813353


127.0.0.1 - - [08/Jun/2021 11:35:27] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -


[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 1481 samples in 0.038s...
[t-SNE] Computed neighbors for 1481 samples in 0.572s...
[t-SNE] Computed conditional probabilities for sample 1000 / 1481
[t-SNE] Computed conditional probabilities for sample 1481 / 1481
[t-SNE] Mean sigma: 2.042130
[t-SNE] KL divergence after 250 iterations with early exaggeration: 63.906723
[t-SNE] KL divergence after 1000 iterations: 0.801879


127.0.0.1 - - [08/Jun/2021 11:36:22] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
