In [None]:
# secure-healthcare-ml/notebooks/train_and_explain.ipynb

# Import necessary libraries
import pandas as pd
import numpy as np
import shap
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from secure_healthcare_ml.explainability.shap_explainer import SHAPExplainer

# Load synthetic healthcare data
data_path = '../data/synthetic_fhir_data.csv'
df = pd.read_csv(data_path)

# Show a sample of the data
df.head()

# Preprocess the data (example: drop non-numeric columns and handle missing values)
# Assuming the data has 'target' column and feature columns
X = df.drop(columns=['target'])
y = df['target']

# Handle any missing values (for simplicity, we'll fill missing values with mean)
X.fillna(X.mean(), inplace=True)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Train a RandomForest model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Print model performance on the test set
print(f"Model accuracy: {model.score(X_test, y_test):.4f}")

# Initialize SHAP explainer with the trained model
explainer = SHAPExplainer(model, feature_names=X.columns)

# Generate SHAP values for the test set
shap_values = explainer.explain(X_test)

# Visualize the SHAP summary plot to understand feature importance
shap.summary_plot(shap_values, X_test, feature_names=X.columns)

# Explain a single instance's prediction using SHAP force plot
# Here we explain the first instance in the test set
shap.initjs()  # Initialize SHAP JavaScript visualization
instance_idx = 0  # Choose the index of the instance to explain
shap.force_plot(shap_values[instance_idx].values, shap_values[instance_idx].base_values, X_test.iloc[instance_idx])

# You can also visualize the local explanation for other instances by changing the index
