diff --git a/apps/dash-svm/Procfile b/apps/dash-svm/Procfile new file mode 100644 index 000000000..58e791494 --- /dev/null +++ b/apps/dash-svm/Procfile @@ -0,0 +1 @@ +web: gunicorn --pythonpath apps/dash-svm app:server diff --git a/apps/dash-svm/README.md b/apps/dash-svm/README.md new file mode 100644 index 000000000..4ba5272ad --- /dev/null +++ b/apps/dash-svm/README.md @@ -0,0 +1,73 @@ +# Support Vector Machine (SVM) Explorer [![Mentioned in Awesome Machine Learning](https://awesome.re/mentioned-badge.svg)](https://github.com/josephmisiti/awesome-machine-learning) + +This is a learning tool and exploration app made using the Dash interactive Python framework developed by [Plotly](https://plot.ly/). + +Dash abstracts away all of the technologies and protocols required to build an interactive web-based application and is a simple and effective way to bind a user interface around your Python code. To learn more check out our [documentation](https://plot.ly/dash). + +Try out the [demo app here](https://dash-svm.plot.ly/). + +![alt text](images/screenshot.png "Screenshot") + + +## Getting Started +### Using the demo +This demo lets you interactive explore Support Vector Machine (SVM). + +It includes a few artificially generated datasets that you can choose from the dropdown, and that you can modify by changing the sample size and the noise level of those datasets. + +The other dropdowns and sliders lets you change the parameters of your classifier, such that it could increase or decrease its accuracy. + +### Running the app locally + +First create a virtual environment with conda or venv inside a temp folder, then activate it. + +``` +virtualenv dash-svm-venv + +# Windows +dash-svm-venv\Scripts\activate +# Or Linux +source venv/bin/activate +``` + +Clone the git repo, then install the requirements with pip +``` +git clone https://github.com/plotly/dash-sample-apps/apps/dash-svm.git +cd dash-sample-apps/apps/dash-svm +pip install -r requirements.txt +``` + +Run the app +``` +python app.py +``` + +## About the app +### How does it work? + +This app is fully written in Dash + scikit-learn. All the components are used as input parameters for scikit-learn functions, which then generates a model with respect to the parameters you changed. The model is then used to perform predictions that are displayed on a contour plot, and its predictions are evaluated to create the ROC curve and confusion matrix. + +In addition to creating models, scikit-learn is used to generate the datasets you see, as well as the data needed for the metrics plots. + +### What is an SVM? +An SVM is a popular Machine Learning model used in many different fields. You can find an [excellent guide to how to use SVMs here](https://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf). + +## Built With +* [Dash](https://dash.plot.ly/) - Main server and interactive components +* [Plotly Python](https://plot.ly/python/) - Used to create the interactive plots +* [Scikit-Learn](http://scikit-learn.org/stable/documentation.html) - Run the classification algorithms and generate datasets + + +## Authors + +* **Xing Han Lu** - *Initial Work* - [@xhlulu](https://github.com/xhlulu) +* **Matthew Chan** - *Code Review* - [@matthewchan15](https://github.com/matthewchan15) +* **Yunke Xiao** - *Redesign* - [@YunkXiao](https://github.com/YunkeXiao) +* **celinehuang** - *Code Review* - [@celinehuang](https://github.com/celinehuang) + + +## Acknowledgments +The heatmap configuration is heavily inspired from the [scikit-learn Classification Comparison Tutorial](http://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html). Please go take a look! + +The idea of the [ROC Curve, the Matrix Pie Chart and Thresholding](https://github.com/nicolaskruchten/dash-roc) came from @nickruchten. The app would not have been as complete without his insightful advice. + diff --git a/apps/dash-svm/app.py b/apps/dash-svm/app.py new file mode 100644 index 000000000..e875959c7 --- /dev/null +++ b/apps/dash-svm/app.py @@ -0,0 +1,454 @@ +import time +import importlib + +import dash +import dash_core_components as dcc +import dash_html_components as html +import numpy as np +from dash.dependencies import Input, Output, State +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn import datasets +from sklearn.svm import SVC + +drc = importlib.import_module("apps.dash-svm.utils.dash_reusable_components") +figs = importlib.import_module("apps.dash-svm.utils.figures") + +app = dash.Dash( + __name__, + meta_tags=[ + {"name": "viewport", "content": "width=device-width, initial-scale=1.0"} + ], +) +server = app.server + + +def generate_data(n_samples, dataset, noise): + if dataset == "moons": + return datasets.make_moons(n_samples=n_samples, noise=noise, random_state=0) + + elif dataset == "circles": + return datasets.make_circles( + n_samples=n_samples, noise=noise, factor=0.5, random_state=1 + ) + + elif dataset == "linear": + X, y = datasets.make_classification( + n_samples=n_samples, + n_features=2, + n_redundant=0, + n_informative=2, + random_state=2, + n_clusters_per_class=1, + ) + + rng = np.random.RandomState(2) + X += noise * rng.uniform(size=X.shape) + linearly_separable = (X, y) + + return linearly_separable + + else: + raise ValueError( + "Data type incorrectly specified. Please choose an existing dataset." + ) + + +app.layout = html.Div( + children=[ + # .container class is fixed, .container.scalable is scalable + html.Div( + className="banner", + children=[ + # Change App Name here + html.Div( + className="container scalable", + children=[ + # Change App Name here + html.H2( + id="banner-title", + children=[ + html.A( + "Support Vector Machine (SVM) Explorer", + href="https://github.com/plotly/dash-svm", + style={ + "text-decoration": "none", + "color": "inherit", + }, + ) + ], + ), + html.A( + id="banner-logo", + children=[ + html.Img(src=app.get_asset_url("dash-logo-new.png")) + ], + href="https://plot.ly/products/dash/", + ), + ], + ) + ], + ), + html.Div( + id="body", + className="container scalable", + children=[ + html.Div( + id="app-container", + # className="row", + children=[ + html.Div( + # className="three columns", + id="left-column", + children=[ + drc.Card( + id="first-card", + children=[ + drc.NamedDropdown( + name="Select Dataset", + id="dropdown-select-dataset", + options=[ + {"label": "Moons", "value": "moons"}, + { + "label": "Linearly Separable", + "value": "linear", + }, + { + "label": "Circles", + "value": "circles", + }, + ], + clearable=False, + searchable=False, + value="moons", + ), + drc.NamedSlider( + name="Sample Size", + id="slider-dataset-sample-size", + min=100, + max=500, + step=100, + marks={ + str(i): str(i) + for i in [100, 200, 300, 400, 500] + }, + value=300, + ), + drc.NamedSlider( + name="Noise Level", + id="slider-dataset-noise-level", + min=0, + max=1, + marks={ + i / 10: str(i / 10) + for i in range(0, 11, 2) + }, + step=0.1, + value=0.2, + ), + ], + ), + drc.Card( + id="button-card", + children=[ + drc.NamedSlider( + name="Threshold", + id="slider-threshold", + min=0, + max=1, + value=0.5, + step=0.01, + ), + html.Button( + "Reset Threshold", + id="button-zero-threshold", + ), + ], + ), + drc.Card( + id="last-card", + children=[ + drc.NamedDropdown( + name="Kernel", + id="dropdown-svm-parameter-kernel", + options=[ + { + "label": "Radial basis function (RBF)", + "value": "rbf", + }, + {"label": "Linear", "value": "linear"}, + { + "label": "Polynomial", + "value": "poly", + }, + { + "label": "Sigmoid", + "value": "sigmoid", + }, + ], + value="rbf", + clearable=False, + searchable=False, + ), + drc.NamedSlider( + name="Cost (C)", + id="slider-svm-parameter-C-power", + min=-2, + max=4, + value=0, + marks={ + i: "{}".format(10 ** i) + for i in range(-2, 5) + }, + ), + drc.FormattedSlider( + id="slider-svm-parameter-C-coef", + min=1, + max=9, + value=1, + ), + drc.NamedSlider( + name="Degree", + id="slider-svm-parameter-degree", + min=2, + max=10, + value=3, + step=1, + marks={ + str(i): str(i) for i in range(2, 11, 2) + }, + ), + drc.NamedSlider( + name="Gamma", + id="slider-svm-parameter-gamma-power", + min=-5, + max=0, + value=-1, + marks={ + i: "{}".format(10 ** i) + for i in range(-5, 1) + }, + ), + drc.FormattedSlider( + id="slider-svm-parameter-gamma-coef", + min=1, + max=9, + value=5, + ), + html.Div( + id="shrinking-container", + children=[ + html.P(children="Shrinking"), + dcc.RadioItems( + id="radio-svm-parameter-shrinking", + labelStyle={ + "margin-right": "7px", + "display": "inline-block", + }, + options=[ + { + "label": " Enabled", + "value": "True", + }, + { + "label": " Disabled", + "value": "False", + }, + ], + value="True", + ), + ], + ), + ], + ), + ], + ), + html.Div( + id="div-graphs", + children=dcc.Graph( + id="graph-sklearn-svm", + figure=dict( + layout=dict( + plot_bgcolor="#282b38", paper_bgcolor="#282b38" + ) + ), + ), + ), + ], + ) + ], + ), + ] +) + + +@app.callback( + Output("slider-svm-parameter-gamma-coef", "marks"), + [Input("slider-svm-parameter-gamma-power", "value")], +) +def update_slider_svm_parameter_gamma_coef(power): + scale = 10 ** power + return {i: str(round(i * scale, 8)) for i in range(1, 10, 2)} + + +@app.callback( + Output("slider-svm-parameter-C-coef", "marks"), + [Input("slider-svm-parameter-C-power", "value")], +) +def update_slider_svm_parameter_C_coef(power): + scale = 10 ** power + return {i: str(round(i * scale, 8)) for i in range(1, 10, 2)} + + +@app.callback( + Output("slider-threshold", "value"), + [Input("button-zero-threshold", "n_clicks")], + [State("graph-sklearn-svm", "figure")], +) +def reset_threshold_center(n_clicks, figure): + if n_clicks: + Z = np.array(figure["data"][0]["z"]) + value = -Z.min() / (Z.max() - Z.min()) + else: + value = 0.4959986285375595 + return value + + +# Disable Sliders if kernel not in the given list +@app.callback( + Output("slider-svm-parameter-degree", "disabled"), + [Input("dropdown-svm-parameter-kernel", "value")], +) +def disable_slider_param_degree(kernel): + return kernel != "poly" + + +@app.callback( + Output("slider-svm-parameter-gamma-coef", "disabled"), + [Input("dropdown-svm-parameter-kernel", "value")], +) +def disable_slider_param_gamma_coef(kernel): + return kernel not in ["rbf", "poly", "sigmoid"] + + +@app.callback( + Output("slider-svm-parameter-gamma-power", "disabled"), + [Input("dropdown-svm-parameter-kernel", "value")], +) +def disable_slider_param_gamma_power(kernel): + return kernel not in ["rbf", "poly", "sigmoid"] + + +@app.callback( + Output("div-graphs", "children"), + [ + Input("dropdown-svm-parameter-kernel", "value"), + Input("slider-svm-parameter-degree", "value"), + Input("slider-svm-parameter-C-coef", "value"), + Input("slider-svm-parameter-C-power", "value"), + Input("slider-svm-parameter-gamma-coef", "value"), + Input("slider-svm-parameter-gamma-power", "value"), + Input("dropdown-select-dataset", "value"), + Input("slider-dataset-noise-level", "value"), + Input("radio-svm-parameter-shrinking", "value"), + Input("slider-threshold", "value"), + Input("slider-dataset-sample-size", "value"), + ], +) +def update_svm_graph( + kernel, + degree, + C_coef, + C_power, + gamma_coef, + gamma_power, + dataset, + noise, + shrinking, + threshold, + sample_size, +): + t_start = time.time() + h = 0.3 # step size in the mesh + + # Data Pre-processing + X, y = generate_data(n_samples=sample_size, dataset=dataset, noise=noise) + X = StandardScaler().fit_transform(X) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.4, random_state=42 + ) + + x_min = X[:, 0].min() - 0.5 + x_max = X[:, 0].max() + 0.5 + y_min = X[:, 1].min() - 0.5 + y_max = X[:, 1].max() + 0.5 + xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) + + C = C_coef * 10 ** C_power + gamma = gamma_coef * 10 ** gamma_power + + if shrinking == "True": + flag = True + else: + flag = False + + # Train SVM + clf = SVC(C=C, kernel=kernel, degree=degree, gamma=gamma, shrinking=flag) + clf.fit(X_train, y_train) + + # Plot the decision boundary. For that, we will assign a color to each + # point in the mesh [x_min, x_max]x[y_min, y_max]. + if hasattr(clf, "decision_function"): + Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) + else: + Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] + + prediction_figure = figs.serve_prediction_plot( + model=clf, + X_train=X_train, + X_test=X_test, + y_train=y_train, + y_test=y_test, + Z=Z, + xx=xx, + yy=yy, + mesh_step=h, + threshold=threshold, + ) + + roc_figure = figs.serve_roc_curve(model=clf, X_test=X_test, y_test=y_test) + + confusion_figure = figs.serve_pie_confusion_matrix( + model=clf, X_test=X_test, y_test=y_test, Z=Z, threshold=threshold + ) + + return [ + html.Div( + id="svm-graph-container", + children=dcc.Loading( + className="graph-wrapper", + children=dcc.Graph(id="graph-sklearn-svm", figure=prediction_figure), + style={"display": "none"}, + ), + ), + html.Div( + id="graphs-container", + children=[ + dcc.Loading( + className="graph-wrapper", + children=dcc.Graph(id="graph-line-roc-curve", figure=roc_figure), + ), + dcc.Loading( + className="graph-wrapper", + children=dcc.Graph( + id="graph-pie-confusion-matrix", figure=confusion_figure + ), + ), + ], + ), + ] + + +# Running the server +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/apps/dash-svm/assets/base-styles.css b/apps/dash-svm/assets/base-styles.css new file mode 100644 index 000000000..37e03533a --- /dev/null +++ b/apps/dash-svm/assets/base-styles.css @@ -0,0 +1,393 @@ +/* Table of contents +–––––––––––––––––––––––––––––––––––––––––––––––––– +- Grid +- Base Styles +- Typography +- Links +- Buttons +- Forms +- Lists +- Code +- Tables +- Spacing +- Utilities +- Clearing +- Media Queries +*/ + + +/* Grid +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.container { + position: relative; + width: 100%; + max-width: 960px; + margin: 0 auto; + padding: 0 20px; + box-sizing: border-box; } +.column, +.columns { + width: 100%; + float: left; + box-sizing: border-box; } + +/* For devices larger than 400px */ +@media (min-width: 400px) { + .container { + width: 85%; + padding: 0; } +} + +/* For devices larger than 550px */ +@media (min-width: 550px) { + .container { + width: 80%; } + .column, + .columns { + margin-left: 0.5%; } + .column:first-child, + .columns:first-child { + margin-left: 0; } + + .one.column, + .one.columns { width: 8%; } + .two.columns { width: 16.25%; } + .three.columns { width: 22%; } + .four.columns { width: 33%; } + .five.columns { width: 39.3333333333%; } + .six.columns { width: 49.75%; } + .seven.columns { width: 56.6666666667%; } + .eight.columns { width: 66.5%; } + .nine.columns { width: 74.0%; } + .ten.columns { width: 82.6666666667%; } + .eleven.columns { width: 91.5%; } + .twelve.columns { width: 100%; margin-left: 0; } + + .one-third.column { width: 30.6666666667%; } + .two-thirds.column { width: 65.3333333333%; } + + .one-half.column { width: 48%; } + + /* Offsets */ + .offset-by-one.column, + .offset-by-one.columns { margin-left: 8.66666666667%; } + .offset-by-two.column, + .offset-by-two.columns { margin-left: 17.3333333333%; } + .offset-by-three.column, + .offset-by-three.columns { margin-left: 26%; } + .offset-by-four.column, + .offset-by-four.columns { margin-left: 34.6666666667%; } + .offset-by-five.column, + .offset-by-five.columns { margin-left: 43.3333333333%; } + .offset-by-six.column, + .offset-by-six.columns { margin-left: 52%; } + .offset-by-seven.column, + .offset-by-seven.columns { margin-left: 60.6666666667%; } + .offset-by-eight.column, + .offset-by-eight.columns { margin-left: 69.3333333333%; } + .offset-by-nine.column, + .offset-by-nine.columns { margin-left: 78.0%; } + .offset-by-ten.column, + .offset-by-ten.columns { margin-left: 86.6666666667%; } + .offset-by-eleven.column, + .offset-by-eleven.columns { margin-left: 95.3333333333%; } + + .offset-by-one-third.column, + .offset-by-one-third.columns { margin-left: 34.6666666667%; } + .offset-by-two-thirds.column, + .offset-by-two-thirds.columns { margin-left: 69.3333333333%; } + + .offset-by-one-half.column, + .offset-by-one-half.columns { margin-left: 52%; } + +} + +/* Base Styles +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +/* NOTE +html is set to 62.5% so that all the REM measurements throughout Skeleton +are based on 10px sizing. So basically 1.5rem = 15px :) */ +html { + font-size: 62.5%; } +body { + font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */ + line-height: 1.6; + font-weight: 400; + font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; + color: #a5b1cd; + background-color: #282b38; +} + +/* Typography +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +h1, h2, h3, h4, h5, h6 { + margin-top: 0; + margin-bottom: 0; + font-weight: 300; } +h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; } +h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;} +h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;} +h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;} +h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;} +h6 { font-size: 2.0rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;} + +p { + margin-top: 0; } + + +/* Blockquotes +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +blockquote { + border-left: 4px #282b38 solid; + padding-left: 1rem; + margin-top: 2rem; + margin-bottom: 2rem; + margin-left: 0rem; +} + + +/* Links +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +a { + color: #1EAEDB; } +a:hover { + color: #0FA0CE; } + + +/* Buttons +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.button, +button, +input[type="submit"], +input[type="reset"], +input[type="button"] { + display: inline-block; + height: 38px; + padding: 0 30px; + color: #555; + text-align: center; + font-size: 11px; + font-weight: 600; + line-height: 38px; + letter-spacing: .1rem; + text-transform: uppercase; + text-decoration: none; + white-space: nowrap; + background-color: #282b38; + border-radius: 4px; + border: 1px solid #bbb; + cursor: pointer; + box-sizing: border-box; } +.button:hover, +button:hover, +input[type="submit"]:hover, +input[type="reset"]:hover, +input[type="button"]:hover, +.button:focus, +button:focus, +input[type="submit"]:focus, +input[type="reset"]:focus, +input[type="button"]:focus { + color: #333; + border-color: #888; + outline: 0; } +.button.button-primary, +button.button-primary, +input[type="submit"].button-primary, +input[type="reset"].button-primary, +input[type="button"].button-primary { + color: #FFF; + background-color: #33C3F0; + border-color: #33C3F0; } +.button.button-primary:hover, +button.button-primary:hover, +input[type="submit"].button-primary:hover, +input[type="reset"].button-primary:hover, +input[type="button"].button-primary:hover, +.button.button-primary:focus, +button.button-primary:focus, +input[type="submit"].button-primary:focus, +input[type="reset"].button-primary:focus, +input[type="button"].button-primary:focus { + color: #FFF; + background-color: #1EAEDB; + border-color: #1EAEDB; } + + +/* Forms +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +input[type="email"], +input[type="number"], +input[type="search"], +input[type="text"], +input[type="tel"], +input[type="url"], +input[type="password"], +textarea, +select { + height: 38px; + padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */ + background-color: #282b38; + border: 1px solid #D1D1D1; + border-radius: 4px; + box-shadow: none; + box-sizing: border-box; + font-family: inherit; + font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/} +/* Removes awkward default styles on some inputs for iOS */ +input[type="email"], +input[type="number"], +input[type="search"], +input[type="text"], +input[type="tel"], +input[type="url"], +input[type="password"], +textarea { + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; } +textarea { + min-height: 65px; + padding-top: 6px; + padding-bottom: 6px; } +input[type="email"]:focus, +input[type="number"]:focus, +input[type="search"]:focus, +input[type="text"]:focus, +input[type="tel"]:focus, +input[type="url"]:focus, +input[type="password"]:focus, +textarea:focus, +select:focus { + border: 1px solid #33C3F0; + outline: 0; } +label, +legend { + display: block; + margin-bottom: 0px; } +fieldset { + padding: 0; + border-width: 0; } +input[type="checkbox"], +input[type="radio"] { + display: inline; } +label > .label-body { + display: inline-block; + margin-left: .5rem; + font-weight: normal; } + + +/* Lists +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +ul { + list-style: circle inside; } +ol { + list-style: decimal inside; } +ol, ul { + padding-left: 0; + margin-top: 0; } +ul ul, +ul ol, +ol ol, +ol ul { + margin: 1.5rem 0 1.5rem 3rem; + font-size: 90%; } +li { + margin-bottom: 1rem; } + + +/* Tables +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +th, +td { + padding: 12px 15px; + text-align: left; + border-bottom: 1px solid #282b38; } +th:first-child, +td:first-child { + padding-left: 0; } +th:last-child, +td:last-child { + padding-right: 0; } + + +/* Spacing +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +button, +.button { + margin-bottom: 0rem; } +input, +textarea, +select, +fieldset { + margin-bottom: 0rem; } +pre, +dl, +figure, +table, +form { + margin-bottom: 0rem; } +p, +ul, +ol { + margin-bottom: 0.75rem; } + +/* Utilities +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.u-full-width { + width: 100%; + box-sizing: border-box; } +.u-max-full-width { + max-width: 100%; + box-sizing: border-box; } +.u-pull-right { + float: right; } +.u-pull-left { + float: left; } + + +/* Misc +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +hr { + margin-top: 3rem; + margin-bottom: 3.5rem; + border-width: 0; + border-top: 1px solid #282b38; } + + +/* Clearing +–––––––––––––––––––––––––––––––––––––––––––––––––– */ + +/* Self Clearing Goodness */ +.container:after, +.row:after, +.u-cf { + content: ""; + display: table; + clear: both; } + + +/* Media Queries +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +/* +Note: The best way to structure the use of media queries is to create the queries +near the relevant code. For example, if you wanted to change the styles for buttons +on small devices, paste the mobile query code up in the buttons section and style it +there. +*/ + + +/* Larger than mobile */ +@media (min-width: 400px) {} + +/* Larger than phablet (also point when grid becomes active) */ +@media (min-width: 550px) {} + +/* Larger than tablet */ +@media (min-width: 750px) {} + +/* Larger than desktop */ +@media (min-width: 1000px) {} + +/* Larger than Desktop HD */ +@media (min-width: 1200px) {} diff --git a/apps/dash-svm/assets/custom-styles.css b/apps/dash-svm/assets/custom-styles.css new file mode 100644 index 000000000..892a1d5de --- /dev/null +++ b/apps/dash-svm/assets/custom-styles.css @@ -0,0 +1,446 @@ +@import url('https://fonts.googleapis.com/css?family=Playfair+Display'); + +#body { + padding-bottom: 5rem; +} + +@media (max-width: 700px) { + #body { + width: 100%; + } +} + +/* Scalable Container +–––––––––––––––––––––––––––––––––––––––––––––––––– */ + +.container.scalable { + width: 95%; + max-width: none; +} + +/* Banner +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.banner { + background-color: #2f3445; /* Machine Learning Color is orange */ + padding: 3rem 0; + width: 100%; + margin-bottom: 5rem; +} + +.banner h2 { + color: #a5b1cd; + display: inline-block; + font-family: 'Playfair Display', sans-serif; + font-size: 4rem; + line-height: 1; + text-align: center; +} + +.banner h2:hover { + color: #b4b5bf; + cursor: pointer; +} + +.banner Img { + position: relative; + float: right; + height: 5rem; + margin-top: 1.25rem +} + +@media (max-width: 1300px) { + + .banner .container.scalable { + display: flex; + flex-direction: column-reverse; + justify-content: space-between; + align-items: center; + } + + .banner h2 { + font-size: 4rem; + } + + .banner Img { + height: 8rem; + margin-bottom: 3rem; + } +} + +@media (max-width: 700px) { + .banner .container.scalable { + padding: 0 + } + + .banner Img { + height: 4rem; + margin-bottom: 2rem; + } + + .banner h2 { + font-size: 3rem; + } +} + +@media (max-width: 500px) { + .banner { + padding: 1rem 0; + margin-bottom: 1rem; + } + + .banner Img { + height: 3rem; + margin-bottom: 1rem; + } + + .banner h2 { + font-size: 2.5rem; + } +} + +/* app-container +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +@media (min-width: 951px) { + #app-container { + width: 100%; + display: flex; + flex-direction: row; + align-items: flex-start; + } +} + +/* Dropdowns +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.Select-control { + color: #a5b1cd; +} + +.Select { + color: #a5b1cd; +} + +.Select-menu-outer { + background-color: #2f3445; + border: 1px solid gray; +} + +.Select div { + background-color: #2f3445; +} + +.Select-menu-outer div:hover { + background-color: rgba(255, 255, 255, 0.01); +} + +.Select-value-label { + color: #a5b1cd !important; +} + +.Select--single > .Select-control .Select-value, .Select-placeholder { + border: 1px solid gray; + border-radius: 4px; +} + +.card { + border-bottom: 1px solid rgba(255, 255, 255, 0.1); + width: 80%; + padding: 2rem 0; +} + +#last-card { + border-bottom: none; +} + +.graph-title { + font-size: 2rem; + margin: 15% 0 0 25%; +} + +#button-zero-threshold { + background-color: #2f3445; + color: #a5b1cd; + border-color: gray; +} + +#button-zero-threshold:hover { + border-color: white; +} + +#button-card { + display: flex; + flex-direction: column; +} + +#first-card { + padding-top: 0; +} + +@media (max-width: 1500px) { + .rc-slider-mark-text { + font-size: 0.7vw; + } +} + +@media (max-width: 950px) { + .rc-slider-mark-text { + font-size: 1.5vw; + } +} + + +@media (max-width: 650px) { + .rc-slider-mark-text { + font-size: 2vw; + } +} + + +@media (min-width: 1301px) { + #left-column { + flex: 1 20%; + margin: 0 3rem 0 0; + max-height: 83rem; + overflow-x: hidden; + overflow-y: auto; + } + + .card { + padding-left: 2rem; + } +} + +@media (max-width: 1300px) { + #left-column { + flex: 1 20%; + margin: 0 3rem 0 0; + max-height: 70rem; + overflow-x: hidden; + overflow-y: auto; + } +} + +@media (max-width: 1200px) { + #button-zero-threshold { + font-size: 0.8rem; + padding: 0 + } +} + +@media (max-width: 950px) { + #button-zero-threshold { + font-size: 1.0rem; + padding: 0 + } +} + +@media (max-width: 950px) { + #left-column { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + margin: 0; + max-height: none; + max-width: none; + } +} + +@media (max-width: 700px) { + #first-card { + margin: 0; + padding-right: 0; + padding-left: 0; + } +} + +/* Slider +–––––––––––––––––––––––––––––––––––––––––––––––––– */ + +.rc-slider-track { + background-color: #13c6e9; +} + +.rc-slider-handle { + border: solid 2px #13c6e9; +} + +/* Left column +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +#slider-svm-parameter-C-coef { + padding: 5px 10px 25px; +} + +#slider-svm-parameter-gamma-coef { + padding: 5px 10px 25px +} + +#shrinking-container { + padding: 20px 10px 25px 4px +} + +/* Graph container +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +#svm-graph-container { + margin-top: 0.5rem; +} + +#svm-graph-container .graph-wrapper { + height: 100%; + width: 100%; +} + +@media (min-width: 951px) { + #graphs-container { + display: flex; + flex-direction: column; + align-items: stretch; + justify-content: flex-start; + } + + #graphs-container .graph-wrapper { + flex: 1 50%; + } + + #graph-line-roc-curve, #graph-pie-confusion-matrix { + height: 100%; + width: 100%; + } + + #div-graphs { + flex: 4 80%; + display: flex; + flex-direction: row; + justify-content: center; + align-items: stretch; + } + + #graph-sklearn-svm { + height: 100%; + } + + #svm-graph-container { + flex: 2 66%; + } + + #graphs-container { + flex: 1 33%; + margin-top: 0.5rem; + } +} + +@media (min-width: 1301px) { + #div-graphs { + height: 83rem; + } +} + +@media (max-width: 1300px) { + #div-graphs { + height: 70rem; + } + + .gtitle { + font-size: 1.25rem !important; + } + + .xtitle, .ytitle { + font-size: 0.9rem !important; + } +} + +@media (max-width: 950px) { + #div-graphs { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + justify-content: flex-start; + } + + #svm-graph-container { + width: 80vw; + height: 100vw; + display: flex; + flex-direction: column; + align-items: center; + padding-bottom: 4rem; + border-bottom: solid 1px rgba(255, 255, 255, 0.1); + } + + #svm-graph-container .graph-wrapper { + width: 80vw; + height: 80vw; + } + + #graphs-container { + display: flex; + flex-direction: column; + align-items: stretch; + justify-content: flex-start; + width: 90%; + max-width: 90%; + margin: 0 0 0 -1rem; + max-height: none; + } + + #graphs-container .graph-wrapper:nth-of-type(1) { + height: 40rem; + margin: 3rem 0 1rem -3rem; + width: 100%; + } + + #graphs-container .graph-wrapper:nth-of-type(2) { + height: 60rem; + margin-bottom: 5rem; + width: 100%; + } + + #graph-sklearn-svm, #graph-line-roc-curve, #graph-pie-confusion-matrix { + height: 100%; + width: 100%; + } +} + +@media (max-width: 650px) { + #graphs-container { + align-items: center; + width: 100%; + max-width: 100%; + } + + #graph-line-roc-curve { + width: 95%; + margin: 3rem 0 1rem 0; + } + + #graph-pie-confusion-matrix { + width: 95%; + margin-bottom: 5rem; + } + + #graphs-container .graph-wrapper:nth-of-type(1) { + height: 35rem; + } + + #graphs-container .graph-wrapper:nth-of-type(2) { + height: 45rem; + padding-left: 3rem; + } +} + +@media (max-width: 400px) { + #graphs-container .graph-wrapper:nth-of-type(1) { + height: 25rem; + } + + #graphs-container .graph-wrapper:nth-of-type(2) { + height: 40rem; + padding-left: 3rem; + } +} + +/* Remove Undo +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +#graph-line-roc-curve .modebar, #graph-pie-confusion-matrix .modebar { + display: none; +} diff --git a/apps/dash-svm/assets/dash-logo-new.png b/apps/dash-svm/assets/dash-logo-new.png new file mode 100644 index 000000000..eb700fc71 Binary files /dev/null and b/apps/dash-svm/assets/dash-logo-new.png differ diff --git a/apps/dash-svm/images/animated1.gif b/apps/dash-svm/images/animated1.gif new file mode 100644 index 000000000..9790ba478 Binary files /dev/null and b/apps/dash-svm/images/animated1.gif differ diff --git a/apps/dash-svm/images/screenshot.png b/apps/dash-svm/images/screenshot.png new file mode 100644 index 000000000..5407c3162 Binary files /dev/null and b/apps/dash-svm/images/screenshot.png differ diff --git a/apps/dash-svm/requirements.txt b/apps/dash-svm/requirements.txt new file mode 100644 index 000000000..3367c0cce --- /dev/null +++ b/apps/dash-svm/requirements.txt @@ -0,0 +1,10 @@ +# Core +gunicorn>=19.8.1 +dash>=1.0.0 + +# Additional +colorlover>=0.2.1 +numpy>=1.16.2 +pandas>=0.24.2 +scikit-learn>=0.20.3 +scipy>=1.2.1 diff --git a/apps/dash-svm/utils/README.md b/apps/dash-svm/utils/README.md new file mode 100644 index 000000000..0eac67920 --- /dev/null +++ b/apps/dash-svm/utils/README.md @@ -0,0 +1,7 @@ +# Utility Files + +## Dash Reusable Components + +Creating custom, reusable components lets you improve workflow and keep repetitions to a minimum (DRY). In this app, there are a few components that have the same pattern, but with only small differences; for example, a dropdown menu with an associated name. In these cases, reusable components were useful to keep the design of those repeated components consistent, and make the app layout less crowded. + +To read more about Reusable components, check out [this workshop by Plotly](https://dash-workshop.plot.ly/reusable-components). diff --git a/apps/dash-svm/utils/dash_reusable_components.py b/apps/dash-svm/utils/dash_reusable_components.py new file mode 100644 index 000000000..e080d0876 --- /dev/null +++ b/apps/dash-svm/utils/dash_reusable_components.py @@ -0,0 +1,75 @@ +from textwrap import dedent + +import dash_core_components as dcc +import dash_html_components as html + + +# Display utility functions +def _merge(a, b): + return dict(a, **b) + + +def _omit(omitted_keys, d): + return {k: v for k, v in d.items() if k not in omitted_keys} + + +# Custom Display Components +def Card(children, **kwargs): + return html.Section(className="card", children=children, **_omit(["style"], kwargs)) + + +def FormattedSlider(**kwargs): + return html.Div( + style=kwargs.get("style", {}), children=dcc.Slider(**_omit(["style"], kwargs)) + ) + + +def NamedSlider(name, **kwargs): + return html.Div( + style={"padding": "20px 10px 25px 4px"}, + children=[ + html.P(f"{name}:"), + html.Div(style={"margin-left": "6px"}, children=dcc.Slider(**kwargs)), + ], + ) + + +def NamedDropdown(name, **kwargs): + return html.Div( + style={"margin": "10px 0px"}, + children=[ + html.P(children=f"{name}:", style={"margin-left": "3px"}), + dcc.Dropdown(**kwargs), + ], + ) + + +def NamedRadioItems(name, **kwargs): + return html.Div( + style={"padding": "20px 10px 25px 4px"}, + children=[html.P(children=f"{name}:"), dcc.RadioItems(**kwargs)], + ) + + +# Non-generic +def DemoDescription(filename, strip=False): + with open(filename, "r") as file: + text = file.read() + + if strip: + text = text.split("")[-1] + text = text.split("")[0] + + return html.Div( + className="row", + style={ + "padding": "15px 30px 27px", + "margin": "45px auto 45px", + "width": "80%", + "max-width": "1024px", + "borderRadius": 5, + "border": "thin lightgrey solid", + "font-family": "Roboto, sans-serif", + }, + children=dcc.Markdown(dedent(text)), + ) diff --git a/apps/dash-svm/utils/figures.py b/apps/dash-svm/utils/figures.py new file mode 100644 index 000000000..75c994516 --- /dev/null +++ b/apps/dash-svm/utils/figures.py @@ -0,0 +1,166 @@ +import colorlover as cl +import plotly.graph_objs as go +import numpy as np +from sklearn import metrics + + +def serve_prediction_plot( + model, X_train, X_test, y_train, y_test, Z, xx, yy, mesh_step, threshold +): + # Get train and test score from model + y_pred_train = (model.decision_function(X_train) > threshold).astype(int) + y_pred_test = (model.decision_function(X_test) > threshold).astype(int) + train_score = metrics.accuracy_score(y_true=y_train, y_pred=y_pred_train) + test_score = metrics.accuracy_score(y_true=y_test, y_pred=y_pred_test) + + # Compute threshold + scaled_threshold = threshold * (Z.max() - Z.min()) + Z.min() + range = max(abs(scaled_threshold - Z.min()), abs(scaled_threshold - Z.max())) + + # Colorscale + bright_cscale = [[0, "#ff3700"], [1, "#0b8bff"]] + cscale = [ + [0.0000000, "#ff744c"], + [0.1428571, "#ff916d"], + [0.2857143, "#ffc0a8"], + [0.4285714, "#ffe7dc"], + [0.5714286, "#e5fcff"], + [0.7142857, "#c8feff"], + [0.8571429, "#9af8ff"], + [1.0000000, "#20e6ff"], + ] + + # Create the plot + # Plot the prediction contour of the SVM + trace0 = go.Contour( + x=np.arange(xx.min(), xx.max(), mesh_step), + y=np.arange(yy.min(), yy.max(), mesh_step), + z=Z.reshape(xx.shape), + zmin=scaled_threshold - range, + zmax=scaled_threshold + range, + hoverinfo="none", + showscale=False, + contours=dict(showlines=False), + colorscale=cscale, + opacity=0.9, + ) + + # Plot the threshold + trace1 = go.Contour( + x=np.arange(xx.min(), xx.max(), mesh_step), + y=np.arange(yy.min(), yy.max(), mesh_step), + z=Z.reshape(xx.shape), + showscale=False, + hoverinfo="none", + contours=dict( + showlines=False, type="constraint", operation="=", value=scaled_threshold + ), + name=f"Threshold ({scaled_threshold:.3f})", + line=dict(color="#708090"), + ) + + # Plot Training Data + trace2 = go.Scatter( + x=X_train[:, 0], + y=X_train[:, 1], + mode="markers", + name=f"Training Data (accuracy={train_score:.3f})", + marker=dict(size=10, color=y_train, colorscale=bright_cscale), + ) + + # Plot Test Data + trace3 = go.Scatter( + x=X_test[:, 0], + y=X_test[:, 1], + mode="markers", + name=f"Test Data (accuracy={test_score:.3f})", + marker=dict( + size=10, symbol="triangle-up", color=y_test, colorscale=bright_cscale + ), + ) + + layout = go.Layout( + xaxis=dict(ticks="", showticklabels=False, showgrid=False, zeroline=False), + yaxis=dict(ticks="", showticklabels=False, showgrid=False, zeroline=False), + hovermode="closest", + legend=dict(x=0, y=-0.01, orientation="h"), + margin=dict(l=0, r=0, t=0, b=0), + plot_bgcolor="#282b38", + paper_bgcolor="#282b38", + font={"color": "#a5b1cd"}, + ) + + data = [trace0, trace1, trace2, trace3] + figure = go.Figure(data=data, layout=layout) + + return figure + + +def serve_roc_curve(model, X_test, y_test): + decision_test = model.decision_function(X_test) + fpr, tpr, threshold = metrics.roc_curve(y_test, decision_test) + + # AUC Score + auc_score = metrics.roc_auc_score(y_true=y_test, y_score=decision_test) + + trace0 = go.Scatter( + x=fpr, y=tpr, mode="lines", name="Test Data", marker={"color": "#13c6e9"} + ) + + layout = go.Layout( + title=f"ROC Curve (AUC = {auc_score:.3f})", + xaxis=dict(title="False Positive Rate", gridcolor="#2f3445"), + yaxis=dict(title="True Positive Rate", gridcolor="#2f3445"), + legend=dict(x=0, y=1.05, orientation="h"), + margin=dict(l=100, r=10, t=25, b=40), + plot_bgcolor="#282b38", + paper_bgcolor="#282b38", + font={"color": "#a5b1cd"}, + ) + + data = [trace0] + figure = go.Figure(data=data, layout=layout) + + return figure + + +def serve_pie_confusion_matrix(model, X_test, y_test, Z, threshold): + # Compute threshold + scaled_threshold = threshold * (Z.max() - Z.min()) + Z.min() + y_pred_test = (model.decision_function(X_test) > scaled_threshold).astype(int) + + matrix = metrics.confusion_matrix(y_true=y_test, y_pred=y_pred_test) + tn, fp, fn, tp = matrix.ravel() + + values = [tp, fn, fp, tn] + label_text = ["True Positive", "False Negative", "False Positive", "True Negative"] + labels = ["TP", "FN", "FP", "TN"] + blue = cl.flipper()["seq"]["9"]["Blues"] + red = cl.flipper()["seq"]["9"]["Reds"] + colors = ["#13c6e9", blue[1], "#ff916d", "#ff744c"] + + trace0 = go.Pie( + labels=label_text, + values=values, + hoverinfo="label+value+percent", + textinfo="text+value", + text=labels, + sort=False, + marker=dict(colors=colors), + insidetextfont={"color": "white"}, + rotation=90, + ) + + layout = go.Layout( + title="Confusion Matrix", + margin=dict(l=50, r=50, t=100, b=10), + legend=dict(bgcolor="#282b38", font={"color": "#a5b1cd"}, orientation="h"), + plot_bgcolor="#282b38", + paper_bgcolor="#282b38", + font={"color": "#a5b1cd"}, + ) + + data = [trace0] + figure = go.Figure(data=data, layout=layout) + + return figure