# Sankey diagrams for customer flow

In [1]:
import pandas as pd
import plotly.graph_objects as go

### Import transition probabilities and transform data

In [2]:
# We use the probabilities that have not been normalized
df = pd.read_csv("data/transition_probabilities_count.csv")
df.columns = ['source', 'checkout', 'dairy', 'drinks', 'fruit', 'spices']
df

Unnamed: 0,source,checkout,dairy,drinks,fruit,spices
0,dairy,1839,13095,1041,886,913
1,drinks,2098,106,5821,855,846
2,entrance,0,2141,1143,2810,1351
3,fruit,2562,1219,697,7586,644
4,spices,946,1213,1024,571,2524


In [3]:
# Convert to long format
df1 = df.melt(id_vars=["source"])
df1.columns = ['source', 'target', 'value']
df1.head()

Unnamed: 0,source,target,value
0,dairy,checkout,1839
1,drinks,checkout,2098
2,entrance,checkout,0
3,fruit,checkout,2562
4,spices,checkout,946


In [4]:
# Mapping for the labels
mapping = {
    "entrance": 0,
    "dairy": 1,
    "drinks": 2,
    "fruit": 3,
    "spices": 4,
    "checkout": 5,
}

In [5]:
# Use mapping to change strings to numbers
df1["source"] = df1["source"].map(mapping)
df1["target"] = df1["target"].map(mapping)

In [6]:
df1.head()

Unnamed: 0,source,target,value
0,1,5,1839
1,2,5,2098
2,0,5,0
3,3,5,2562
4,4,5,946


### Plot sankey diagram

In [7]:
fig = go.Figure(
    go.Sankey(
        node=dict(
            label=list(mapping.keys()),
        ),
        link=dict(
            source=df1["source"],
            target=df1["target"],
            value=df1["value"],
        ),
    )
)

fig.update_layout(title_text="Customer flow", width=1200, height=800)
fig.show()

### Exclude self-loops

In [8]:
# Remove rows where source and target are the same
df2 = df1[df1["source"] != df1["target"]].copy()
df2.head()

Unnamed: 0,source,target,value
0,1,5,1839
1,2,5,2098
2,0,5,0
3,3,5,2562
4,4,5,946


### Plot again

In [9]:
fig = go.Figure(
    go.Sankey(
        node=dict(
            label=list(mapping.keys()),
        ),
        link=dict(
            source=df2["source"],
            target=df2["target"],
            value=df2["value"],
        ),
    )
)

fig.update_layout(title_text="Customer flow without self-loops", width=1200, height=800)
fig.show()