#### Model Bias Monitoring

In [None]:
import copy
import json
import random
import time
import pandas as pd

from datetime import datetime, timedelta

from sagemaker import get_execution_role, image_uris, Session
from sagemaker.clarify import (
    BiasConfig,
    DataConfig,
    ModelConfig,
    ModelPredictedLabelConfig,
    SHAPConfig,
)
from sagemaker.model import Model
from sagemaker.model_monitor import (
    BiasAnalysisConfig,
    CronExpressionGenerator,
    DataCaptureConfig,
    EndpointInput,
    ExplainabilityAnalysisConfig,
    ModelBiasMonitor,
    ModelExplainabilityMonitor,
)
from sagemaker.s3 import S3Downloader, S3Uploader

In [None]:
# Configurations
role = get_execution_role()
print(f"RoleArn: {role}")

sagemaker_session = Session()
sagemaker_client = sagemaker_session.sagemaker_client
sagemaker_runtime_client = sagemaker_session.sagemaker_runtime_client

region = sagemaker_session.boto_region_name
print(f"AWS region: {region}")

# A different bucket can be used, but make sure the role for this notebook has
# the s3:PutObject permissions. This is the bucket into which the data is captured
bucket = Session().default_bucket()
print(f"Demo Bucket: {bucket}")
prefix = "sagemaker/DEMO-ClarifyModelMonitor-20200901"
s3_key = f"s3://{bucket}/{prefix}"
print(f"S3 key: {s3_key}")

s3_capture_upload_path = f"{s3_key}/datacapture"
ground_truth_upload_path = f"{s3_key}/ground_truth_data/{datetime.now():%Y-%m-%d-%H-%M-%S}"
s3_report_path = f"{s3_key}/reports"

print(f"Capture path: {s3_capture_upload_path}")
print(f"Ground truth path: {ground_truth_upload_path}")
print(f"Report path: {s3_report_path}")

baseline_results_uri = f"{s3_key}/baselining"
print(f"Baseline results uri: {baseline_results_uri}")

endpoint_instance_count = 1
endpoint_instance_type = "ml.m5.large"
schedule_expression = CronExpressionGenerator.hourly()

model_name = f"DEMO-xgb-churn-pred-model-monitor-{datetime.utcnow():%Y-%m-%d-%H%M}"
print("Model name: ", model_name)
endpoint_name = f"DEMO-xgb-churn-model-monitor-{datetime.utcnow():%Y-%m-%d-%H%M}"
print("Endpoint name: ", endpoint_name)

In [None]:
# Model Files and Data
model_file = "model/xgb-churn-prediction-model.tar.gz"
test_file = "test_data/upload-test-file.txt"
test_dataset = "test_data/test-dataset-input-cols.csv"
validation_dataset = "test_data/validation-dataset-with-header.csv"
dataset_type = "text/csv"

with open(validation_dataset) as f:
    headers_line = f.readline().rstrip()
all_headers = headers_line.split(",")
label_header = all_headers[0]

In [None]:
# Monitoring: model bias
model_bias_monitor = ModelBiasMonitor(
    role=role,
    sagemaker_session=sagemaker_session,
    max_runtime_in_seconds=1800,
)


model_bias_baselining_job_result_uri = f"{baseline_results_uri}/model_bias"
model_bias_data_config = DataConfig(
    s3_data_input_path=validation_dataset,
    s3_output_path=model_bias_baselining_job_result_uri,
    label=label_header,
    headers=all_headers,
    dataset_type=dataset_type,
)

In [None]:

model_bias_config = BiasConfig(
    label_values_or_threshold=[1],
    facet_name="Account Length",
    facet_values_or_threshold=[100],
)

In [None]:
model_predicted_label_config = ModelPredictedLabelConfig(
    probability_threshold=0.8,
)

In [None]:
model_config = ModelConfig(
    model_name=model_name,
    instance_count=endpoint_instance_count,
    instance_type=endpoint_instance_type,
    content_type=dataset_type,
    accept_type=dataset_type,
)

Kick off baseling job

In [None]:
model_bias_monitor.suggest_baseline(
    model_config=model_config,
    data_config=model_bias_data_config,
    bias_config=model_bias_config,
    model_predicted_label_config=model_predicted_label_config,
)
print(f"ModelBiasMonitor baselining job: {model_bias_monitor.latest_baselining_job_name}")


In [None]:
model_bias_analysis_config = None
if not model_bias_monitor.latest_baselining_job:
    model_bias_analysis_config = BiasAnalysisConfig(
        model_bias_config,
        headers=all_headers,
        label=label_header,
    )
model_bias_monitor.create_monitoring_schedule(
    analysis_config=model_bias_analysis_config,
    output_s3_uri=s3_report_path,
    endpoint_input=EndpointInput(
        endpoint_name=endpoint_name,
        destination="/opt/ml/processing/input/endpoint",
        start_time_offset="-PT1H",
        end_time_offset="-PT0H",
        probability_threshold_attribute=0.8,
    ),
    ground_truth_input=ground_truth_upload_path,
    schedule_cron_expression=schedule_expression,
)
print(f"Model bias monitoring schedule: {model_bias_monitor.monitoring_schedule_name}")

How to test / invoke monitoring:

In [None]:
def wait_for_execution_to_start(model_monitor):
    print(
        "A hourly schedule was created above and it will kick off executions ON the hour (plus 0 - 20 min buffer)."
    )

    print("Waiting for the first execution to happen", end="")
    schedule_desc = model_monitor.describe_schedule()
    while "LastMonitoringExecutionSummary" not in schedule_desc:
        schedule_desc = model_monitor.describe_schedule()
        print(".", end="", flush=True)
        time.sleep(60)
    print()
    print("Done! Execution has been created")

    print("Now waiting for execution to start", end="")
    while schedule_desc["LastMonitoringExecutionSummary"]["MonitoringExecutionStatus"] in "Pending":
        schedule_desc = model_monitor.describe_schedule()
        print(".", end="", flush=True)
        time.sleep(10)

    print()
    print("Done! Execution has started")

In [None]:
model_bias_monitor.stop_monitoring_schedule()

In [None]:
# Waits for the schedule to have last execution in a terminal status.
def wait_for_execution_to_finish(model_monitor):
    schedule_desc = model_monitor.describe_schedule()
    execution_summary = schedule_desc.get("LastMonitoringExecutionSummary")
    if execution_summary is not None:
        print("Waiting for execution to finish", end="")
        while execution_summary["MonitoringExecutionStatus"] not in [
            "Completed",
            "CompletedWithViolations",
            "Failed",
            "Stopped",
        ]:
            print(".", end="", flush=True)
            time.sleep(60)
            schedule_desc = model_monitor.describe_schedule()
            execution_summary = schedule_desc["LastMonitoringExecutionSummary"]
        print()
        print("Done! Execution has finished")
    else:
        print("Last execution not found")
wait_for_execution_to_finish(model_bias_monitor)

List generated results

In [None]:
schedule_desc = model_bias_monitor.describe_schedule()
execution_summary = schedule_desc.get("LastMonitoringExecutionSummary")
if execution_summary and execution_summary["MonitoringExecutionStatus"] in [
    "Completed",
    "CompletedWithViolations",
]:
    last_model_bias_monitor_execution = model_bias_monitor.list_executions()[-1]
    last_model_bias_monitor_execution_report_uri = (
        last_model_bias_monitor_execution.output.destination
    )
    print(f"Report URI: {last_model_bias_monitor_execution_report_uri}")
    last_model_bias_monitor_execution_report_files = sorted(
        S3Downloader.list(last_model_bias_monitor_execution_report_uri)
    )
    print("Found Report Files:")
    print("\n ".join(last_model_bias_monitor_execution_report_files))
else:
    last_model_bias_monitor_execution = None
    print(
        "====STOP==== \n No completed executions to inspect further. Please wait till an execution completes or investigate previously reported failures."
    )

In [None]:
# If there are violations compared to the baseline, they will be listed here.

if last_model_bias_monitor_execution:
    model_bias_violations = last_model_bias_monitor_execution.constraint_violations()
    if model_bias_violations:
        print(model_bias_violations.body_dict)

In [None]:
# Model Explainability Monitor

model_explainability_monitor = ModelExplainabilityMonitor(
    role=role,
    sagemaker_session=sagemaker_session,
    max_runtime_in_seconds=1800,
)

In [None]:
model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability"
model_explainability_data_config = DataConfig(
    s3_data_input_path=validation_dataset,
    s3_output_path=model_explainability_baselining_job_result_uri,
    label=label_header,
    headers=all_headers,
    dataset_type=dataset_type,
)

In [None]:
# Here use the mean value of test dataset as SHAP baseline
test_dataframe = pd.read_csv(test_dataset, header=None)
shap_baseline = [list(test_dataframe.mean())]

shap_config = SHAPConfig(
    baseline=shap_baseline,
    num_samples=100,
    agg_method="mean_abs",
    save_local_shap_values=False,
)


In [None]:
# Kickoff baselining job

model_explainability_monitor.suggest_baseline(
    data_config=model_explainability_data_config,
    model_config=model_config,
    explainability_config=shap_config,
)
print(
    f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}"
)

In [None]:
model_explainability_monitor.latest_baselining_job.wait(logs=False)
model_explainability_constraints = model_explainability_monitor.suggested_constraints()
print()
print(
    f"ModelExplainabilityMonitor suggested constraints: {model_explainability_constraints.file_s3_uri}"
)
print(S3Downloader.read_file(model_explainability_constraints.file_s3_uri))

In [None]:
# Schedule model explainability monitor

model_explainability_analysis_config = None
if not model_explainability_monitor.latest_baselining_job:
    # Remove label because only features are required for the analysis
    headers_without_label_header = copy.deepcopy(all_headers)
    headers_without_label_header.remove(label_header)
    model_explainability_analysis_config = ExplainabilityAnalysisConfig(
        explainability_config=shap_config,
        model_config=model_config,
        headers=headers_without_label_header,
    )
model_explainability_monitor.create_monitoring_schedule(
    output_s3_uri=s3_report_path,
    endpoint_input=endpoint_name,
    schedule_cron_expression=schedule_expression,
)

In [None]:
wait_for_execution_to_start(model_explainability_monitor)

In [None]:
model_explainability_monitor.stop_monitoring_schedule()

In [None]:
wait_for_execution_to_finish(model_explainability_monitor)

In [None]:
schedule_desc = model_explainability_monitor.describe_schedule()
execution_summary = schedule_desc.get("LastMonitoringExecutionSummary")
if execution_summary and execution_summary["MonitoringExecutionStatus"] in [
    "Completed",
    "CompletedWithViolations",
]:
    last_model_explainability_monitor_execution = model_explainability_monitor.list_executions()[-1]
    last_model_explainability_monitor_execution_report_uri = (
        last_model_explainability_monitor_execution.output.destination
    )
    print(f"Report URI: {last_model_explainability_monitor_execution_report_uri}")
    last_model_explainability_monitor_execution_report_files = sorted(
        S3Downloader.list(last_model_explainability_monitor_execution_report_uri)
    )
    print("Found Report Files:")
    print("\n ".join(last_model_explainability_monitor_execution_report_files))
else:
    last_model_explainability_monitor_execution = None
    print(
        "====STOP==== \n No completed executions to inspect further. Please wait till an execution completes or investigate previously reported failures."
    )

In [None]:
# If there are any violations compared to the baseline, they will be listed here.

if last_model_explainability_monitor_execution:
    model_explainability_violations = (
        last_model_explainability_monitor_execution.constraint_violations()
    )
    if model_explainability_violations:
        print(model_explainability_violations.body_dict)

In [None]:
# CLEAN UP 
# STOP WOKER THREADS
invoke_endpoint_thread.terminate()
ground_truth_thread.terminate()

In [None]:
from sagemaker.predictor import Predictor

predictor = Predictor(endpoint_name, sagemaker_session=sagemaker_session)
model_monitors = predictor.list_monitors()
for model_monitor in model_monitors:
    model_monitor.stop_monitoring_schedule()
    wait_for_execution_to_finish(model_monitor)
    model_monitor.delete_monitoring_schedule()

In [None]:
predictor.delete_endpoint()
predictor.delete_model()