In this post, we will be building a data visualization using Plotly Dash 2.0 (in fact, it will be the same regression visualization that was created in these two blog posts: HERE and HERE). As a reminder from those two posts, this visualization will demonstrate linear regression on random data by creating a sample dataset using the scikit-learn function "make_regression", which will actually generate the data behind-the-scenes. Plotly controls will be used to supply the required parameters to the function.

The first section below contains the header of the function (with useful information about what it does, who wrote it, when it was last updated), along with importing all required libraries.

In [1]:
import numpy as np
import scipy
from scipy import stats
import pandas as pd

import math
import torch
import gpytorch

import dash
import plotly.graph_objs as go
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State

from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression as LR
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel, RBF
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.preprocessing import OneHotEncoder
import pickle

from imblearn.over_sampling import SMOTE
from collections import Counter
import matplotlib.pyplot as plt

TRAINING_ITER = 20
NUM_TREE = 10000
TREE_MAX_DEPTH = 12

%matplotlib inline
%load_ext autoreload
%autoreload 2


The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).





As one of the first things that will be seen when the code is opened, a parameter defaults section allows the controls to be easily changed, without digging too deeply into the code. In the code below, the defaults and ranges for the number of samples, the bias, and the amount of noise can be adjusted easily. These will be inputs to the "make_regression" function.

In [2]:
###------------------------------------------------------------------------###
###------------------------PARAMETER DEFAULTS------------------------------###
### This section contains defaults and ranges for the Plotly controls and  ###
### may be modified without concern, if required.                          ###
###------------------------------------------------------------------------###
# The format for this section is: default, range[Lower, Upper, Step Size]

d_nsamp, r_nsamp = 100, [50, 500, 50] # Number of samples
d_bias, r_bias = 0, [-50, 50, 5] # Bias
d_noise, r_noise = 3, [0, 20, 1] # Amount of noise

Based on the different way that Plotly apps are laid out, the order of the sections will be changed somewhat, compared to the Bokeh example. The next section contains the base level function(s), which in this case will be used to generate data.

In [3]:
###------------------------------------------------------------------------###
###-----------------------BASE-LEVEL FUNCTIONS-----------------------------###
### This section contains the low-level calculations required for the      ###
### regression modeling.                                                  ###
###------------------------------------------------------------------------###
def create_data(n_samp, bias, noise):
    # Creates a set of random data based on user parameters
    data = make_regression(n_samp, 1, 1, 1, bias, noise=noise)
    x_data = np.array([i[0] for i in data[0]]) # Because it's an array of arrays
    y_data = data[1]
    return x_data, y_data

The next section defines the layout of the Plotly Dash app, including the controls and graphs that will be present in the GUI. This layout includes a plot of the data and the sliders for the parameters (number of samples, bias, and noise). Note that the defaults and ranges for the controls were defined in the previous section, so this section should not need modification unless new controls are being added. 

In [4]:
presidential = pd.read_csv("2020_presidential_tracker.csv", delimiter = ',')

X_features = presidential[['greek', 'athlete', 'financialAid',
       'gender', 'geography', 'highschool', 'legacy', 'major', 'orientation',
       'race', 'year', 'school', 'q5', 'q3', 'q1', 'q7', 'q4', 'q2', 'q6',
       'q8']]

X_features = X_features[X_features['q6'] != 5.0]
X_features = X_features[X_features['q7'] != 5.0]
X_features = X_features.dropna(subset = ['major','gender','orientation','geography',
                                         'q2', 'q3', 'q4', 'q5', 'q7', 'q8'])

X_features['q6'].replace([1.0, 2.0, 3.0, 4.0], [8.0, 7.0, 6.0, 5.0], inplace=True)
X_features['q6'].replace([5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0], inplace=True)
X_features['q7'].replace([1.0, 2.0, 3.0, 4.0], [8.0, 7.0, 6.0, 5.0], inplace=True)
X_features['q7'].replace([5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0], inplace=True)

y_labels_d = X_features[['q6']]
y_labels_r = X_features[['q7']]
X_features = X_features[['major','gender','orientation','geography','q2', 'q3', 'q4', 'q5', 'q8']]

# One hot encoding

In [5]:
enc = OneHotEncoder()
enc.fit(X_features)
X_features_onehot = enc.transform(X_features).toarray()


The handling of integer data will change in version 0.22. Currently, the categories are determined based on the range [0, max(values)], while in the future they will be determined based on the unique values.
In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.



In [6]:
X_train_val, X_test, y_train_val_d, y_test_d = train_test_split(X_features_onehot, y_labels_d, test_size=0.12, random_state = 24)

In [7]:
y_train_val_r = y_labels_r.loc[y_train_val_d.index]
y_test_r = y_labels_r.loc[y_test_d.index]

In [8]:
sm = SMOTE()
X_smote_d, y_smote_d = sm.fit_sample(X_train_val, y_train_val_d)
print('Resampled dataset shape %s' % Counter(y_smote_d))
X_smote_r, y_smote_r = sm.fit_sample(X_train_val, y_train_val_r)
print('Resampled dataset shape %s' % Counter(y_smote_r))



A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().


A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().


A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().


A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().



Resampled dataset shape Counter({3.0: 503, 2.0: 503, 1.0: 503, 4.0: 503})
Resampled dataset shape Counter({1.0: 464, 2.0: 464, 3.0: 464, 4.0: 464})


In [9]:
for i in range(len(X_smote_d)):
    for j in range(len(X_smote_d[i])):
        if X_smote_d[i][j] != 0:
            X_smote_d[i][j] = 1

for i in range(len(X_smote_r)):
    for j in range(len(X_smote_r[i])):
        if X_smote_r[i][j] != 0:
            X_smote_r[i][j] = 1

In [10]:
from sklearn.gaussian_process import GaussianProcessRegressor

from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
from sklearn.gaussian_process.kernels import RBF

kernel = DotProduct() + WhiteKernel()
# kernel = RBF()
gpr_d = GaussianProcessRegressor(kernel=kernel,random_state=0).fit(X_smote_d, y_smote_d)
gpr_r = GaussianProcessRegressor(kernel=kernel,random_state=0).fit(X_smote_r, y_smote_r)

In [11]:
d_major, r_major = 1, [1,73,1] # major
d_gender, r_gender = 1, [1,6,1] # gender
d_orient, r_orient = 1, [1,7,1] # orientation
d_geog, r_geog = 1, [1,8,1] # geography
# Regardless of who you may support in the upcoming 2020 presidential election
d_q2, r_q2 = 1, [1,20,1] # Q2: who would you most like to be the Democratic nominee for President?
d_q3, r_q3 = 1, [1,3,1] # Q3: would you prefer (Donald) Trump to be the Republican, or prefer another Republican?
d_q4, r_q4 = 1, [1,4,1] # Q4: would you vote for the Democratic candidate, or Donald Trump or not vote?
d_q5, r_q5 = 1, [1,6,1] # Q5: a favorable or unfavorable opinion of President Donald Trump?
d_q7, r_q7 = 1, [1,6,1] # Q7: a favorable or unfavorable opinion of The Republican Party?
d_q8, r_q8 = 1, [1,3,1] # Q8: do you think Donald Trump will be reelected president in 2020, or not?

In [12]:
# App defines the entire Dash application, containing the layout and controls 
# for the visualization (in app.layout)
app = dash.Dash()

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div(children=[   
    html.Div('Major', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Major',
        options=[
            {'label': str(i), 'value': i} for i in range(r_major[0], r_major[1], r_major[2])
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
   
    html.Div('Gender', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Gender',
        options=[
            {'label': 'man', 'value': 1.0},
            {'label': 'woman', 'value': 2.0},
            {'label': 'transgender male', 'value': 3.0},
            {'label': 'transgender female', 'value': 4.0},
            {'label': 'genderqueer', 'value': 5.0},
            {'label': 'other', 'value': 6.0},     
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Orientation', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Orientation',
        options=[
            {'label': 'lesbian/gay', 'value': 1.0},
            {'label': 'straight', 'value': 2.0}, # Heterosexual
            {'label': 'bisexual', 'value': 3.0},
            {'label': 'pansexual', 'value': 4.0},
            {'label': 'queer', 'value': 5.0},
            {'label': 'questioning', 'value': 6.0},
            {'label': 'other', 'value': 7.0},
        ],
        value= 2.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Geography', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Geography',
        options=[
            {'label': 'New England', 'value': 1.0},
            {'label': 'Mid-Atlantic', 'value': 2.0},
            {'label': 'Midwest', 'value': 3.0},
            {'label': 'South', 'value': 4.0},
            {'label': 'West', 'value': 5.0},
            {'label': 'Non-Continental', 'value': 6.0},
            {'label': 'International', 'value': 7.0},
            {'label': 'Other', 'value': 8.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Question 2', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Q2',
        options=[
            {'label': 'Sanders', 'value': 1.0},
            {'label': 'Booker', 'value': 2.0},
            {'label': 'Harris', 'value': 3.0},
            {'label': 'Warren', 'value': 4.0},
            {'label': 'Klobuchar', 'value': 5.0},
            {'label': 'Castro', 'value': 6.0},
            {'label': 'Gillibrand', 'value': 7.0},
            {'label': 'Gabbard', 'value': 8.0},
            {'label': 'Biden', 'value': 9.0},
            {'label': "O'Rouke", 'value': 10.0},
            {'label': 'Inslee', 'value': 11.0},
            {'label': 'Delaney', 'value': 12.0},
            {'label': 'Kickenlooper', 'value': 13.0},
            {'label': 'Buttigieg', 'value': 14.0},
            {'label': 'Bullock', 'value': 15.0},
            {'label': 'McAuliffe', 'value': 16.0},
            {'label': 'Ryan', 'value': 17.0},
            {'label': 'Yang', 'value': 18.0},
            {'label': 'Other', 'value': 19.0},
            {'label': 'Refused', 'value': 20.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Question 3', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Q3',
        options=[
            {'label': '1', 'value': 1.0},
            {'label': '2', 'value': 2.0},
            {'label': '3', 'value': 3.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    
    html.Div('Question 4', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Q4',
        options=[
            {'label': '1', 'value': 1.0},
            {'label': '2', 'value': 2.0},
            {'label': '3', 'value': 3.0},
            {'label': '4', 'value': 4.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Question 5', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Q5',
        options=[
            {'label': 'Democratic candidate', 'value': 1.0},
            {'label': 'Trump', 'value': 2.0},
            {'label': 'Would not vote', 'value': 3.0},
            {'label': 'Refused', 'value': 4.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Div('Question 8', style = {'fontSize': '25px'}),
    dcc.Dropdown(
        id='Q8',
        options=[
            {'label': '1', 'value': 1.0},
            {'label': '2', 'value': 2.0},
            {'label': '3', 'value': 3.0},
        ],
        value= 1.0,
        style={'height': '40px', 'width': '700px', 'fontSize': '20px', 'width': '49%', 'display': 'inline-block'},
    ),
    
    html.Button('Submit', id='button'),
    
    dcc.Graph(id='regression-example',
            figure={
            'data': [
                {
                    'x': np.linspace(0.5, 4.5, 100),
                    'y': 2.0,
                    'text': ['a', 'b', 'c', 'd'],
                    'customdata': ['c.a', 'c.b', 'c.c', 'c.d'],
                    'name': 'True label (Student Response)',
                    'mode': 'lines',
                    'marker': {'size': 12}
                },
            ],
            'layout': {'title': 'Do you have a favorable or unfavorable opinion of The Democratic Party?', 
                      'xaxis':{'title':'Student ID'},
                       'yaxis':{'title':'Level of Unfavorability'},
                       'style':{'width': '49%', 'display': 'inline-block'},
                      }})
])

Finally, the "Callbacks" section defines what happens when the user controls are interacted with (sample size, noise, and bias sliders). If the user controls are changed in any way, new data will be generated based on their values and the figure will be updated automatically. Plotly Dash does this using a function decorator called @app.callback. It's important to remember that the inputs specified in the decorator will be in the same order as the parameters to the function.

In [13]:
###-----------------------------------------------------------------------###
###----------------------------CALLBACKS----------------------------------###
### This section defines the behavior of the GUI as the user interacts    ###
### with the controls.                                                  ###
###-----------------------------------------------------------------------###
# Define which inputs (the sliders) go to which outputs (the figure)

@app.callback(
    Output('regression-example', 'figure'),
    [Input('button', 'n_clicks')],
    [State('Major', 'value'),
     State('Gender', 'value'),
     State('Orientation', 'value'),
     State('Geography', 'value'),
     State('Q2', 'value'),
     State('Q3', 'value'),
     State('Q4', 'value'),
     State('Q5', 'value'),
     State('Q8', 'value')])

def update_graph(n_clicks,major,gender,orientation,geography,q2,q3,q4,q5,q8):
    new_data_point = [major,gender,orientation,geography,q2,q3,q4,q5,q8]   
    
    new_data_pd = pd.DataFrame(columns=['major','gender','orientation','geography','q2', 'q3', 'q4', 'q5', 'q8'])
    new_data_pd.loc[0] = new_data_point
    enc1 = OneHotEncoder()
    enc1.fit(X_features)
    new_data_pd_onehot = enc1.transform(new_data_pd).toarray()
    
    gpr_pred_newdp_d, sigma_newdp_d = gpr_d.predict(new_data_pd_onehot, return_std = True)
    gpr_pred_newdp_r, sigma_newdp_r = gpr_r.predict(new_data_pd_onehot, return_std = True)
    
    mu_d = gpr_pred_newdp_d
    variance_d = np.sqrt(sigma_newdp_d)
    mu_r = gpr_pred_newdp_r
    variance_r = np.sqrt(sigma_newdp_r)
    
    # sigma = np.sqrt(variance)
    x = np.linspace(0.5, 4.5, 100)
    
    return {
            'data': [
                {
                    'x': np.linspace(0.5, 4.5, 100),
                    'y': scipy.stats.norm.pdf(x, mu_d, variance_d),
                    'text': ['a', 'b', 'c', 'd'],
                    'customdata': ['c.a', 'c.b', 'c.c', 'c.d'],
                    'name': 'Democratic party favorability',
                    'mode': 'lines',
                    'marker': {'size': 100}
                },
                {
                    'x': np.linspace(0.5, 4.5, 100),
                    'y': scipy.stats.norm.pdf(x, mu_r, variance_r),
                    'text': ['w', 'x', 'y', 'z'],
                    'customdata': ['c.w', 'c.x', 'c.y', 'c.z'],
                    'name': 'Republican party favorability',
                    'mode': 'lines',
                    'marker': {'size': 100}
                }
            ],
#             'layout': {'title': 'Party Favorability?', 
#                       'xaxis':{'title':'Student ID'},
#                        'yaxis':{'title':'Level of Unfavorability'}
#                       }
            'layout': go.Layout(
            width=900,
            height=600,
            xaxis={
                'title': 'Party Favorability?',
                "titlefont": {"size": 20},
                "tickfont": {"size": 20}
                #'type': 'linear' if xaxis_type == 'Linear' else 'log'
            },
            yaxis={
                'title': 'Level of Unfavorability',
                "titlefont": {"size": 20},
                "tickfont": {"size": 20}
                #'type': 'linear' if yaxis_type == 'Linear' else 'log'
            },
            margin={'l': 100, 'b': 100, 't': 10, 'r': 0},
            
            hovermode='closest',
            # style={'display': 'inline-block'}
            style={'display': 'inline-block'},
            # 'width': '100%', 
        )
    
    
    }
            

if __name__ == '__main__':
    app.run_server()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [08/Jul/2019 23:32:34] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jul/2019 23:32:34] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
127.0.0.1 - - [08/Jul/2019 23:32:34] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -

The handling of integer data will change in version 0.22. Currently, the categories are determined based on the range [0, max(values)], while in the future they will be determined based on the unique values.
In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.

[2019-07-08 23:32:35,410] ERROR in app: Exception on /_dash-update-component [POST]
Traceback (most recent call last):
  File "/Users/zhucheng/anaconda3/lib/python3.7/site-packages/flask/app.py", line 2292, in wsgi_app
    response = self.full_dispatch_request()
  File "/Users/zhucheng/anaconda3/lib/python3.7/site-packages/flask/app.py", line 18

127.0.0.1 - - [08/Jul/2019 23:32:35] "[1m[35mPOST /_dash-update-component HTTP/1.1[0m" 500 -


Overall, the Plotly Dash code is very easy to understand and clear to read. Using a function decorator for callbacks provides a simple way to update the plots.

The final visualization will look like this: