# libraries

In [1]:
import os

import pandas as pd
import numpy as np
import pickle
from IPython.display import display
from ipywidgets import interact, Dropdown
from floweaver import ProcessGroup, Waypoint, Partition, Bundle, SankeyDefinition, weave
from ipysankeywidget import SankeyWidget
from IPython.display import clear_output

# modules

## variables

In [2]:
data_dir = "data"
base_extensions = ["csv", "xlsx", "pickle"]

## functions

In [3]:
def extract_files(extensions=base_extensions):
    file_candidate = os.listdir(data_dir)
    file_list = [os.path.join(data_dir, file_path) for file_path in file_candidate
                 if sum([extension == file_path.split(".")[-1] for extension in extensions])]
    return file_list


def load_file(file_path, extensions=base_extensions):
    assert sum([extension == file_path.split(".")[-1] for extension in extensions]), "the extension of your file ({}) is not supported".format(file_path)
    extension = file_path.split(".")[-1]
    if "csv" == extension:
        data = pd.read_csv(file_path)
    elif "xlsx" == extension:
        data = pd.read_excel(file_path)
    elif "pickle" == extension:
        with open(file_path, "rb") as f:
            data = pickle.load(f)
    return data


def get_palette(y_values, yl):
    nan_color = 'gray'
    base_color = 'yellowgreen'
    palette = {value: nan_color for value in y_values}
    palette[yl] = base_color
    return palette


def concat_in_out_df(df, i, col, val):
    n = df[i].unique().shape[0]
    df = df[[i, col, val]].copy()
    add_df = pd.DataFrame()
    add_df[i] = np.arange(n)
    add_df[col] = "Start"
    add_df[val] = 'in'
    df = pd.concat([df, add_df])
    add_df = pd.DataFrame()
    add_df[i] = np.arange(n)
    add_df[col] = "End"
    add_df[val] = 'out'
    df = pd.concat([df, add_df]).reset_index(drop=True)
    return df


def get_node_order_bundle(df, i, x, y):
    df = df.copy()
    df = concat_in_out_df(df, i, x, y)
    df_piv = df.pivot_table(values=y, index=i, columns=x, aggfunc='sum').reset_index()
    x_values = ["Start"] + df_piv.columns.drop([i, "Start", "End"]).sort_values().tolist() + ["End"]
    assert len(x_values) > 3, "the cardinality of x need to be larger than 1"
    df_piv_count = df_piv.groupby(x_values)[i].count().reset_index().rename(columns={i: "value"})
    nodes = {}
    nodes[x_values[0]] = ProcessGroup(np.sort(df_piv_count[x_values[0]].unique()).tolist(), title=x_values[0])
    nodes[x_values[0]].partition = Partition.Simple("process", np.sort(df_piv_count[x_values[0]].unique()).tolist())
    nodes[x_values[-1]] = ProcessGroup(np.sort(df_piv_count[x_values[-1]].unique()).tolist(), title=x_values[-1])
    nodes[x_values[-1]].partition = Partition.Simple("process", np.sort(df_piv_count[x_values[-1]].unique()).tolist())
    for x_value in x_values[1:-1]:
        part = Partition.Simple(x_value, np.sort(df_piv_count[x_value].unique()).tolist())
        nodes[x_value] = Waypoint(part, title=x_value)

    ordering = [[x_value] for x_value in x_values]

    bundles = [Bundle(x_values[0], x_values[-1], waypoints=x_values[1:-1])]
    tmp_flows = df_piv_count.copy().rename(columns={x_values[0]: "source", x_values[-1]: "target"})
    return tmp_flows, nodes, bundles, ordering, df_piv

## class

In [4]:
class TsSankey(object):
    def __init__(self, multiple_display_widget, path_widget, index_widget, column_widget, value_widget, column_level_widget, value_level_widget):
        self.df = None
        self.flows = None
        self.column_values = None
        self.multiple_display_widget = multiple_display_widget
        self.path_widget = path_widget
        self.index_widget = index_widget
        self.column_widget = column_widget
        self.value_widget = value_widget
        self.column_level_widget = column_level_widget
        self.value_level_widget = value_level_widget
        self.sdd = None
        # display flag
        self.display_flag = False

    def on_path_update(self, change):
        # extract changed value
        file_path = change["new"]
        self.file_path = file_path

        # load dataframe
        self.df = load_file(file_path)
        self.index_widget.options = self.df.columns.tolist()
        self.index_widget.value = None
        # reset column level observation
        self.column_widget.unobserve(self.on_column_update, names="value")
        self.column_level_widget.unobserve(self.on_column_to_value_update, names="value")
        # set column list
        self.column_widget.options = self.df.columns.tolist()
        # reset default values
        self.column_widget.value = None
        self.column_level_widget.options = []
        self.column_level_widget.value = None

        # reset value level observation
        self.value_widget.unobserve(self.on_value_update, names="value")
        self.value_level_widget.unobserve(self.on_value_level_update, names="value")
        # set column list
        self.value_widget.options = self.df.columns.tolist()
        # reset default values
        self.value_widget.value = None
        self.value_level_widget.options = []
        self.value_level_widget.value = None

        # observe level observations
        self.column_widget.observe(self.on_column_update, names="value")
        self.value_widget.observe(self.on_value_update, names="value")
        self.column_level_widget.observe(self.on_column_to_value_update, names="value")
        self.value_level_widget.observe(self.on_value_level_update, names="value")

        # reset floweaver resources
        self.sdd = None
        self.flows = None
        self.palette = None

    def on_column_update(self, change):
        # extract changed value
        column_name = change["new"]
        # reset column level observation
        self.column_level_widget.unobserve(self.on_column_to_value_update, names="value")
        # reset value level observation
        self.value_level_widget.unobserve(self.on_value_level_update, names="value")
        # set column level options
        self.column_level_widget.options = np.sort(self.df[column_name].unique()).tolist()
        # reset column level value and value level variables
        self.column_level_widget.value = None
        self.value_level_widget.options = []
        self.value_level_widget.value = None
        # reset floweaver resources
        self.sdd = None
        # observe level observations
        self.column_level_widget.observe(self.on_column_to_value_update, names="value")
        self.value_level_widget.observe(self.on_value_level_update, names="value")

    def on_value_update(self, change):
        value_name = change["new"]
        if self.column_widget.value is not None and self.column_level_widget.value is not None:
            # reset value level observation
            self.value_level_widget.unobserve(self.on_value_level_update, names="value")
            # update　options
            options = np.sort(
                self.df[self.df[self.column_widget.value] == self.column_level_widget.value][value_name].unique()).tolist()
            if self.value_level_widget.value in options:
                self.sdd = self.get_sdd()
            else:
                self.value_level_widget.value = None
                self.sdd = None
            self.value_level_widget.options = options
            self.value_level_widget.observe(self.on_value_level_update, names="value")

    def on_column_to_value_update(self, change):
        column_level = change["new"]
        if self.column_widget.value is not None and self.value_widget.value is not None:
            # reset value level observation
            self.value_level_widget.unobserve(self.on_value_level_update, names="value")
            # update　options
            options = np.sort(
                self.df[self.df[self.column_widget.value] == column_level][self.value_widget.value].unique()).tolist()
            if self.value_level_widget.value in self.value_level_widget.options:
                self.sdd = self.get_sdd()
            else:
                self.sdd = None
            self.value_level_widget.options = options
            self.value_level_widget.value = None
            self.value_level_widget.observe(self.on_value_level_update, names="value")

    def on_value_level_update(self, change):
        value = change["new"]
        if self.df is not None:
            self.sdd = self.get_sdd()

    def on_clear_output(self, change):
        flag = change["new"]
        if self.sdd is not None:
            if flag == "Yes":
                display(self.multiple_display_widget)
                display(self.path_widget)
                display(self.index_widget)
                display(self.column_widget)
                display(self.value_widget)
                display(self.column_level_widget)
                display(self.value_level_widget)
                print("file: {}, index: {}, date: {} is {}, target: {} is {}".format(
                    self.path_widget.value,
                    self.index_widget.value,
                    self.column_widget.value,
                    self.column_level_widget.value,
                    self.value_widget.value,
                    self.value_level_widget.value
                ))
                display(weave(self.sdd, self.flows, palette=self.palette).to_widget(**self.size))
            else:
                clear_output(wait=False)
                display(self.multiple_display_widget)
                display(self.path_widget)
                display(self.index_widget)
                display(self.column_widget)
                display(self.value_widget)
                display(self.column_level_widget)
                display(self.value_level_widget)
                print("file: {}, index: {}, date: {} is {}, target: {} is {}".format(
                    self.path_widget.value,
                    self.index_widget.value,
                    self.column_widget.value,
                    self.column_level_widget.value,
                    self.value_widget.value,
                    self.value_level_widget.value
                ))
                display(weave(self.sdd, self.flows, palette=self.palette).to_widget(**self.size))
                clear_output(wait=True)

    def get_sdd(self):
        tmp_flows, nodes, bundles, orderings, df_piv = get_node_order_bundle(
            self.df,
            self.index_widget.value,
            self.column_widget.value,
            self.value_widget.value
        )
        flow_partition = nodes[self.column_level_widget.value].partition
        y_list = np.sort(df_piv[self.column_level_widget.value].unique()).tolist()
        palette = get_palette(y_list, self.value_level_widget.value)
        sdd = SankeyDefinition(nodes, bundles, orderings,
                               flow_partition=flow_partition)
        self.flows = tmp_flows
        self.palette = palette
        self.size = dict(width=1070, height=500)
        if self.multiple_display_widget.value == "No" and self.display_flag is True:
            display(self.multiple_display_widget)
            display(self.path_widget)
            display(self.index_widget)
            display(self.column_widget)
            display(self.value_widget)
            display(self.column_level_widget)
            display(self.value_level_widget)
        print("file: {}, index: {}, date: {} is {}, target: {} is {}".format(
            self.path_widget.value,
            self.index_widget.value,
            self.column_widget.value,
            self.column_level_widget.value,
            self.value_widget.value,
            self.value_level_widget.value
        ))
        display(weave(sdd, self.flows, palette=self.palette).to_widget(**self.size))
        if self.multiple_display_widget.value == "No":
            clear_output(wait=True)
        self.display_flag = True
        return sdd


## main function

In [5]:
def main():
    # define widgets
    style = {'description_width': 'initial'}
    multiple_display_widget = Dropdown(options=["Yes", "No"], value=None, description="multiple display?", style=style)
    path_widget = Dropdown(options=extract_files(), value=None, description="file path", style=style)
    i_widget = Dropdown(options=[], value=None, description="index column", style=style)
    x_widget = Dropdown(options=[], value=None, description="date column", style=style)
    y_widget = Dropdown(options=[], value=None, description="target variable", style=style)
    x_level_widget = Dropdown(options=[], value=None, description="target date", style=style)
    y_level_widget = Dropdown(options=[], value=None, description="target value", style=style)

    # define an instance
    ts = TsSankey(multiple_display_widget, path_widget, i_widget, x_widget, y_widget, x_level_widget, y_level_widget)

    # update widgets
    multiple_display_widget.observe(ts.on_clear_output, names='value')
    path_widget.observe(ts.on_path_update, names='value')
    x_widget.observe(ts.on_column_update, names='value')
    y_widget.observe(ts.on_value_update, names='value')
    x_level_widget.observe(ts.on_column_to_value_update, names='value')
    y_level_widget.observe(ts.on_value_level_update, names='value')

    # temp function to for interaction
    def f(multiple_display_widget, data_path, i_widget, x_widget, y_widget, x_level_widget, y_level_widget):
        pass

    interact(f,
             multiple_display_widget=multiple_display_widget,
             data_path=path_widget,
             i_widget=i_widget,
             x_widget=x_widget,
             y_widget=y_widget,
             x_level_widget=x_level_widget,
             y_level_widget=y_level_widget)


# run

type `main()` to run the interaction between floweaver-path and data/templace.csv

In [6]:
# main()