In [8]:
import os
import matplotlib.pyplot as plt

from demo.healthcare.histogram_inspection import HistogramInspection
from demo.healthcare.missing_embeddings_inspection import MissingEmbeddingInspection
from demo.healthcare.lineage_demo_inspection import LineageDemoInspection
from mlinspect.inspections.materialize_first_rows_inspection import MaterializeFirstRowsInspection
from mlinspect.pipeline_inspector import PipelineInspector
from mlinspect.visualisation import save_fig_to_path
from mlinspect.utils import get_project_root
from mlinspect.instrumentation.dag_node import OperatorType

import warnings
warnings.filterwarnings('ignore')

In [None]:
HEALTHCARE_FILE_PY = os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare.py")
print(HEALTHCARE_FILE_PY)

inspection_result = PipelineInspector\
    .on_pipeline_from_py_file(HEALTHCARE_FILE_PY) \
    .add_inspection(HistogramInspection()) \
    .add_inspection(MissingEmbeddingInspection()) \
    .add_inspection(LineageDemoInspection(5)) \
    .add_inspection(MaterializeFirstRowsInspection(5)) \
    .execute()
extracted_dag = inspection_result.dag
inspection_results = inspection_result.analyzer_to_annotations

filename = os.path.join(str(get_project_root()), "demo", "healthcare", "healthcare.png")
save_fig_to_path(extracted_dag, filename)

/Users/stefangrafberger/Documents/uni/master-thesis/mlinspect/demo/healthcare/healthcare.py


In [None]:
from IPython.display import Image
Image(filename=filename) 

In [None]:
first_rows_inspection_result = inspection_results[MaterializeFirstRowsInspection(5)]

for dag_node in extracted_dag.nodes:
    if dag_node in first_rows_inspection_result and first_rows_inspection_result[dag_node] is not None:
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description)
        print("________")
        print(first_rows_inspection_result[dag_node])
        print("")
        print("")

In [None]:
lineage_inspection_result = inspection_results[LineageDemoInspection(5)]
#print(lineage_inspection_result)
for dag_node in extracted_dag.nodes:
    if dag_node in lineage_inspection_result: #and lineage_inspection_result[dag_node] is not None:
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description)
        print("________")
        print(lineage_inspection_result[dag_node])
        print("")
        print("")

In [None]:
embedding_inspection_result = inspection_results[MissingEmbeddingInspection()]
assert len(embedding_inspection_result) == 33

for dag_node in extracted_dag.nodes:
    if dag_node in embedding_inspection_result and embedding_inspection_result[dag_node] is not None:
        print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description)
        print("Number of missing embeddings: {}".format(embedding_inspection_result[dag_node]))
        print("________")
    

In [None]:
histogram_inspection_result = inspection_results[HistogramInspection()]

for dag_node in extracted_dag.nodes:
    print(dag_node.operator_type, dag_node.code_reference, dag_node.module, dag_node.description)
    if dag_node.operator_type != OperatorType.FIT and histogram_inspection_result[dag_node]:
        output_age_group = histogram_inspection_result[dag_node]["age_group_counts"]
        if output_age_group:
            print("age_group histogram")
            plt.bar(output_age_group.keys(), output_age_group.values())
            plt.show()
        output_race_group = histogram_inspection_result[dag_node]["race_counts"]
        # Plotting is slow and there are a lot of different race values, so I will comment this out for now
        if output_race_group:
            print("race histogram")
            keys = [str(key) for key in output_race_group.keys()]
            plt.bar(keys, output_race_group.values())
            plt.show()
    print("________")
    