In [None]:
import json
import dspy
import datetime
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import roc_auc_score, accuracy_score
from scipy.stats import ks_2samp
import shap

# Optional: XGBoost can be used instead; this example uses sklearn's GB for simplicity.

# ----- User settings -----
OUT_DIR = Path("./explainable_ai_demo")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# use your favorite LLM provider (check DSPy documentation for details)
GOOGLE_API_KEY = (
    "<gemini_api_key>"
)


In [12]:
# -----------------------------------------------------------
# 1) Data generation and model training
# -----------------------------------------------------------
rng = np.random.RandomState(42)
X, y = make_classification(
    n_samples=3000,
    n_features=10,
    n_informative=6,
    n_redundant=2,
    n_clusters_per_class=2,
    flip_y=0.02,
    class_sep=1.0,
    random_state=42,
)
feature_names = [f"feat_{i}" for i in range(X.shape[1])]
df = pd.DataFrame(X, columns=feature_names)
df["approved"] = y

# synthetic sensitive attribute correlated with label to demonstrate bias checks
gender = (rng.rand(len(df)) < (0.45 + 0.2 * (df["approved"]))).astype(int)
df["gender"] = gender  # 1: Group A, 0: Group B

# train/test split
X_features = feature_names + ["gender"]
X_train, X_test, y_train, y_test = train_test_split(
    df[X_features],
    df["approved"],
    test_size=0.2,
    random_state=42,
    stratify=df["approved"],
)

model = GradientBoostingClassifier(random_state=42)
model.fit(X_train, y_train)

# baseline metrics
y_proba = model.predict_proba(X_test)[:, 1]
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_proba)
print(f"Trained model — accuracy: {acc:.3f}, AUC: {auc:.3f}")

# Save small model metadata for regulator view
MODEL_METADATA = {
    "version": "0.1",
    "trained_on": str(datetime.date.today()),
    "n_train": len(X_train),
}

Trained model — accuracy: 0.945, AUC: 0.980


In [13]:
# -----------------------------------------------------------
# 2) Tools (to be exposed to DSPy / the LLM agent)
#     We also create a very light tool-call logger to capture the transcript.
# -----------------------------------------------------------
TOOL_LOG = []  # append tuples (timestamp, tool_name, args, result_summary)


def log_tool_call(tool_name, args, result):
    TOOL_LOG.append(
        {
            "time": str(datetime.datetime.utcnow()),
            "tool": tool_name,
            "args": args if len(str(args)) < 2000 else str(args)[:2000],
            "result_summary": (result if isinstance(result, dict) else str(result))
            if len(str(result)) < 2000
            else str(result)[:2000],
        }
    )


# Tool: predict_customer
def predict_customer(customer_row):
    """Run model prediction on a single customer (Series or dict)."""
    # Ensure correct order of columns
    if isinstance(customer_row, dict):
        row = pd.DataFrame([customer_row])[X_features]
    else:
        row = (
            pd.DataFrame([customer_row.values], columns=customer_row.index)[X_features]
            if isinstance(customer_row, pd.Series)
            else pd.DataFrame([customer_row], columns=X_features)
        )
    pred = int(model.predict(row)[0])
    proba = float(model.predict_proba(row)[0][1])
    result = {"prediction": pred, "probability": proba}
    log_tool_call(
        "predict_customer",
        {"row": (row.to_dict(orient="records")[0])},
        {"prediction": pred, "probability": round(proba, 4)},
    )
    return result


# Tool: explain_prediction (use SHAP TreeExplainer for per-instance explanations)
def explain_prediction(customer_row, top_k=5, save_plot=True, plot_path=None):
    """Return SHAP per-feature contributions and save a SHAP bar plot for regulator visuals."""
    # Prepare instance
    if isinstance(customer_row, dict):
        row = pd.DataFrame([customer_row])[X_features]
    else:
        row = (
            pd.DataFrame([customer_row.values], columns=customer_row.index)[X_features]
            if isinstance(customer_row, pd.Series)
            else pd.DataFrame([customer_row], columns=X_features)
        )
    # SHAP explainer
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(row)  # returns array shape (1, n_features)
    # create mapping
    contributions = dict(zip(X_features, shap_values[0].tolist()))
    # top k by absolute magnitude
    top = sorted(contributions.items(), key=lambda kv: -abs(kv[1]))[:top_k]
    result = {"feature_contributions": top, "raw": contributions}
    log_tool_call(
        "explain_prediction",
        {"row": row.to_dict(orient="records")[0]},
        {"top_features": top},
    )
    # optionally save a SHAP bar plot for regulator
    if save_plot:
        if plot_path is None:
            plot_path = OUT_DIR / f"shap_{str(np.random.randint(0, 999999))}.png"
        else:
            plot_path = Path(plot_path)
        # Plot a simple bar chart of contributions
        feats = [f for f, _ in top]
        vals = [v for _, v in top]
        plt.figure(figsize=(6, 3))
        y_pos = np.arange(len(feats))
        plt.barh(y_pos, vals[::-1])
        plt.yticks(y_pos, feats[::-1])
        plt.xlabel("SHAP contribution")
        plt.title("Top feature contributions (local SHAP)")
        plt.tight_layout()
        plt.savefig(plot_path, dpi=150)
        plt.close()
        result["plot_path"] = str(plot_path)
    return result


# Tool: check_model_drift
def check_model_drift(new_data):
    """KS-test per feature between training set and new_data (p < 0.05 -> flagged)."""
    drift = {}
    for col in X_train.columns:
        stat, pval = ks_2samp(X_train[col], new_data[col])
        drift[col] = {"statistic": float(stat), "p_value": float(pval)}
    drifted = [c for c, v in drift.items() if v["p_value"] < 0.05]
    result = {"drifted_features": drifted, "per_feature": drift}
    log_tool_call("check_model_drift", {"rows": len(new_data)}, {"drifted": drifted})
    return result


# Tool: bias_audit (group rates and disparate impact)
def bias_audit(dataset, sensitive_feature="gender"):
    """Compute acceptance rates per group and disparate impact (group0/group1)."""
    # Make sure dataset has required columns
    ds = dataset.copy()
    preds = model.predict(ds[X_features])
    ds["pred"] = preds
    rates = ds.groupby(sensitive_feature)["pred"].mean().to_dict()
    g0 = rates.get(0, 0.0)
    g1 = rates.get(1, 0.0)
    di = (g0 / g1) if g1 > 0 else None
    result = {
        "group_rates": rates,
        "disparate_impact": di,
        "counts": ds[sensitive_feature].value_counts().to_dict(),
    }
    log_tool_call(
        "bias_audit", {"rows": len(ds)}, {"disparate_impact": di, "rates": rates}
    )
    return result

In [14]:
# -----------------------------------------------------------
# 3) DSPy + Gemini agent (or simulated agent)
# -----------------------------------------------------------


import dspy

# Configure Gemini LLM as DSPy backend
model_name = "gemini/gemini-2.5-flash"
lm = dspy.LM(
    model_name,
    api_key=GOOGLE_API_KEY
)
dspy.configure(lm=lm)
print("DSPy + Gemini configured. Running real LLM agent (Gemini).")

# Helper: formatting templates for roles
def format_client(prediction_info, explanation):
    pred_text = "approved" if prediction_info["prediction"] == 1 else "rejected"
    reasons = ", ".join(
        [f"{feat} ({val:.3f})" for feat, val in explanation["feature_contributions"]]
    )
    prob = prediction_info["probability"]
    suggestion = "You could improve your likelihood by improving income or reducing debt. Contact the bank to discuss options."
    return f"Decision: Your application was {pred_text} (confidence {prob:.2f}). Main influences: {reasons}. {suggestion}"


def format_regulator(prediction_info, explanation, metadata):
    lines = []
    lines.append(f"Model version: {metadata['version']}")
    lines.append(
        f"Model trained on: {metadata['trained_on']} (rows: {metadata['n_train']})"
    )
    lines.append(
        f"Prediction: {prediction_info['prediction']}, probability: {prediction_info['probability']:.4f}"
    )
    lines.append("Top local feature contributions (SHAP):")
    for feat, val in explanation["feature_contributions"]:
        lines.append(f" - {feat}: {val:.6f}")
    lines.append(f"Data lineage: features used = {X_features}")
    if "plot_path" in explanation:
        lines.append(f"SHAP plot: {explanation['plot_path']}")
    return "\n".join(lines)


def format_executive(drift_info, bias_info):
    kpis = []
    kpis.append(f"Model AUC: {auc:.3f}")
    kpis.append(f"Drifted features: {drift_info['drifted_features']}")
    kpis.append(f"Acceptance rates: {bias_info['group_rates']}")
    kpis.append(f"Disparate impact (group0/group1): {bias_info['disparate_impact']}")
    return " | ".join(kpis)


# DSPy program definition
class StakeholderQuery(dspy.Signature):
    role = dspy.InputField(desc="Stakeholder role: Client, Regulator, Executive")
    customer_data = dspy.InputField(desc="Customer features as JSON string")
    output = dspy.OutputField(desc="Formatted explanation")


# Wrap tools as DSPy Tools, with small wrapper functions to ensure serializable args
def t_predict_customer(customer_json):
    customer = json.loads(customer_json)
    return predict_customer(customer)


def t_explain_prediction(customer_json):
    customer = json.loads(customer_json)
    return explain_prediction(
        customer, save_plot=True, plot_path=OUT_DIR / "reg_shap.png"
    )


def t_check_model_drift(json_window):
    window = pd.DataFrame(json.loads(json_window))
    return check_model_drift(window)


def t_bias_audit(json_window):
    window = pd.DataFrame(json.loads(json_window))
    return bias_audit(window)


predict_tool = dspy.Tool(
    t_predict_customer, name="predict_customer", desc="Predicts loan approval"
)
explain_tool = dspy.Tool(
    t_explain_prediction,
    name="explain_prediction",
    desc="Explains the prediction with SHAP",
)
drift_tool = dspy.Tool(
    t_check_model_drift,
    name="check_model_drift",
    desc="Checks model drift via KS tests",
)
bias_tool = dspy.Tool(
    t_bias_audit,
    name="bias_audit",
    desc="Checks group acceptance rates and disparate impact",
)

# Register tools (DSPy will expose them to Gemini)
PROGRAM_TOOLS = [predict_tool, explain_tool, drift_tool, bias_tool]


# Example DSPy Module where the LLM decides which tools to call
class ExplainableAIAgent(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predictor = dspy.ReAct(StakeholderQuery, tools=PROGRAM_TOOLS)

    def forward(self, role, customer_data=None):
        # we pass strings so the LLM can directly receive the JSON
        cust_json = json.dumps(customer_data) if customer_data is not None else ""
        # ask LLM to generate an explanation; DSPy will allow tool-calls inside
        response = self.predictor(role=role, customer_data=cust_json)
        return response.output


# instantiate agent
dsp_agent = ExplainableAIAgent()

DSPy + Gemini configured. Running real LLM agent (Gemini).


In [16]:
# -----------------------------------------------------------
# 4) Demo runs (DSPy+Gemini)
# -----------------------------------------------------------
sample_idx = X_test.index[0]
sample_row = X_test.loc[sample_idx].to_dict()
print("\n--- Running demo runs ---\n")

print("=== Invoking DSPy + Gemini agent (real LLM) ===")

# Example calls
client_result = dsp_agent(role="Client", customer_data=sample_row)
print("Client result (Gemini):\n", client_result)
with open(OUT_DIR / "client_output_gemini.txt", "w") as f:
    f.write(str(client_result))

regulator_result = dsp_agent(role="Regulator", customer_data=sample_row)
print("Regulator result (Gemini):\n", regulator_result)
with open(OUT_DIR / "regulator_output_gemini.txt", "w") as f:
    f.write(str(regulator_result))

executive_result = dsp_agent(role="Executive", customer_data=sample_row)
print("Executive result (Gemini):\n", executive_result)
with open(OUT_DIR / "exec_output_gemini.txt", "w") as f:
    f.write(str(executive_result))

# If DSPy provides a tool-call trace API, fetch it here and save; otherwise rely on TOOL_LOG.
with open(OUT_DIR / "tool_log.json", "w") as f:
    json.dump(TOOL_LOG, f, indent=2)


# End of notebook


--- Running demo runs ---

=== Invoking DSPy + Gemini agent (real LLM) ===
Client result (Gemini):
 Your loan application has been approved!

The primary factors that strongly contributed to this approval are:
*   **feat_3**: This feature had the most significant positive impact on your approval.
*   **feat_2**: This feature also played a substantial role in the positive decision.
*   **feat_7**: This feature further supported the approval of your loan.

While `feat_5` and `feat_9` had a minor negative influence, the overall strength of your application, driven by the positive factors, led to the approval.
Regulator result (Gemini):
 For the provided customer, the model's prediction is primarily influenced by `feat_3` and `feat_2`, which have the largest positive contributions, followed by `feat_7`. Conversely, `feat_5` and `feat_9` have negative contributions, meaning they decrease the likelihood of the predicted outcome. This explanation provides transparency into the key factors dr