# Post-Processing Functions in ValidMind

Welcome! This notebook demonstrates how to use post-processing functions with ValidMind tests to customize test outputs. You'll learn various ways to modify test results including updating tables, adding/removing tables, creating figures from tables, and vice versa.

## Contents
- [About Post-Processing Functions](#about-post-processing-functions)
- [Key Concepts](#key-concepts)
- [Setup and Prerequisites](#setup-and-prerequisites)
- [Simple Tabular Updates](#simple-tabular-updates)
- [Adding Tables](#adding-tables) 
- [Removing Tables](#removing-tables)
- [Creating Figures from Tables](#creating-figures-from-tables)
- [Creating Tables from Figures](#creating-tables-from-figures)
- [Re-Drawing Confusion Matrix](#re-drawing-confusion-matrix)
- [Re-Drawing ROC Curve](#re-drawing-roc-curve)
- [Custom Test Example](#custom-test-example)

## About Post-Processing Functions

Post-processing functions allow you to customize the output of ValidMind tests before they are saved to the platform. These functions take a TestResult object as input and return a modified TestResult object.

Common use cases include:
- Reformatting table data
- Adding or removing tables/figures
- Creating new visualizations from test data
- Customizing test pass/fail criteria

### Key Concepts

**`validmind.vm_models.result.TestResult`**: Whenever a test is run with the `run_test` function in ValidMind, the items returned/produced by the test are bundled into a single `TestResult` object. There are several attributes on this object that are useful to know about:
- `tables`: List of `validmind.vm_models.result.ResultTable` objects (see below)
- `figures`: List of `validmind.vm_models.figure.Figure` objects (see below)
- `passed`: Optional boolean indicating test pass/fail status. `None` indicates that the test is not a pass/fail test (previously known as a threshold test).
- `raw_data`: Optional `validmind.vm_models.result.RawData` object containing additional data from test execution. Some ValidMind tests will produce this raw data to be used in post-processing functions. This data is not displayed in the test result or sent to the ValidMind platform (currently). To view the available raw data, you can run `result.raw_data.inspect()` which will return a dictionary where the keys are the raw data attributes available and the values are string representations of the data.

**`validmind.vm_models.result.ResultTable`**: ValidMind object representing tables displayed in the test result and sent to the ValidMind platform:
- `title`: Optional table title
- `data`: Pandas dataframe

**`validmind.vm_models.figure.Figure`**: ValidMind object representing plots/visualizations displayed in the test result and sent to the ValidMind platform:
- `figure`: matplotlib or plotly figure object
- `key`: Unique identifier
- `ref_id`: Reference ID linking to test

## Setup and Prerequisites

First, we'll set up our environment and load sample data using the customer churn dataset:

In [None]:
import xgboost as xgb
import validmind as vm
from validmind.datasets.classification import customer_churn

raw_df = customer_churn.load_data()

train_df, validation_df, test_df = customer_churn.preprocess(raw_df)

x_train = train_df.drop(customer_churn.target_column, axis=1)
y_train = train_df[customer_churn.target_column]
x_val = validation_df.drop(customer_churn.target_column, axis=1)
y_val = validation_df[customer_churn.target_column]

model = xgb.XGBClassifier(early_stopping_rounds=10)
model.set_params(
    eval_metric=["error", "logloss", "auc"],
)
model.fit(
    x_train,
    y_train,
    eval_set=[(x_val, y_val)],
    verbose=False,
)

vm_raw_dataset = vm.init_dataset(
    dataset=raw_df,
    input_id="raw_dataset",
    target_column=customer_churn.target_column,
    class_labels=customer_churn.class_labels,
    __log=False,
)

vm_train_ds = vm.init_dataset(
    dataset=train_df,
    input_id="train_dataset",
    target_column=customer_churn.target_column,
    __log=False,
)

vm_test_ds = vm.init_dataset(
    dataset=test_df,
    input_id="test_dataset",
    target_column=customer_churn.target_column,
    __log=False,
)

vm_model = vm.init_model(
    model,
    input_id="model",
    __log=False,
)

vm_train_ds.assign_predictions(
    model=vm_model,
)

vm_test_ds.assign_predictions(
    model=vm_model,
)

As a refresher, here is how we run a test normally, without any post-processing:

In [None]:
from validmind.tests import run_test

result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
)

## Post-processing functions

### Simple Tabular Updates

The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table.

Some key concepts to keep in mind:
- Tables produced by a test are accessible via the `result.tables` attribute
    - The `result.tables` attribute is a list of `ResultTable` objects which are simple data structures that contain a `data` attribute and an optional `title` attribute
    - The `data` attribute is guaranteed to be a `pd.DataFrame` whether the test code itself returns a `pd.DataFrame` or a list of dictionaries
    - The `title` attribute is optional and can be set by tests that return a dictionary where the keys are the table titles and the values are the table data (e.g. `{"Classifier Performance": performance_df, "Class Legend": [{"Class Value": "0", "Class Label": "No Churn"}, {"Class Value": "1", "Class Label": "Churn"}]}}`)
- Post-processing functions can directly modify any of the tables in the `result.tables` list and return the modified `TestResult` object... This can be useful for renaming columns, adding/removing rows, etc.

In [None]:
from validmind.vm_models.result import TestResult


def add_class_labels(result: TestResult):
    result.tables[0].data["Class"] = (
        result.tables[0]
        .data["Class"]
        .map(lambda x: "Churn" if x == "1" else "No Churn" if x == "0" else x)
    )

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=add_class_labels,
)

### Adding Tables

This example shows how to add a legend table mapping class values to labels using the `TestResult.add_table()` method:

In [None]:
def add_table(result: TestResult):
    # add legend table to show map of class value to class label
    result.add_table(
        title="Class Legend",
        table=[
            {"Class Value": "0", "Class Label": "No Churn"},
            {"Class Value": "1", "Class Label": "Churn"},
        ],
    )

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=add_table,
)

### Removing Tables 

If there are tables in the test result that you don't want to display or log to the ValidMind platform, you can remove them using the `TestResult.remove_table()` method.

In [None]:
def remove_table(result: TestResult):
    result.remove_table(1)

    return result


result = run_test(
    "validmind.model_validation.sklearn.ClassifierPerformance",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=remove_table,
)

### Creating Figures from Tables

A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table using the `TestResult.add_figure()` method. This method can take a `matplotlib`, `plotly`, raw PNG `bytes`, or a `validmind.vm_models.figure.Figure` object.

In [None]:
from plotly.express import bar


def create_figure(result: TestResult):
    result.add_figure(
        bar(result.tables[0].data, x="Variable", y="Total Count of Outliers")
    )

    return result


result = run_test(
    "validmind.data_validation.IQROutliersTable",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    post_process_fn=create_figure,
)

### Creating Tables from Figures

The reverse operation - extracting tabular data from figures - is also possible. However, its recommended instead to use the raw data produced by the test (assuming it is available) as the approach below requires deeper knowledge of the underlying figure (e.g. `matplotlib` or `plotly`) and may not be as robust/maintainable.

In [None]:
def create_table(result: TestResult):
    for fig in result.figures:
        data = fig.figure.data[0]

        result.add_table(
            title=fig.figure.layout.title.text,
            table=[
                {"Percentile": x, "Outlier Count": y}
                for x, y in zip(data.x, data.y)
            ],
        )

    return result


result = run_test(
    "validmind.data_validation.IQROutliersBarPlot",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    post_process_fn=create_table,
)

### Re-Drawing Confusion Matrix

A less common example is re-drawing a figure. This example uses the table produced by the test to create a matplotlib confusion matrix figure and removes the existing plotly figure.

In [None]:
import matplotlib.pyplot as plt


def re_draw_class_imbalance(result: TestResult):
    data = result.tables[0].data

    # remove the existing figure
    result.remove_figure(0)

    # use matplotlib to plot the confusion matrix
    fig = plt.figure()

    plt.bar(data["Exited"], data["Percentage of Rows (%)"])
    plt.xlabel("Exited")
    plt.ylabel("Percentage of Rows (%)")
    plt.title("Class Imbalance")

    # add the figure to the result
    result.add_figure(fig)

    # close the figure to avoid showing it in the test result
    plt.close()

    return result


result = run_test(
    "validmind.data_validation.ClassImbalance",
    inputs={"dataset": vm_test_ds},
    generate_description=False,
    post_process_fn=re_draw_class_imbalance,
)

### Re-Drawing ROC Curve

This example shows re-drawing the ROC curve using the raw data produced by the test. This is the recommended approach to reproducing figures or tables from test results as it allows you to get intermediate and other raw data that was originally used by the test to produce the figures or tables we want to reproduce.

First, let's run the test without post-processing to see the original result.

In [None]:
# run the test without post-processing
result = run_test(
    "validmind.model_validation.sklearn.ROCCurve",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
)

Now that we have a `TestResult` object, we can inspect the raw data to see what is available.

In [None]:
result.raw_data.inspect()

Now that we know what is available in the raw data, we can build a post-processing function that uses this raw data to reproduce the ROC curve.

In [None]:
def post_process_roc_curve(result: TestResult):
    fpr = result.raw_data.fpr
    tpr = result.raw_data.tpr
    auc = result.raw_data.auc

    # remove the existing figure
    result.remove_figure(0)

    # use matplotlib to plot the ROC curve
    fig = plt.figure()

    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {auc:.2f})")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")

    plt.legend()

    plt.close()

    result.add_figure(fig)

    return result


result = run_test(
    "validmind.model_validation.sklearn.ROCCurve",
    inputs={"dataset": vm_test_ds, "model": vm_model},
    generate_description=False,
    post_process_fn=post_process_roc_curve,
)

### Custom Test Example

While we envision that post-processing functions are most useful for modifying built-in (ValidMind  Library) tests, there are also cases where you may want to use them for your own custom tests. Let's see an example of this.

In [None]:
import pandas as pd
import numpy as np
from validmind import test
from validmind.tests import run_test


@test("custom.CorrelationBetweenVariables")
def CorrelationBetweenVariables(var1: str, var2: str):
    """This fake test shows the relationship between two variables"""
    data = pd.DataFrame(
        {
            "Variable 1": np.random.rand(20),
            "Variable 2": np.random.rand(20),
        }
    )

    return [{"Correlation between var1 and var2": data.corr().iloc[0, 1]}]


variables = ["Age", "Balance", "CreditScore", "EstimatedSalary"]

result = run_test(
    "custom.CorrelationBetweenVariables",
    param_grid={
        "var1": variables,
        "var2": variables,
    }, # this will automatically generate all combinations of variables for var1 and var2
    generate_description=False,
)

As you can see, the test result now contains a table with the correlation between each pair of variables like this:

| var1 | var2 | Correlation between var1 and var2 |
|------|------|-----------------------------------|
| Age | Age | 0.3001 |
| Age | Balance | -0.4185 |
| Age | CreditScore | 0.2952 |
| Age | EstimatedSalary | -0.2855 |
| Balance | Age | 0.0141 |
| Balance | Balance | -0.1513 |
| Balance | CreditScore | 0.2401 |
| Balance | EstimatedSalary | 0.1198 |
| CreditScore | Age | -0.2320 |
| CreditScore | Balance | 0.4125 |
| CreditScore | CreditScore | 0.1726 |
| CreditScore | EstimatedSalary | 0.3187 |
| EstimatedSalary | Age | -0.1774 |
| EstimatedSalary | Balance | -0.1202 |
| EstimatedSalary | CreditScore | 0.1488 |
| EstimatedSalary | EstimatedSalary | 0.0524 |

Now let's say we don't really want to see the big table of correlations. Instead, we want to see a heatmap of the correlations. We can use a post-processing function to create a heatmap from the table and add it to the test result while removing the table.

In [None]:
import plotly.graph_objects as go


def create_heatmap(result: TestResult):
    # get the data from the existing table
    data = result.tables[0].data

    # remove the existing table
    result.remove_table(0)
    
    # Create a pivot table from the data to get it in matrix form
    matrix = pd.pivot_table(
        data,
        values='Correlation between var1 and var2',
        index='var1',
        columns='var2'
    )

    # remove the existing figure 
    result.remove_figure(0)

    # use plotly to create a heatmap
    fig = go.Figure(data=go.Heatmap(
        z=matrix.values,
        x=matrix.columns,
        y=matrix.index,
        colorscale='RdBu',
        zmid=0,  # Center the color scale at 0
    ))

    fig.update_layout(
        title="Correlation Heatmap",
        xaxis_title="Variable",
        yaxis_title="Variable",
    )

    # add the figure to the result
    result.add_figure(fig)

    return result


result = run_test(
    "custom.CorrelationBetweenVariables",
    param_grid={
        "var1": variables,
        "var2": variables,
    },
    generate_description=False,
    post_process_fn=create_heatmap,
)