In [None]:
#| default_exp visualization.plot

In [None]:
#| export
#| hide

import plotly.figure_factory as ff
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [None]:
#| export

def plotly_confusion_matrix(labels, y, _y):

    l = [0 for _ in range(len(labels))]
    z = [[0 for _ in range(len(labels))] for _ in range(len(labels))]
    h = [[0 for _ in range(len(labels))] for _ in range(len(labels))]

    for i, j in zip(y, _y):
        z[j][i] += 1
        l[i] += 1

    x = labels.copy()
    y = labels.copy()

    z_labels = [[str(col) if col != 0 else "" for col in row] for row in z]

    for i in range(len(labels)):
        for j in range(len(labels)):
            if i == j:
                h[i][j] = (
                    "Correctly predicted "
                    + str(z[i][j])
                    + " out of "
                    + str(l[i])
                    + " "
                    + labels[i]
                    + " with accuracy "
                    + str(z[i][j] / l[i])
                )
            else:
                if z[j][i] == 0:
                    h[j][i] = ""
                else:
                    h[j][i] = (
                        "Incorrectly predicted "
                        + str(z[j][i])
                        + " out of "
                        + str(l[i])
                        + " "
                        + labels[i]
                        + " as "
                        + labels[j]
                    )

    fig = ff.create_annotated_heatmap(
        z,
        x=x,
        y=y,
        text=h,
        annotation_text=z_labels,
        hoverinfo="text",
        colorscale="Blues",
    )

    fig.update_layout(width=850, height=550)
    fig.update_layout(margin=dict(t=100, l=200))

    fig.add_annotation(
        dict(
            font=dict(color="#094973", size=16),
            x=0.5,
            y=-0.10,
            showarrow=False,
            text="True Class",
            xref="paper",
            yref="paper",
        )
    )
    fig.add_annotation(
        dict(
            font=dict(color="#094973", size=16),
            x=-0.17,
            y=0.5,
            showarrow=False,
            text="Predicted Class",
            textangle=-90,
            xref="paper",
            yref="paper",
        )
    )

    fig.show()
    return fig


In [None]:
#| export

def get_classification_report(true_categories, predicted_categories):
    # Classification Report
    cl_report = classification_report(
        true_categories,
        predicted_categories,
        labels=[i for i in range(cfg.num_classes)],
        target_names=labels,
        output_dict=False,
    )

    print(f"\nClassification Report\n{cl_report}")
    return cl_report

In [None]:
# #| export

# def get_confusion_matrix(model, test_dataset, y_true, true_categories, y_pred, predicted_categories):
#     # Confusion Matrix
#     def get_cm(model, test_dataset, y_true):

#         y_prediction = model.predict(test_dataset)
#         y_prediction = np.argmax(y_prediction, axis=1)
#         y_test = np.argmax(y_true, axis=1)
#         # Create confusion matrix and normalizes it over predicted (columns)
#         result = confusion_matrix(y_test, y_prediction, normalize="pred")
#         disp = ConfusionMatrixDisplay(confusion_matrix=result, display_labels=labels)
#         disp.plot()
#         plt.xticks(rotation=35)
#         plt.savefig("confusion_matrix.png")
#         plt.close()

#     cm_sklearn = get_cm(model, test_dataset, y_true)
#     return cm_sklearn

In [None]:
#| export

def get_confusion_matrix(true_categories, predicted_categories):
    # Confusion Matrix
    
    # Create confusion matrix and normalizes it over predicted (columns)
    result = confusion_matrix(true_categories, predicted_categories, normalize="pred")
    disp = ConfusionMatrixDisplay(confusion_matrix=result, display_labels=labels)
    disp.plot()
    plt.xticks(rotation=35)
    plt.savefig("confusion_matrix.png")
    plt.close()

    return result

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()