In [1]:
# imports
import pandas as pd
import plotly.graph_objects as go

# to make notebook work offline
from plotly.offline import iplot, init_notebook_mode
init_notebook_mode(connected=True)

In [2]:
df = pd.read_csv('Superstore Decision Tree Data.csv')
df

Unnamed: 0,Row ID,Order Priority,Discount,Unit Price,Shipping Cost,Customer ID,Customer Name,Ship Mode,Customer Segment,Product Category,...,State or Province,City,Postal Code,Order Date,Ship Date,Profit,Quantity ordered new,Sales,Order ID,Link
0,18606,Not Specified,0.01,2.88,0.50,2,Janice Fletcher,Regular Air,Corporate,Office Supplies,...,Illinois,Addison,60101,5/28/2012,5/30/2012,1.320000,2,5.90,88525,Link
1,20847,High,0.01,2.84,0.93,3,Bonnie Potter,Express Air,Corporate,Office Supplies,...,Washington,Anacortes,98221,7/7/2010,7/8/2010,4.560000,4,13.01,88522,Link
2,23086,Not Specified,0.03,6.68,6.15,3,Bonnie Potter,Express Air,Corporate,Office Supplies,...,Washington,Anacortes,98221,7/27/2011,7/28/2011,-47.640000,7,49.92,88523,Link
3,23087,Not Specified,0.01,5.68,3.60,3,Bonnie Potter,Regular Air,Corporate,Office Supplies,...,Washington,Anacortes,98221,7/27/2011,7/28/2011,-30.510000,7,41.64,88523,Link
4,23088,Not Specified,0.00,205.99,2.50,3,Bonnie Potter,Express Air,Corporate,Technology,...,Washington,Anacortes,98221,7/27/2011,7/27/2011,998.202300,8,1446.67,88523,Link
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9421,20275,Critical,0.06,35.89,14.72,3402,Frederick Cole,Regular Air,Consumer,Office Supplies,...,West Virginia,Charleston,25314,5/14/2013,5/15/2013,137.860000,13,447.87,87532,Link
9422,20276,Critical,0.00,3.34,7.49,3402,Frederick Cole,Regular Air,Consumer,Office Supplies,...,West Virginia,Charleston,25314,5/14/2013,5/14/2013,-39.070000,3,13.23,87532,Link
9423,24491,Not Specified,0.08,550.98,45.70,3402,Frederick Cole,Delivery Truck,Consumer,Furniture,...,West Virginia,Charleston,25314,9/12/2013,9/14/2013,-1225.029097,4,2215.93,87533,Link
9424,25914,High,0.10,105.98,13.99,3403,Tammy Buckley,Express Air,Consumer,Furniture,...,Wyoming,Cheyenne,82001,2/8/2010,2/11/2010,349.485000,5,506.50,87530,Link


In [10]:
df.dtypes

Row ID                    int64
Order Priority           object
Discount                float64
Unit Price              float64
Shipping Cost           float64
Customer ID               int64
Customer Name            object
Ship Mode                object
Customer Segment         object
Product Category         object
Product Sub-Category     object
Product Container        object
Product Name             object
Product Base Margin     float64
Region                   object
State or Province        object
City                     object
Postal Code               int64
Order Date               object
Ship Date                object
Profit                  float64
Quantity ordered new      int64
Sales                   float64
Order ID                  int64
Link                     object
dtype: object

In [3]:
# Helper function to transform regular data to sankey format
# Returns data and layout as dictionary
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # maximum of 6 value cols -> 6 colors
    colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          ),
          label = labelList,
          color = colorList
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

In [13]:
# Generating regular sankey diagram
sank = genSankey(df,cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
fig = go.Figure(sank)
fig.show()

In [14]:
# Generating DFs for different filter options

# All Filter
all = genSankey(df,cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')

# Customer Segment
corporate = genSankey(df[df['Customer Segment']=='Corporate'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
home_office = genSankey(df[df['Customer Segment']=='Home Office'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
consumer = genSankey(df[df['Customer Segment']=='Consumer'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
small_business = genSankey(df[df['Customer Segment']=='Small Business'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')

# Region
central = genSankey(df[df['Region']=='Central'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
west = genSankey(df[df['Region']=='West'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
south = genSankey(df[df['Region']=='South'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
east = genSankey(df[df['Region']=='East'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')

# Product Category
technology = genSankey(df[df['Product Category']=='Technology'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
furniture = genSankey(df[df['Product Category']=='Furniture'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
os = genSankey(df[df['Product Category']=='Office Supplies'],cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')

In [15]:
# Constructing menus
updatemenus=[
        dict(
            type = "buttons",
            direction = "left",
            buttons=list([
                dict(
                    args=[all],
                    label="All",
                    method="animate"
                )
            ]),
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            y=1.1
        ),
        dict(
            buttons=list([
                dict(
                    args=[corporate],
                    label="Corporate",
                    method="animate"
                ),
                dict(
                    args=[home_office],
                    label="Home Office",
                    method="animate"
                ),
                dict(
                    args=[consumer],
                    label="Consumer",
                    method="animate"
                ),
                dict(
                    args=[small_business],
                    label="Small Business",
                    method="animate"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            y=1
        ),
        dict(
            buttons=list([
                dict(
                    args=[central],
                    label="Central",
                    method="animate"
                ),
                dict(
                    args=[west],
                    label="West",
                    method="animate"
                ),
                dict(
                    args=[south],
                    label="South",
                    method="animate"
                ),
                dict(
                    args=[east],
                    label="East",
                    method="animate"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            y=0.9
        ),
        dict(
            buttons=list([
                dict(
                    args=[technology],
                    label="Technology",
                    method="animate"
                ),
                dict(
                    args=[furniture],
                    label="Furniture",
                    method="animate"
                ),
                dict(
                    args=[os],
                    label="Office Supplies",
                    method="animate"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            y=0.8
        )
          ]

In [19]:
# update layout with buttons, and show the figure
sank = genSankey(df,cat_cols=['Customer Segment','Region','Product Category','Product Sub-Category'],value_cols='Sales',title='Sales Performances')
fig = go.Figure(sank)
fig.update_layout(updatemenus=updatemenus)
fig.show()