# Post-Processing Functions in ValidMind

Welcome! This notebook demonstrates how to use post-processing functions with ValidMind tests to customize test outputs. You'll learn various ways to modify test results including updating tables, adding/removing tables, creating figures from tables, and vice versa.

## Contents
- [About Post-Processing Functions](#about-post-processing-functions)
- [Key Concepts](#key-concepts)
- [Setup and Prerequisites](#setup-and-prerequisites)
- [Simple Tabular Updates](#simple-tabular-updates)
- [Adding Tables](#adding-tables) 
- [Removing Tables](#removing-tables)
- [Creating Figures from Tables](#creating-figures-from-tables)
- [Creating Tables from Figures](#creating-tables-from-figures)
- [Re-Drawing Confusion Matrix](#re-drawing-confusion-matrix)
- [Re-Drawing ROC Curve](#re-drawing-roc-curve)
- [Custom Test Example](#custom-test-example)

## About Post-Processing Functions

Post-processing functions allow you to customize the output of ValidMind tests before they are saved to the platform. These functions take a TestResult object as input and return a modified TestResult object.

Common use cases include:
- Reformatting table data
- Adding or removing tables/figures
- Creating new visualizations from test data
- Customizing test pass/fail criteria

### Key Concepts

**TestResult object**: The main object that post-processing functions work with, containing:
- tables: List of ResultTable objects
- figures: List of Figure objects 
- passed: Boolean indicating test pass/fail status
- raw_data: Additional data from test execution

**ResultTable**: Object representing tabular data with:
- title: Table title
- data: Pandas DataFrame or list of dictionaries

**Figure**: Object representing plots/visualizations with:
- figure: matplotlib or plotly figure object
- key: Unique identifier
- ref_id: Reference ID linking to test

## Setup and Prerequisites

First, we'll set up our environment and load sample data using the customer churn dataset:

In [None]:
import xgboost as xgb
import validmind as vm
from validmind.datasets.classification import customer_churn

vm.init()

raw_df = customer_churn.load_data()

train_df, validation_df, test_df = customer_churn.preprocess(raw_df)

x_train = train_df.drop(customer_churn.target_column, axis=1)
y_train = train_df[customer_churn.target_column]
x_val = validation_df.drop(customer_churn.target_column, axis=1)
y_val = validation_df[customer_churn.target_column]

model = xgb.XGBClassifier(early_stopping_rounds=10)
model.set_params(
    eval_metric=["error", "logloss", "auc"],
)
model.fit(
    x_train,
    y_train,
    eval_set=[(x_val, y_val)],
    verbose=False,
)

vm_raw_dataset = vm.init_dataset(
    dataset=raw_df,
    input_id="raw_dataset",
    target_column=customer_churn.target_column,
    class_labels=customer_churn.class_labels,
)

vm_train_ds = vm.init_dataset(
    dataset=train_df,
    input_id="train_dataset",
    target_column=customer_churn.target_column,
)

vm_test_ds = vm.init_dataset(
    dataset=test_df, input_id="test_dataset", target_column=customer_churn.target_column
)

vm_model = vm.init_model(
    model,
    input_id="model",
)

vm_train_ds.assign_predictions(
    model=vm_model,
)

vm_test_ds.assign_predictions(
    model=vm_model,
)

As a refresher, here is how we run a test normally, without any post-processing:

In [None]:
from validmind.tests import run_test

result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
)

## Post-processing functions

### Simple Tabular Updates

The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table:

In [None]:
from validmind.vm_models.result import TestResult


def add_class_labels(result: TestResult):
    result.tables[0].data["Class"] = (
        result.tables[0]
        .data["Class"]
        .map(lambda x: "Churn" if x == "1" else "No Churn" if x == "0" else x)
    )

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=add_class_labels,
)

### Adding Tables

Sometimes you may want to add supplementary tables to provide additional context or information. This example shows how to add a legend table mapping class values to labels:

In [None]:
from validmind.vm_models.result import ResultTable

def add_table(result: TestResult):
    # add legend table to show map of class value to class label
    result.add_table(
        ResultTable(
            title="Class Legend",
            data=[
                {"Class Value": "0", "Class Label": "No Churn"},
                {"Class Value": "1", "Class Label": "Churn"},
            ],
        )
    )

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=add_table,
)

### Removing Tables 

You can also remove tables that may not be relevant for your use case. Here we demonstrate removing a specific table from the results:

In [None]:
def remove_table(result: TestResult):
    result.tables.pop(1)

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=remove_table,
)

### Creating Figures from Tables

A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table:

In [None]:
from plotly_express import bar
from validmind.vm_models.figure import Figure


def create_figure(result: TestResult):
    fig = bar(result.tables[0].data, x="Variable", y="Total Count of Outliers")

    result.add_figure(
        Figure(
            figure=fig,
            key="outlier_count_by_variable",
            ref_id=result.ref_id,
        )
    )

    return result


result = run_test(
    "validmind.data_validation.IQROutliersTable",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    post_process_fn=create_figure,
)

### Creating Tables from Figures

The reverse operation - extracting tabular data from figures - is also possible. Here we demonstrate creating a table from figure data:

In [None]:
def create_table(result: TestResult):
    for fig in result.figures:
        data = fig.figure.data[0]

        table_data = [
            {"Percentile": x, "Outlier Count": y}
            for x, y in zip(data.x, data.y)
        ]

        result.add_table(
            ResultTable(
                title=fig.figure.layout.title.text,
                data=table_data,
            )
        )

    return result


result = run_test(
    "validmind.data_validation.IQROutliersBarPlot",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    # post_process_fn=create_table,
)

### Re-Drawing Confusion Matrix

Sometimes you may want to completely replace the default visualizations. This example shows how to redraw a confusion matrix using matplotlib:

In [None]:
import matplotlib.pyplot as plt


def re_draw_class_imbalance(result: TestResult):
    data = result.tables[0].data
    # Exited Percentage of Rows (%) Pass/Fail
    # 0       0                 80.25%      Pass
    # 1       1                 19.75%      Pass

    result.figures = []

    # use matplotlib to plot the confusion matrix
    fig = plt.figure()

    # show a bar plot of the class imbalance with matplotlib
    plt.bar(data["Exited"], data["Percentage of Rows (%)"])
    plt.xlabel("Exited")
    plt.ylabel("Percentage of Rows (%)")
    plt.title("Class Imbalance")

    result.add_figure(
        Figure(
            figure=fig,
            key="confusion_matrix",
            ref_id=result.ref_id,
        )
    )

    plt.close()

    return result


result = run_test(
    "validmind.data_validation.ClassImbalance",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    post_process_fn=re_draw_class_imbalance,
)

### Re-Drawing ROC Curve

Here is another example of re-drawing a figure. This time we are re-drawing the ROC curve:

In [None]:
def post_process_roc_curve(result: TestResult):
    result.raw_data.fpr


result = run_test(
    "validmind.model_validation.sklearn.ROCCurve",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=post_process_roc_curve,
)

### Custom Test Example

While we envision that post-processing functions are most useful for modifying built-in (ValidMind  Library) tests, there are cases where you may want to use them for your own custom tests. Let's see an example of a situation where this is the case:

In [None]:
import random
import pandas as pd
import numpy as np
from plotly_express import bar
from validmind.vm_models.figure import Figure
from validmind.vm_models.result import TestResult
import plotly.graph_objects as go


@vm.test("my_custom_tests.Sensitivity")
def sensitivity_test(strike=None):
    """This is sensitivity test"""
    price = strike * random.random()

    return pd.DataFrame({"Option price": [price]})


def process_results(result: TestResult):

    df = pd.DataFrame(result.tables[0].data)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(x=df["strike"].values, y=df["Option price"].values, mode="lines")
    )

    fig.update_layout(
        # title=params["title"],
        # xaxis_title=params["xlabel"],
        # yaxis_title=params["ylabel"],
        showlegend=True,
        template="plotly_white",  # Adds a grid by default
    )

    result.add_figure(
        Figure(
            figure=fig,
            key="sensitivity_to_strike",
            ref_id=result.ref_id,
        )
    )

    return result


result = run_test(
    "my_custom_tests.Sensitivity:ToStrike",
    param_grid={
        "strike": list(np.linspace(0, 100, 20)),
    },
    post_process_fn=process_results,
)