In [None]:
import pandas as pd
import numpy as np
import re
from google.cloud.bigquery import magics
from google.cloud import bigquery
from urllib.parse import urlparse

pd.set_option('display.max_columns', 50)
pd.set_option('display.max_rows', 200)

magics.context.project = 'wandb-production'
magics.context.progress_bar_type = 'None'
%load_ext google.cloud.bigquery

bqclient = bigquery.Client()

import plotly.express as px

import plotly.graph_objects as go # or plotly.express as px
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import json

import wandb

In [None]:
# run = wandb.init(project = 'state-machine', entity = 'mercedes-wu')

___
Queries to get relevant data regarding paid and not paid users
___

In [None]:
sankey_q = """
-- not all users have a first_telemetry at, better to not use it?
-- daily user agg can have activity before the first_run_at time?
with user_attributes as (
    select
        universal_user_id,
        first_run_at,
        hosting_type,
        is_paid
    from
        analytics.dim_users
    where 1=1
        and is_dev is not true
        and is_paid is not null --removing legacy local users from analysis
),
generated_date_table as (
    SELECT
        extract(year from date_gen_table) as year,
        extract(month from date_gen_table) as month,
        extract(isoweek from date_gen_table) as workweek_min_activity
    FROM
        UNNEST(GENERATE_DATE_ARRAY('2022-01-01', CURRENT_DATE(), INTERVAL 1 DAY)) AS date_gen_table
    where 1=1
        and extract(isoweek from date_gen_table) != 52 -- exclude days in january that dont belong to the first work week in a year
        and extract(isoweek from date_gen_table) + 8 <= extract(isoweek from CURRENT_DATE())
    group by
        extract(year from date_gen_table),
        extract(month from date_gen_table),
        extract(isoweek from date_gen_table)
),
first_agg_daily_appearance as (
    select
        universal_user_id,
        min(activity_day) as min_activity_day,
        extract(isoweek from min(activity_day)) as workweek_min_activity,
        extract(isoweek from min(activity_day)) + 8 as workweek_max_activity
    from
        analytics.agg_daily_user_activity
    group by
        universal_user_id
),
user_min_max_activity as (
    select
        first_agg_daily_appearance.*
    from
        first_agg_daily_appearance
    join
        generated_date_table using (workweek_min_activity)
),
week_product_created_agg as (
    select
        agg_daily.universal_user_id,
        user_attributes.hosting_type,
        user_attributes.is_paid,
        user_min_max_activity.workweek_min_activity,
        user_min_max_activity.workweek_max_activity,
        extract(isoweek from agg_daily.activity_day) as activity_workweek,
        extract(isoweek from agg_daily.activity_day) - user_min_max_activity.workweek_min_activity as weeks_away_from_signup,
        max(case when COALESCE(agg_daily.run_created, 0) > 0 then 1 else 0 end) run_created,
        max(case when COALESCE(agg_daily.artifact_created, 0) > 0 then 1 else 0 end) artifact_created,
        max(case when COALESCE(agg_daily.report_created, 0) > 0 then 1 else 0 end) report_created,
        max(case when COALESCE(agg_daily.weave_table_created, 0) > 0 then 1 else 0 end) weave_table_created,
        max(case when COALESCE(agg_daily.sweep_created, 0) > 0 then 1 else 0 end) sweep_created,
        max(case when COALESCE(agg_daily.custom_chart_created, 0) > 0 then 1 else 0 end) custom_chart_created,
        max(case when COALESCE(agg_daily.comment_created, 0) > 0 then 1 else 0 end) comment_created,
        (
            max(case when COALESCE(agg_daily.run_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.artifact_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.report_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.weave_table_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.sweep_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.custom_chart_created, 0) > 0 then 1 else 0 end) +
            max(case when COALESCE(agg_daily.comment_created, 0) > 0 then 1 else 0 end)
            ) as count_products_created,
    from
        analytics.agg_daily_user_activity as agg_daily
    join
        user_attributes using (universal_user_id)
    join
        user_min_max_activity using (universal_user_id)
    where 1=1
        and extract(isoweek from agg_daily.activity_day) <= workweek_max_activity
        and extract(isoweek from agg_daily.activity_day) >= workweek_min_activity
    group by
        agg_daily.universal_user_id,
        user_attributes.hosting_type,
        user_attributes.is_paid,
        user_min_max_activity.workweek_min_activity,
        user_min_max_activity.workweek_max_activity,
        extract(isoweek from agg_daily.activity_day),
        extract(isoweek from agg_daily.activity_day) - user_min_max_activity.workweek_min_activity
),

sankey_agg as (
    select
        week_product_created_agg.*,
        case
            when count_products_created = 0 then 'zero products'
            when count_products_created = 1 then (
                case
                    when run_created = 1 then 'run'
                    when sweep_created = 1 then 'sweep'
                    when weave_table_created = 1 then 'table'
                    when artifact_created = 1 then 'artifact'
                    when report_created = 1 then 'report'
                    when custom_chart_created = 1 then 'chart'
                    when comment_created = 1 then 'comment'
                    else 'DATA CHECK NEEDED: one product created that was not a run'
                end
                )
            when count_products_created = 2 then (
                case
                    when run_created = 1 and sweep_created = 1 then 'run+sweep'
                    when run_created = 1 and weave_table_created = 1 then 'run+table'
                    when run_created = 1 and artifact_created = 1 then 'run+artifact'
                    when run_created = 1 and report_created = 1 then 'run+report'
                    else '2 products misc'
                end
                )
            when count_products_created = 3 then '3 products'
            when count_products_created >= 4 then '4+ products'
            else 'DATA CHECK NEEDED: case statement condition not fufilled'
        end as sankey_buckets
    from
        week_product_created_agg
    where 1=1
    order by
        universal_user_id,
        weeks_away_from_signup
)
select
    *
from
    sankey_agg
where 1=1
"""

job_config = bigquery.QueryJobConfig(
    query_parameters = [
    ]
)

sankey_df = (
    bqclient.query(sankey_q, job_config = job_config)
    .result()
    .to_dataframe()
)

In [None]:
sankey_df['hosting_type'].value_counts()

In [None]:
sankey_df['is_paid'].unique()

In [None]:
wandb.Table.MAX_ROWS = 500000
sankey_table = wandb.Table(dataframe = sankey_df)
sankey_artifact = wandb.Artifact('sankey_table', type = 'dataset')
sankey_artifact.add(sankey_table, "sankey_table")
log_artifact = run.log_artifact(sankey_artifact)

In [None]:
# run.finish()

___
helper functions to generate sankey plots from query data
___

In [None]:
def sankey_preprocessing(sankey_piv, sub_sankey_piv, column_progression):
    df = sankey_piv.copy()

    for c in column_progression:
        df[c] = df[c] + f' ({int(c)})'

    order_x_nodes = column_progression.copy()

    order_x_dict = {order_x_nodes[i]: (i + 1) * 1/(len(order_x_nodes)) for i in range(len(order_x_nodes))}
    
    order_y_nodes = [
        '4+ products',
        '3 products',
        '2 products misc', 'run+report', 'run+artifact', 'run+table', 'run+sweep',
        'comment', 'chart', 'report', 'artifact', 'table', 'sweep', 'run', 
        'zero products'
    ]

    y_pos_manual = [
        0.01,
        0.05,
        0.1,
        0.125,
        0.2,
        0.25,
        0.3,
        0.325,
        0.35,
        0.365,
        0.375,
        0.415,
        0.40,
        0.55,
        0.99
    ]

    order_y_dict = dict(zip(order_y_nodes, y_pos_manual))

    labels = []
    sources = []
    targets = []
    values = []
    
    # TODO: Refactor
    if not sub_sankey_piv.empty:
        filt_df = sub_sankey_piv 
        
    for i in range(len(column_progression) - 1):

        source_col, target_col = column_progression[i], column_progression[i + 1]
        temp_gb = df.groupby([source_col, target_col]).count()[['universal_user_id']].reset_index()
        
        # creating list of unique labels to use for indexing sources and targets
        if i == 0:
            labels += list(temp_gb[source_col].unique())
        labels += list(temp_gb[target_col].unique())

        # adding sources, targets, and values for sankey link data
        sources += list(temp_gb[source_col])
        targets += list(temp_gb[target_col])
        values += list(temp_gb['universal_user_id'])
        

    labels_dict = dict(zip(labels, range(len(labels))))
    sources_mapped = [labels_dict[val] for val in sources]
    targets_mapped = [labels_dict[val] for val in targets]
    link_labels = [s + ' --> ' + t for s, t in zip(sources, targets)]

    node_color_dict = {
        '4+ products'        : 'rgba(255,165,0,0.8)',
        '3 products'         : 'rgba(255,215,0,0.8)',
        '2 products misc'    : 'rgba(128,128,0,0.8)',
        'run+report'         : 'rgba(154,205,50,0.8)',
        'run+artifact'       : 'rgba(102,205,170,0.8)',
        'run+table'          : 'rgba(60,179,113,0.8)',
        'run+sweep'          : 'rgba(0,206,209,0.8)',
        'comment'            : 'rgba(176,224,230,0.8)',
        'chart'              : 'rgba(173,216,230,0.8)',
        'report'             : 'rgba(95,158,160,0.8)',
        'artifact'           : 'rgba(65,105,225,0.8)',
        'table'              : 'rgba(139,0,139,0.8)',
        'sweep'              : 'rgba(148,0,211,0.8)',
        'run'                : 'rgba(25,25,112,0.8)',
        'zero products'      : 'rgba(105,105,105,0.8)'
    }

    link_color_dict = {
        '4+ products'        : 'rgba(255,165,0,0.3)',
        '3 products'         : 'rgba(255,215,0,0.3)',
        '2 products misc'    : 'rgba(128,128,0,0.3)',
        'run+report'         : 'rgba(154,205,50,0.3)',
        'run+artifact'       : 'rgba(102,205,170,0.3)',
        'run+table'          : 'rgba(60,179,113,0.3)',
        'run+sweep'          : 'rgba(0,206,209,0.3)',
        'comment'            : 'rgba(176,224,230,0.3)',
        'chart'              : 'rgba(173,216,230,0.3)',
        'report'             : 'rgba(95,158,160,0.3)',
        'artifact'           : 'rgba(65,105,225,0.3)',
        'table'              : 'rgba(139,0,139,0.3)',
        'sweep'              : 'rgba(148,0,211,0.3)',
        'run'                : 'rgba(25,25,112,0.3)',
        'zero products'      : 'rgba(105,105,105,0.3)'
    }

    node_xpos = []
    node_ypos = []
    node_colors = []
    for k in labels_dict.keys():
        k_y, k_x = re.split('[()]', k)[0].rstrip(), int(re.split('[()]', k)[1])
        node_ypos.append(order_y_dict[k_y])
        node_xpos.append(order_x_dict[k_x])
        node_colors.append(node_color_dict[k_y])

    link_colors = []
    for source_link in sources:
        source_link_stripped = re.split('[()]', source_link)[0].rstrip()
        link_colors.append(link_color_dict[source_link_stripped])
        
    return labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors

In [None]:
def sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors):
    sankey_fig = go.Figure(
        data=[
            go.Sankey(
                arrangement = 'fixed',
                node = dict(
                    pad = 10,
                    thickness = 10,
                    line = dict(color = "black", width = 0.5),
                    label = labels,
                    color = node_colors,
                    x = node_xpos,
                    y = node_ypos,
                    hovertemplate = labels
                ),
                link = dict(
                    source = sources_mapped,
                    target = targets_mapped,
                    value = values,
                    color = link_colors,
                    label = link_labels,
                    hovertemplate = link_labels
                )
            )
        ]
    )
    sankey_fig.update_traces(orientation = 'h')

    sankey_fig.update_layout(
        font_size = 9,
        autosize = False,
        height = 850,
        width = 1400,
    )

    return sankey_fig

___
test plotting
___

In [None]:
column_progression = sankey_df['weeks_away_from_signup'].unique()
sankey_piv = sankey_df.pivot(index=['universal_user_id', 'hosting_type', 'is_paid', 'workweek_min_activity'], columns='weeks_away_from_signup', values='sankey_buckets').reset_index()
labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors = sankey_preprocessing(sankey_piv, pd.DataFrame(), column_progression)

In [None]:
run.finish()

In [None]:
run = wandb.init(project = 'state-machine', entity = 'mercedes-wu')

In [None]:
sankey_piv = sankey_piv.reset_index().rename_axis(None, axis=1)

In [None]:
sankey_piv = sankey_piv.rename(columns = {
    0: 'week_0',
    1: 'week_1',
    2: 'week_2',
    3: 'week_3',
    4: 'week_4',
    5: 'week_5',
    6: 'week_6',
    7: 'week_7',
    8: 'week_8',
})

In [None]:
sankey_piv = sankey_piv.drop(columns = ['index'])

In [None]:
sankey_table = wandb.Table(dataframe = sankey_piv)
sankey_artifact = wandb.Artifact('sankey_pivot_table', type = 'dataset')
sankey_artifact.add(sankey_table, "sankey_pivot_table")
log_artifact = run.log_artifact(sankey_artifact)

In [None]:
sankey_piv

In [None]:
run.finish()

run = wandb.init(project = 'state-machine', entity = 'mercedes-wu')

In [None]:
sankey_fig = sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors)
sankey_fig.show()

In [None]:
wandb.log({'state_sankey_aggregated': sankey_fig})

In [None]:
merge_to_df = pd.DataFrame([{
    '4+ products'        : 'rgba(255,165,0,0.8)',
    '3 products'         : 'rgba(255,215,0,0.8)',
    '2 products misc'    : 'rgba(128,128,0,0.8)',
    'run+report'         : 'rgba(154,205,50,0.8)',
    'run+artifact'       : 'rgba(102,205,170,0.8)',
    'run+table'          : 'rgba(60,179,113,0.8)',
    'run+sweep'          : 'rgba(0,206,209,0.8)',
    'comment'            : 'rgba(176,224,230,0.8)',
    'chart'              : 'rgba(173,216,230,0.8)',
    'report'             : 'rgba(95,158,160,0.8)',
    'artifact'           : 'rgba(65,105,225,0.8)',
    'table'              : 'rgba(139,0,139,0.8)',
    'sweep'              : 'rgba(148,0,211,0.8)',
    'run'                : 'rgba(25,25,112,0.8)',
    'zero products'      : 'rgba(105,105,105,0.8)'
}]).T.reset_index().rename(columns = {'index': 'state', 0: 'color'})

In [None]:
states = [
    '4+ products'    
    ,'3 products'     
    ,'2 products misc'
    ,'run+report'     
    ,'run+artifact'   
    ,'run+table'      
    ,'run+sweep'      
    ,'comment'        
    ,'chart'          
    ,'report'         
    ,'artifact'       
    ,'table'          
    ,'sweep'          
    ,'run'            
    ,'zero products'  
]
states.reverse()

In [None]:
merge_to_df.loc[:, 'order'] = merge_to_df.index

In [None]:
bar_chart_x_axis = ['week_0', 'week_1', 'week_2', 'week_3', 'week_4', 'week_5', 'week_6', 'week_7', 'week_8']

In [None]:
dfs = []
for c in bar_chart_x_axis:
    temp_df = sankey_piv.copy()[['universal_user_id', c]].groupby(c).count().reset_index().rename(columns = {c: 'state', 'universal_user_id': 'count_users'})
    temp_merge_df = merge_to_df.copy().merge(temp_df, on = 'state', how = 'left').replace(np.NaN, 0)
    temp_merge_df.loc[:, 'week'] = c
    dfs.append(temp_merge_df)


In [None]:
stacked_bar_chart_df = pd.concat(dfs, axis=0)

In [None]:

fig = go.Figure(data=[
    go.Bar(
        x = [week.replace('_', ' ') for week in bar_chart_x_axis], 
        y = stacked_bar_chart_df[stacked_bar_chart_df['state']==s]['count_users'],
        name = s, 
        marker_color = stacked_bar_chart_df[stacked_bar_chart_df['state']==s]['color'],
    ) for s in states  
])
# Change the bar mode
fig.update_layout(barmode='stack')
fig.update_layout(
    font_size = 9,
    autosize = False,
    height = 850,
    width = 1400,
    template = 'plotly_white',
)
fig.show()

In [None]:
wandb.log({'state_stacked_bar_chart': fig})

In [None]:
week0_count_users_df = stacked_bar_chart_df.query('week == "week_0"')[['state', 'count_users']].rename(columns = {'count_users': 'count_users_week0'})

In [None]:
perc_line_chart_df = stacked_bar_chart_df.copy().merge(week0_count_users_df, on = 'state', how = 'left')

In [None]:
perc_line_chart_df.loc[:, 'user_perc_of_week0'] = np.round(perc_line_chart_df['count_users'] / perc_line_chart_df['count_users_week0'], 2)

In [None]:

fig = go.Figure(data=[
    go.Scatter(
        x = [week.replace('_', ' ') for week in bar_chart_x_axis], 
        y = perc_line_chart_df[perc_line_chart_df['state']==s]['user_perc_of_week0'],
        name = s, 
        mode = 'lines',
        line = dict(color = perc_line_chart_df[perc_line_chart_df['state']==s]['color'].iloc[0]),
    ) for s in states  
])
# Change the bar mode
fig.update_layout(
    font_size = 9,
    autosize = False,
    height = 850,
    width = 1400,
    template = 'plotly_white',
    hovermode = 'x'
)

fig.update_yaxes(
    range = [0,3],
    tickformat = ',.0%'
)
fig.show()

In [None]:
run.finish()

run = wandb.init(project = 'state-machine', entity = 'mercedes-wu')

In [None]:
wandb.log({'perc_week0_activity': fig})

TODO:
    - exclude controller runs

In [None]:
sankey_piv[sankey_piv[8] == '4+ products']

___
volatility analysis?
___

___
modeling (where sequence does not matter)
- testing naive-bayes
___

In [None]:
import numpy as np
rng = np.random.RandomState(1)
X = rng.randint(5, size=(6, 100))
y = np.array([1, 2, 3, 4, 5, 6])
from sklearn.naive_bayes import CategoricalNB
clf = CategoricalNB()
clf.fit(X, y)
print(clf.predict(X[2:3]))

In [None]:
pred_df = sankey_piv.copy()

In [None]:
import matplotlib.pyplot as plt
pred_df.hist()

___
testing cpt
___

In [None]:
# import plotly.graph_objects as go

# fig = go.Figure(data=[go.Sankey(
#     node = dict(
#       pad = 15,
#       thickness = 20,
#       line = dict(color = "black", width = 0.5),
#       label = ["A1", "A2", "B1", "B2", "C1", "C2"],
#       color = "blue",
#         hoverinfo='skip'
#     ),
#     link = dict(
#       source = [0, 0, 1, 0,], # indices correspond to labels, eg A1, A2, A1, B1, ...
#       target = [0, 2, 3, 3,],
#       value = [1000, 8, 4, 2,],
#       color = ['rgba(255,165,0,0)', 'rgba(255,165,0,0.3)', 'rgba(255,165,0,0.3)', 'rgba(255,165,0,0.3)'],
#         hoverinfo='skip'
#   ))])

# fig.update_layout(title_text="Basic Sankey Diagram", font_size=10, hovermode=False)
# fig.show()

___
breakdown by paid type and user type
___

In [None]:
def create_sankey(sankey_df):
    column_progression = sankey_df['weeks_away_from_signup'].unique()
    sankey_piv = sankey_df.pivot(index=['universal_user_id', 'hosting_type', 'is_paid', 'workweek_min_activity'], columns='weeks_away_from_signup', values='sankey_buckets').reset_index()
    labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors = sankey_preprocessing(sankey_piv, pd.DataFrame(), column_progression)
    sankey_fig = sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors)
    sankey_fig.show()
    
    return sankey_fig

In [None]:
sankey_df['is_paid'].unique()

In [None]:
create_sankey(sankey_df.query('is_paid == True'))

In [None]:
create_sankey(sankey_df.query('hosting_type == "local"'))

In [None]:
create_sankey(sankey_df.query('is_paid == False'))

___
test click events
___


> test subsankey

In [None]:
# if click is on node
click_label = "4+ products (8)"

In [None]:
' --> ' in click_label

In [None]:
click_label.split(' --> ')

In [None]:
sankey_piv = sankey_df.pivot(index='universal_user_id', columns='weeks_away_from_signup', values='sankey_buckets').reset_index()

In [None]:
product_key, column_key = re.split('[()]', click_label)[0].rstrip(), int(re.split('[()]', click_label)[1])

In [None]:
filt_sankey_piv = sankey_piv[sankey_piv[column_key] == product_key] #TODO: refactor using query

In [None]:
labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors = sankey_preprocessing(filt_sankey_piv, filt_sankey_piv, column_progression)
filt_sankey_fig = sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors)
filt_sankey_fig.show()

> test tooltip

In [None]:
click_label = "zero products (-3)"
if ' --> ' in click_label:
    source_label, target_label = click_label.split(' --> ')
    source_product_key, source_column_key = re.split('[()]', source_label)[0].rstrip(), int(re.split('[()]', source_label)[1])
    target_product_key, target_column_key = re.split('[()]', target_label)[0].rstrip(), int(re.split('[()]', target_label)[1])
    filt_sankey_piv = sankey_piv[
        (sankey_piv[source_column_key] == source_product_key) & 
        (sankey_piv[target_column_key] == target_product_key) #TODO: refactor using query#TODO: refactor using query
    ]

    fig = go.Figure(
        [
            go.Bar(
                x=[click_label], 
                y=[len(filt_sankey_piv)],
                text=[len(filt_sankey_piv)]
            )
        ]
    ) 
else:
    product_key, column_key = re.split('[()]', click_label)[0].rstrip(), int(re.split('[()]', click_label)[1])
    filt_sankey_piv = sankey_piv[sankey_piv[column_key] == product_key] #TODO: refactor using query
    if column_key == min(column_progression):
        columns_to_keep = [column_key, column_key + 1]
    elif column_key == max(column_progression):
        columns_to_keep = [column_key - 1, column_key]
    else:
        columns_to_keep = [column_key - 1, column_key, column_key + 1]
    path_sankey_node = filt_sankey_piv[['universal_user_id'] + columns_to_keep].rename_axis(None, axis=1)
    path_sankey_node.loc[:, 'path'] = path_sankey_node[columns_to_keep].apply(lambda x: ' --> '.join(x), axis=1)
    path_gb = path_sankey_node[['universal_user_id', 'path']].groupby('path').count().reset_index().sort_values('path')
    path_gb.loc[:, 'click_label'] = click_label
    fig = px.bar(
        path_gb, 
        x='click_label', 
        y='universal_user_id',
        color='path',
        text_auto=True
    )


In [None]:
path_gb

In [None]:
path_sankey_node.name = None

In [None]:
path_sankey_node.rename_axis(None, axis=1)

___
dash app testing
___

In [None]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(external_stylesheets=external_stylesheets)

styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll'
    }
}

column_progression = sankey_df['weeks_away_from_paid'].unique()
sankey_piv = sankey_df.pivot(index='universal_user_id', columns='weeks_away_from_paid', values='sankey_buckets').reset_index()
labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors = sankey_preprocessing(sankey_piv, sankey_piv, column_progression)
sankey_fig = sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors)

app.layout = html.Div(
    [
        html.Div(
            className='row',
            children=[
                dcc.Graph(
                    id='global-sankey',
                    figure=sankey_fig,
                    style={'width': '80%', 'display': 'inline-block'},
                ),
                dcc.Graph(
                    id='global-path',
                    style={'width': '20%', 'display': 'inline-block'},
                )
            ],
            
        ),
        
        html.Div(
            className='row', 
            children=[
                html.Div(
                    [
                        dcc.Markdown(
                            """
                            **Click Data**
                            Click on points in the graph.
                            """
                        ),
                        html.Pre(
                            id='click-data', 
                            style=styles['pre']
                        ),
                    ], 
                    className='three columns'
                ),
            ]
        ),
        html.Div(
            dcc.Graph(id = 'filter-sankey')
        )
    ]
)

@app.callback(
    Output('click-data', 'children'),
    [Input('global-sankey', 'clickData')])
def display_click_data(clickData):
    label = clickData['points'][0]['label']
    return label

@app.callback(
    Output('global-path', 'figure'),
    [Input('global-sankey', 'clickData')])
def create_global_flow(clickData):
    click_label = clickData['points'][0]['label']
    if ' --> ' in click_label:
        source_label, target_label = click_label.split(' --> ')
        source_product_key, source_column_key = re.split('[()]', source_label)[0].rstrip(), int(re.split('[()]', source_label)[1])
        target_product_key, target_column_key = re.split('[()]', target_label)[0].rstrip(), int(re.split('[()]', target_label)[1])
        filt_sankey_piv = sankey_piv[
            (sankey_piv[source_column_key] == source_product_key) & 
            (sankey_piv[target_column_key] == target_product_key) #TODO: refactor using query#TODO: refactor using query
        ]

        bar_fig = go.Figure(
            [
                go.Bar(
                    x=[click_label], 
                    y=[len(filt_sankey_piv)],
                    text=[len(filt_sankey_piv)]
                )
            ]
        ) 
    else:
        product_key, column_key = re.split('[()]', click_label)[0].rstrip(), int(re.split('[()]', click_label)[1])
        filt_sankey_piv = sankey_piv[sankey_piv[column_key] == product_key] #TODO: refactor using query
        if column_key == min(column_progression):
            columns_to_keep = [column_key, column_key + 1]
        elif column_key == max(column_progression):
            columns_to_keep = [column_key - 1, column_key]
        else:
            columns_to_keep = [column_key - 1, column_key, column_key + 1]
        path_sankey_node = filt_sankey_piv[['universal_user_id'] + columns_to_keep].rename_axis(None, axis=1)
        path_sankey_node.loc[:, 'path'] = path_sankey_node[columns_to_keep].apply(lambda x: ' --> '.join(x), axis=1)
        path_gb = path_sankey_node[['universal_user_id', 'path']].groupby('path').count().reset_index().sort_values('path')
        path_gb.loc[:, 'click_label'] = click_label
        bar_fig = px.bar(
            path_gb, 
            x='click_label', 
            y='universal_user_id',
            color='path',
            text_auto=True
        )
    bar_fig.update_layout(
        font_size = 9,
        autosize = False,
        height = 850,
        width = 300,
        showlegend = False,
        template = 'plotly_white',
        yaxis={'visible': False, 'showticklabels': False},
        xaxis={'visible': False, 'showticklabels': False}
    )
    return bar_fig

@app.callback(
    Output('filter-sankey', 'figure'),
    [Input('global-sankey', 'clickData')])
def create_subsankey(clickData):
    click_label = clickData['points'][0]['label']
    if ' --> ' in click_label:
        source_label, target_label = click_label.split(' --> ')
        source_product_key, source_column_key = re.split('[()]', source_label)[0].rstrip(), int(re.split('[()]', source_label)[1])
        target_product_key, target_column_key = re.split('[()]', target_label)[0].rstrip(), int(re.split('[()]', target_label)[1])
        filt_sankey_piv = sankey_piv[
            (sankey_piv[source_column_key] == source_product_key) & 
            (sankey_piv[target_column_key] == target_product_key) #TODO: refactor using query#TODO: refactor using query
        ]
    else:
        product_key, column_key = re.split('[()]', click_label)[0].rstrip(), int(re.split('[()]', click_label)[1])
        filt_sankey_piv = sankey_piv[sankey_piv[column_key] == product_key] #TODO: refactor using query
    labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors = sankey_preprocessing(filt_sankey_piv, filt_sankey_piv, column_progression)
    filt_sankey_fig = sankey_fig_helper(labels, node_colors, node_xpos, node_ypos, sources_mapped, targets_mapped, link_labels, values, link_colors)
    return filt_sankey_fig


app.run_server(debug=False, use_reloader=False) # Turn off reloader if inside Jupyter

TODO:


In [None]:
sankey_fig