<p style="text-align:center"><font size="12">Confusion Matrix as Sankey Diagram</font></p>

<p style="background-color:red; text-align:center"><font size="10">Work in Progress!!!</font></p>

In machine learning a confusion matrix is a kind of a table that is used to understand how well our classification model predictions perform, typically a supervised learning. It helps us a lot in understanding the model behaviour and interpreting the results.  

Usually, we use heatmaps to visualize a confusion matrix, but there is another, more elegant and interactive way. In this notebook we will describe step-by-step how to plot a confusion matrix as Sankey diagram. 

These a the main features of Sankey confusion matrix:
- the size of source and target nodes corresponds directly to the number of samples that belongs to each class 
- the width of the links between nodes shows us how the samples were classified (correctly or incorrectly)
- hovering over the nodes and links will display numerical and textual representation of our confusion matrix



# Dependencies

In [1]:
import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', None)

import os

# Classification metrics
from sklearn.metrics import confusion_matrix 

from plotly import graph_objects as go
# set the appropriate renderer in Jupyter Lab to allow Plotly displays figure correctly
# As suggested in the Plotly documentation, you might set the default renderer explicitly as iframe 
#  or collab by adding following lines into your codes
import plotly.io as pio
pio.renderers.default = 'iframe' # or 'colab' or 'iframe' or 'iframe_connected' or 'sphinx_gallery'


## Helper Functions

To help us with visualizations, we will import the script `metrics_utilities`. It is a collection of several helper functions.

The function developed in this notebook will be added to this script.


In [2]:
# Import the script from different folder
import sys  
sys.path.append('./scripts')

import metrics_utilities as mu

# Preparing Data

The data used in this notebook is a result from one of my previous projects - [Bank-Churn-Prediction](https://github.com/zunicd/Bank-Churn-Prediction). <br>
I saved the true (actual) labels and predictions in `.npy` format and in the next two cells we will load them.

## Load Test True (Actual) Labels

In [3]:
# True labels
y_test = np.load('./data/y_test.npy')

## Load Predictions

In [4]:
# Prepare predictions for our models
pred_dt = np.load('./data/pred_dt.npy')
pred_dl = np.load('./data/pred_dl.npy')
pred_knn = np.load('./data/pred_knn.npy')
pred_lr = np.load('./data/pred_lr.npy')
pred_rf = np.load('./data/pred_rf.npy')
pred_svm = np.load('./data/pred_svm.npy')
pred_xgb = np.load('./data/pred_xgb.npy')


## Names of Classes

The `target_names` variable holds names of our classes. It will be used later for displaying evaluation results.

In [5]:
# Names of our classes
target_names = ['Stays', 'Exits']

# Confusion Matrix

## Axes Convention

Instead of Wikipedia convention for axes, we will use sklearn representation:  
- Actual labels on the horizontal axes and Predicted labels on the vertical axes
- and default parameter `labels=[0,1]`, meaning TN (True Negatrive) is at the top left corner

## Why Normalized Confusion Matrix?

Most of real-life data is imbalanced, so using a confusion matrix without normalization might lead to improper conclusions.  
In our Sankey diagram we will include values from both, unnormalized and normalized, matrices.

The simplest way to display a confusion matrix is to print an array.

**Unnormalized Confusion Matrix**

In [6]:
# Confusion matrix
cm = confusion_matrix(y_test, pred_dt)
print(cm)

[[1979  410]
 [ 198  413]]


**Normalized Confusion Matrix**

Few points to know:
1. to calculate normalized version, divide each row element by the sum of the entire row
2. each row represents the total number of true (actual) values for each class label 
3. the normalized matrix will show % prediction of each class made by the model for that specific true label

In [7]:
# Normalized confusion matrix
cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
print(cmn)

[[0.83 0.17]
 [0.32 0.68]]


# Sankey Diagram

**Place Sankey diagram definition here**

## Create Dataframe for Sankey

To create a Sankey diagram, first we have to organize our data. We will use Pnadas DataFrame to prepare and store the data.  
We wiil split the process in several steps::
- base dataframe - using all data from our confusion matrix
- node labels - create node labels list and dictionary of node labels indices
- final dataframe - add columns for normalized matrix, color and text to display when hovering over the links
- mapping node labels to integers
- prepare text tha we want to print in bold font

## Base Dataframe and Node Labels

#### Confusion Matrix &rarr; DataFrame

Create the DataFrame from the confusion matrix, using previously defined `target_names` as row and column names.

In [8]:
# Create dataframe
df = pd.DataFrame(cm, columns=target_names, index=target_names)
df

Unnamed: 0,Stays,Exits
Stays,1979,410
Exits,198,413


#### The Goal

We need to transform this base dataframe to the following dataframe:

```
                 actual	        predicted	  samples
        0	ACTUAL Stays	PREDICTED Stays	    1979
        1	ACTUAL Stays	PREDICTED Exits	     410
        2	ACTUAL Exits	PREDICTED Stays	     198
        3	ACTUAL Exits	PREDICTED Exits	     413
```

For our Sankey diagram 
- column `actual` represents Sankey *source nodes*  
- column `predicted` represents Sankey *target nodes* 
- column `samples`represents values for Sankey *links*.

Later we will add more columns to improve interpretability of our Sankey confusion matrix.

#### Name the Axes

Let's name the row axis to **ACTUAL** and the column axis to **PREDICTED**.

In [9]:
# Axes naming
df = df.rename_axis(index='ACTUAL', columns='PREDICTED')
df

PREDICTED,Stays,Exits
ACTUAL,Unnamed: 1_level_1,Unnamed: 2_level_1
Stays,1979,410
Exits,198,413


#### Update Columns

Let's append axes names to labels of rows and columns

In [10]:
[f'ACTUAL {s}' for s in target_names]

['ACTUAL Stays', 'ACTUAL Exits']

In [11]:
# Set new labels for rows and columns
df = df.set_axis([f'ACTUAL {s}' for s in target_names], axis=0)
df = df.set_axis([f'PREDICTED {s}' for s in target_names], axis=1)
print(df)

              PREDICTED Stays  PREDICTED Exits
ACTUAL Stays             1979              410
ACTUAL Exits              198              413


We can do the same in one line.  
We will create a dataframe from the confusion matrix using the new labels for rows and columns.

In [12]:
df = pd.DataFrame(cm, columns=[f'PREDICTED {s}' for s in target_names], index=[f'ACTUAL {s}' for s in target_names])
print(df)

              PREDICTED Stays  PREDICTED Exits
ACTUAL Stays             1979              410
ACTUAL Exits              198              413


#### Node Labels

Sankeys only take integers for node and target values.  
We will do this transformation close to the end.  
And for now let's prepare data for that.  

First we will create the list of node labels and then dictionary of their indices.

##### List of Node Labels

In [13]:
# column labels --> Sankey target nodes
cl = df.columns.values.tolist()
cl

['PREDICTED Stays', 'PREDICTED Exits']

In [14]:
# row labels --> Sankey source nodes
rl = df.index.values.tolist()
rl

['ACTUAL Stays', 'ACTUAL Exits']

In [15]:
node_labels = rl + cl
node_labels

['ACTUAL Stays', 'ACTUAL Exits', 'PREDICTED Stays', 'PREDICTED Exits']

##### Indices for Node Labels

In [16]:
# Create dictionary with node labels indices
node_labels_inds = {label:ind for ind, label in enumerate(node_labels)}
node_labels_inds

{'ACTUAL Stays': 0,
 'ACTUAL Exits': 1,
 'PREDICTED Stays': 2,
 'PREDICTED Exits': 3}

#### Reshape DataFrame

For Sankey diagram we need to plot flows from source nodes to target nodes. The flows are the numbers of samples being correctly or incorrectly classified.
- Our **source nodes**: 'ACTUAL Stays', 'ACTUAL Exits'
- Our **target nodes**: 'PREDICTED Stays', 'PREDICTED Exits'



The new reshaped dataframe will have 2x2=4 rows, 4 combinations. Each row is one flow:
    
ACTUAL Stays &rarr; PREDICTED Stays  =  # of Stays correctly classified  
ACTUAL Stays &rarr; PREDICTED Exits  =  # of Stays incorrectly classified  
ACTUAL Exits &rarr; PREDICTED Stays  =  # of Exits incorrectly classified  
ACTUAL Exits &rarr; PREDICTED Exits  =  # of Exits correctly classified


To acomplish this we will use Pandas funcitions:  

- stack() - stack the columns to rows, it returns a series with two levels MultiIndex
- reset_index() - reset the multilevel index to default one, and the original index gets converted to columns
- rename() - rename new columns


In [17]:
# Reshape dataframe
df = df.stack().reset_index()
df.rename(columns={0:'samples', 'level_0':'actual', 'level_1':'predicted'}, inplace=True)
df

Unnamed: 0,actual,predicted,samples
0,ACTUAL Stays,PREDICTED Stays,1979
1,ACTUAL Stays,PREDICTED Exits,410
2,ACTUAL Exits,PREDICTED Stays,198
3,ACTUAL Exits,PREDICTED Exits,413


## Final DataFrame

### Normalized Confusion Matrix Column

In [18]:
# Normalized confusion matrix
cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
print(cmn)

[[0.83 0.17]
 [0.32 0.68]]


In [19]:
# Flatten normmalized confusion matrix and add as a new column
df['norm_samples'] = cmn.ravel()
df

Unnamed: 0,actual,predicted,samples,norm_samples
0,ACTUAL Stays,PREDICTED Stays,1979,0.83
1,ACTUAL Stays,PREDICTED Exits,410,0.17
2,ACTUAL Exits,PREDICTED Stays,198,0.32
3,ACTUAL Exits,PREDICTED Exits,413,0.68


### Add New Columns `color` and `link_hover_text`

The ink color is determioned based on classification result (correct or incorrect)

In [20]:
incorrect_red = "rgba(205, 92, 92, 0.8)"
correct_green = "rgba(144, 238, 144, 0.8)"

Create a helper function to add columns `color`, and `link_hover_text` for text to be displayed when hovering over the Sankey links.

In [21]:
# 'color' - link color based on classification result (correct or incorrect)
# 'link_hover_text' - text for hovering over connecting links of sankey diagram
def new_columns(row):
    source_1 = ''.join(row.actual.split()[1:])
    target_1 = ''.join(row.predicted.split()[1:])
    # Correct classification
    if source_1 == target_1:
        row['color'] = correct_green
        row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples correctly classified as {target_1}"
    # Incorrect classification
    else:
        row['color'] = incorrect_red
        row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples incorrectly classified as {target_1}"
    return row

Finalize the DataFrame.

In [22]:
# Apply heper function
df = df.apply(lambda x: new_columns(x), axis=1)
df

Unnamed: 0,actual,predicted,samples,norm_samples,color,link_hover_text
0,ACTUAL Stays,PREDICTED Stays,1979,0.83,"rgba(144, 238, 144, 0.8)",1979 (83%) Stays samples correctly classified as Stays
1,ACTUAL Stays,PREDICTED Exits,410,0.17,"rgba(205, 92, 92, 0.8)",410 (17%) Stays samples incorrectly classified as Exits
2,ACTUAL Exits,PREDICTED Stays,198,0.32,"rgba(205, 92, 92, 0.8)",198 (32%) Exits samples incorrectly classified as Stays
3,ACTUAL Exits,PREDICTED Exits,413,0.68,"rgba(144, 238, 144, 0.8)",413 (68%) Exits samples correctly classified as Exits


### Map Node Labels to Integers

Map node label columns (`actual`, `predicted`) to integers due to Sankey requirements.

In [23]:
node_labels_inds

{'ACTUAL Stays': 0,
 'ACTUAL Exits': 1,
 'PREDICTED Stays': 2,
 'PREDICTED Exits': 3}

In [24]:
# using replace for multiple columns
df = df.replace({'actual':node_labels_inds, 'predicted':node_labels_inds})
df

Unnamed: 0,actual,predicted,samples,norm_samples,color,link_hover_text
0,0,2,1979,0.83,"rgba(144, 238, 144, 0.8)",1979 (83%) Stays samples correctly classified as Stays
1,0,3,410,0.17,"rgba(205, 92, 92, 0.8)",410 (17%) Stays samples incorrectly classified as Exits
2,1,2,198,0.32,"rgba(205, 92, 92, 0.8)",198 (32%) Exits samples incorrectly classified as Stays
3,1,3,413,0.68,"rgba(144, 238, 144, 0.8)",413 (68%) Exits samples correctly classified as Exits


In [25]:
# using assign + apply + lambda
# dft.assign(actual    = dft.actual.apply(lambda x: node_labels_inds[x]),
#            predicted = dft.predicted.apply(lambda x: node_labels_inds[x]))

In [26]:
# Using assign + map
# dft.assign(actual    = dft.actual.map(node_labels_indices),
#             predicted = dft.predicted.map(node_labels_indices))

### Bold Printing in Plotly

Prepare data for bold printing of some words in Plotly.

#### Node Labels

We want to print class names (2nd word in string) in bold font.  
We will use the HTML `<b>` tag for that.



In [27]:
node_labels

['ACTUAL Stays', 'ACTUAL Exits', 'PREDICTED Stays', 'PREDICTED Exits']

In [28]:
node_labels = [f'{ls[0]} <b>{ls[1]}</b>' for ls in [l.split() for l in node_labels]]
print(node_labels)

['ACTUAL <b>Stays</b>', 'ACTUAL <b>Exits</b>', 'PREDICTED <b>Stays</b>', 'PREDICTED <b>Exits</b>']


#### Hovering Text

Printing class names in bold font.

In [29]:
df['link_hover_text'] = [f'{" ".join(ls[0:2])} <b>{ls[2]}</b> {" ".join(ls[3:-1])} <b>{ls[-1]}</b>' for ls in [l.split() for l in df['link_hover_text']]]
df

Unnamed: 0,actual,predicted,samples,norm_samples,color,link_hover_text
0,0,2,1979,0.83,"rgba(144, 238, 144, 0.8)",1979 (83%) <b>Stays</b> samples correctly classified as <b>Stays</b>
1,0,3,410,0.17,"rgba(205, 92, 92, 0.8)",410 (17%) <b>Stays</b> samples incorrectly classified as <b>Exits</b>
2,1,2,198,0.32,"rgba(205, 92, 92, 0.8)",198 (32%) <b>Exits</b> samples incorrectly classified as <b>Stays</b>
3,1,3,413,0.68,"rgba(144, 238, 144, 0.8)",413 (68%) <b>Exits</b> samples correctly classified as <b>Exits</b>


## Plotting

In [30]:
fig = go.Figure(data=[go.Sankey(
    
node = dict(
    pad = 30,
    thickness = 20,
    line = dict(color = "gray", width = 1.0),
    label = node_labels,
    hovertemplate = "%{label} has total %{value:d} samples<extra></extra>"
    ),
link = dict(
    source = df.actual, 
    target = df.predicted,
    value = df.samples,
    color = df.color,
    customdata = df['link_hover_text'], 
    hovertemplate = "%{customdata}<extra></extra>"  
))])

title = f'Decision Tree'


# fig.update_layout(title_text=f"Confusion Matrix \\033[1mSankey\\033[0m Diagram", font_size=15,
                 # width=500, height=400)
# fig.update_layout(title_text=title, title_xanchor='center', font_size=15,
#               width=500, height=400)

fig.update_layout(
    hovermode = 'x',
    title = {
    'text': title,
    'x':0.5,
    },
    # paper_bgcolor = '#51504f',
    font_size = 15,
    # font_color = 'white',
    width = 600,
    height = 500
)



## Define Function

In [31]:
def plot_cm_sankey(model_name, y_test, y_pred, target_names=None):
    """ Plot confusion matrix with Sankey diagram 

    Args:
        model_name: name of the model
        y_test: test target variable
        y_pred: prediction
        target_names: list of class names

    Returns:
        Plot Sankey diagram of confusion matrix
    """ 
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    
    # If class labels not passed, create dummy class labels
    if target_names == None: 
        target_names = []
    if not len(target_names):
        target_names = [f'class-{i+1}' for i in range(len(cm))]
    
    # Prepare dataframe with parameters for Sankey
    def prepare_df_for_sankey(cm, target_names):
        # create a dataframe
        df = pd.DataFrame(cm, columns=[f'PREDICTED {s}' for s in target_names], index=[f'ACTUAL {s}' for s in target_names])
        
        # Create list of node labels
        # target nodes = column labels (PREDICTED ...)
        cl = df.columns.values.tolist()
        # source nodes = row (index) labels (ACTUAL ...)
        rl = df.index.values.tolist()
        node_labels = rl + cl
        
        # Create dictionary with indices for node labels
        node_labels_inds = {label:ind for ind, label in enumerate(node_labels)}
        
        # Stack label from column to row, output is Series
        # Reset index to get DataFrame and rename columns
        df = df.stack().reset_index()
        df.rename(columns={0:'samples', 'level_0':'actual', 'level_1':'predicted'}, inplace=True)
        
        """
               actual	       predicted	  samples
        0	ACTUAL Stays	PREDICTED Stays	    1979
        1	ACTUAL Stays	PREDICTED Exits	     410
        2	ACTUAL Exits	PREDICTED Stays	     198
        3	ACTUAL Exits	PREDICTED Exits	     413
        """

        # Normalized confusion matrix
        cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
        # Add a column with normalized values of samples
        df['norm_samples'] = cmn.ravel()
        
        # Helper function to add new columns: color and link_hover_text 
        # 'color' - link color based on classification result (correct or incorrect)        
        incorrect_red = "rgba(205, 92, 92, 0.8)"
        correct_green = "rgba(144, 238, 144, 0.8)"
        # # 'link_hover_text' - text for hovering on connecting links of sankey diagram
        
        def new_columns(row):
            source_1 = ''.join(row.actual.split()[1:])
            target_1 = ''.join(row.predicted.split()[1:])
            # Correct classification
            if source_1 == target_1:
                row['color'] = correct_green
                row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples correctly classified as {target_1}"
            # Incorrect classification
            else:
                row['color'] = incorrect_red
                row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples incorrectly classified as {target_1}"
            return row

        # Apply "new_columns" function
        df = df.apply(lambda x: new_columns(x), axis=1)
        
        # Sankey only takes integers for node and target values,
        #  so we need to map node label columns (actual, predicted) to numbers
        # Using replace for multiple columns
        df = df.replace({'actual':node_labels_inds, 'predicted':node_labels_inds})
               
        return df, node_labels
    
    
    # Plotting confusion matrix as Sankey diagram
    # Get dataframe and node labels
    df, node_labels = prepare_df_for_sankey(cm, target_names)
    
    # Prepare for bold printing of some words in Plotly
    node_labels = [f'{ls[0]} <b>{ls[1]}</b>' for ls in [l.split() for l in node_labels]]
    df['link_hover_text'] = [f'{" ".join(ls[0:2])} <b>{ls[2]}</b> {" ".join(ls[3:-1])} <b>{ls[-1]}</b>' for ls in [l.split() for l in df['link_hover_text']]]
    

    fig = go.Figure(data=[go.Sankey(    
        node = dict(
        pad = 50,
        thickness = 30,
        line = dict(color = "gray", width = 1.0),
        label = node_labels,
        hovertemplate = "%{label} has total %{value:d} samples<extra></extra>"
        ),
    link = dict(
        source = df.actual, 
        target = df.predicted,
        value = df.samples,
        color = df.color,
        customdata = df['link_hover_text'], 
        hovertemplate = "%{customdata}<extra></extra>"  
    ))])
    
    margins = {'l': 25, 'r': 25, 't': 70, 'b': 25}
    
    fig.update_layout(
        title = {
        'text': f'<b>{model_name}</b>',
        'x':0.5,
        },
        font_size = 15,
        width = 625,
        height = 500,
        #paper_bgcolor = '#d3d3d3',
        # paper_bgcolor = 'white',
        # plot_bgcolor = 'black',
        margin = margins,
    )
    
    return fig

### Run function

In [32]:
plot_cm_sankey('Decision Tree', y_test, pred_dt, target_names)

##### Copy function to script, reload kernel and run it

In [33]:
mu.plot_cm_sankey('Decision Tree', y_test, pred_dt, target_names)

### 3x3 Confusion Matrix

Let's see how this works for an 3x3 confusion matrix.

We will use data from my project [T2D-Predictions](https://github.com/zunicd/T2D-Predictions).

In [34]:
# Actual (True) labels
t2d_y_test = np.load('./data/t2d_y_test.npy')
# Prediction from random forest model
t2d_pred_rf = np.load('./data/t2d_pred_rf.npy')
# Classes
t2d_classes = ['no_diabetes', 'pre_diabetes', 'diabetes']

In [35]:
mu.plot_cm_sankey('Random Forest', t2d_y_test, t2d_pred_rf, t2d_classes)