In [1]:
import os
import matplotlib.pyplot as plt

from demo.healthcare.histogram_inspection import HistogramInspection
from demo.healthcare.missing_embeddings_inspection import MissingEmbeddingInspection
from demo.healthcare.lineage_demo_inspection import LineageDemoInspection
from mlinspect.inspections.materialize_first_rows_inspection import MaterializeFirstRowsInspection
from mlinspect.pipeline_inspector import PipelineInspector
from mlinspect.visualisation import save_fig_to_path
from mlinspect.utils import get_project_root
from mlinspect.instrumentation.dag_node import OperatorType

import warnings
warnings.filterwarnings('ignore')

In [2]:
import pandas as pd
data = pd.read_csv(os.path.join(str(get_project_root()), "demo", "healthcare", "MOCK_DATA.csv"), na_values='?')

patients = data[['id', 'first_name', 'last_name', 'race', 'county', 'num_children', 'income', 'age_group', 'ssn']]
patients.to_csv(os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare_patients.csv"), index=False, na_rep='?')

histories = data[['smoker', 'complications', 'ssn']]
histories = histories.dropna(axis=0, subset=['ssn'])
histories.to_csv(os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare_histories.csv"), index=False, na_rep='?')

# Overview of this example from the paper
![overview](paper_example_image.png)

# Add inspections and execute the pipeline

The pipeline inspector returns the extracted dag and the result of our inspections. The runnable version of the example above is in the `healthcare.py` file

In [None]:
HEALTHCARE_FILE_PY = os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare.py")

inspection_result = PipelineInspector\
    .on_pipeline_from_py_file(HEALTHCARE_FILE_PY) \
    .add_inspection(HistogramInspection()) \
    .add_inspection(MissingEmbeddingInspection(20)) \
    .add_inspection(LineageDemoInspection(5)) \
    .add_inspection(MaterializeFirstRowsInspection(5)) \
    .execute()
extracted_dag = inspection_result.dag
inspection_results = inspection_result.inspection_to_annotations

# Now, let's look at the extracted Dag

In [None]:
from IPython.display import Image

filename = os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare.png")
save_fig_to_path(extracted_dag, filename)

Image(filename=filename) 

# Want to know the output of some specific operator?
We can use the `MaterializeFirstRowsInspection` to look at e.g. the output of a OneHotEncoder and the imputer right before it

In [None]:
first_rows_inspection_result = inspection_results[MaterializeFirstRowsInspection(5)]

relevant_nodes = [node for node in extracted_dag.nodes if node.description in {
    "Imputer (SimpleImputer), Column: 'county'", "Categorical Encoder (OneHotEncoder), Column: 'county'"}]

for dag_node in relevant_nodes:
    if dag_node in first_rows_inspection_result and first_rows_inspection_result[dag_node] is not None:
        print('\033[1m')
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description, '\033[0m')
        row = first_rows_inspection_result[dag_node][0]
        print(first_rows_inspection_result[dag_node])
        print("")
        print("")

# Want to know the origin of some row in the featurized model input?
We can use the `LineageDemoInspection` to get row-level lineage information for e.g., a featurized tuple. In practice, you probably do not want to look at the lineage information yourself, as it can get quite complicated for complex pipelines like the one in our example. In the future, we could e.g., extend the lineage inspection to take a list of lineage ids and materialize all related intermediate results in the pipeline. This way, users do not have to interpret the lineage ids themselves.

In [None]:
lineage_inspection_result = inspection_results[LineageDemoInspection(5)]

relevant_nodes = [node for node in extracted_dag.nodes if node.operator_type in {OperatorType.DATA_SOURCE, OperatorType.GROUP_BY_AGG, OperatorType.CONCATENATION}]

#print(lineage_inspection_result)
for dag_node in relevant_nodes:
    if dag_node in lineage_inspection_result: #and lineage_inspection_result[dag_node] is not None:
        print('\033[1m')
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description, '\033[0m')
        row = lineage_inspection_result[dag_node][0]
        print("First output row:")
        print("\033[1mLineage: \033[0m{}".format(row[0]))
        print("\033[1mValue: \033[0m{}".format(row[1]))
        print("")

# What about issue 6? Were there missing embeddings?
Let's look at the output from the `MissingEmbeddingInspection`

In [None]:
embedding_inspection_result = inspection_results[MissingEmbeddingInspection(20)]

for dag_node in extracted_dag.nodes:
    if dag_node in embedding_inspection_result and embedding_inspection_result[dag_node] is not None:
        print('\033[1m')
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description, '\033[0m')
        row = lineage_inspection_result[dag_node][0]
        op_result = embedding_inspection_result[dag_node]
        print("Number of missing embeddings: {}".format(op_result["missing_embedding_count"]))
        print("Example values with missing embeddings: {}".format(op_result["missing_embeddings_examples"]))
    

# We can look at how histograms of sensitive groups change after different Dag nodes
Thanks to our annotation propagation, this works even if the group columns are projected out at some point (Issue 2)

In [None]:
histogram_inspection_result = inspection_results[HistogramInspection()]

def print_dag_operator_histograms(before, after):
    description_before = "{}\n{}\n{}\n{}".format(before.operator_type, before.code_reference, before.module, before.description)
    description_after = "{}\n{}\n{}\n{}".format(after.operator_type, after.code_reference, after.module, after.description)

    print("Age_group histogram")

    plt.subplot(1, 2, 1)
    before_output_age_group = histogram_inspection_result[before]["age_group_counts"]
    plt.bar(before_output_age_group.keys(), before_output_age_group.values())
    plt.title=description_before

    plt.subplot(1, 2, 2)
    after_output_age_group = histogram_inspection_result[after]["age_group_counts"]
    plt.bar(after_output_age_group.keys(), after_output_age_group.values())
    plt.title=description_after
    
    fig = plt.gcf()
    fig.set_size_inches(12, 4)
    plt.show()
    
    print("Race histogram")
            
    plt.subplot(1, 2, 1)        
    before_output_race_group = histogram_inspection_result[before]["race_counts"]
    keys = [str(key) for key in before_output_race_group.keys()]
    plt.bar(keys, before_output_race_group.values())
    plt.title=description_before
   
    plt.subplot(1, 2, 2)
    after_output_race_group = histogram_inspection_result[after]["race_counts"]
    keys = [str(key) for key in after_output_race_group.keys()]
    plt.bar(keys, after_output_race_group.values())
    plt.title=description_after
    
    fig = plt.gcf()
    fig.set_size_inches(12, 4)
    plt.show()

## Issue 1: Join might change proportions of groups in data

In [None]:
relevant_nodes = [node for node in extracted_dag.nodes if node.description in {
    "healthcare_patients.csv", "on ['ssn']"}]

print_dag_operator_histograms(relevant_nodes[0], relevant_nodes[1])

As we can see, there are no noteworthy changes because of the join.

## Issue 3: Selection might change proportions of groups in data

In [None]:
relevant_nodes = [node for node in extracted_dag.nodes if node.description in {
    "to ['smoker', 'last_name', 'county', 'num_children', 'race', 'income', 'label']", "Select by series"}]

print_dag_operator_histograms(relevant_nodes[0], relevant_nodes[1])

**There clearly is an issue here! A lot of values from the `race` `race3` are filtered out!**

## Issue 4: Imputation might change proportions of groups in data

In [None]:
relevant_nodes = [node for node in extracted_dag.nodes if node.description in {
    "to ['race'] (ColumnTransformer)", "Imputer (SimpleImputer), Column: 'race'"}]

print_dag_operator_histograms(relevant_nodes[0], relevant_nodes[1])

**The `most-frequent` imputation amplifies the existing `race` imbalance!**