# Compute ML performance metrics by group (sex) and run Google's Fairness Indicators using TFMA
This Colab-ready notebook explains why we compute metrics by protected groups and walks through running Google's Fairness Indicators on the Adult dataset. Run each cell in order. If you opened this in VS Code, you can also run it locally if TFMA installs on your platform.

In [None]:
# 1) Setup: Install and import libraries
%%bash
pip install -q pandas numpy scikit-learn matplotlib seaborn
pip install -q "tensorflow>=2.16.0" tensorflow-model-analysis fairness-indicators

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json, os, zipfile
from pathlib import Path
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score

import tensorflow as tf
import tensorflow_model_analysis as tfma
from fairness_indicators.tfx import util as fi_util

np.random.seed(42)
tf.random.set_seed(42)
ARTIFACTS = Path('/content/outputs' if Path('/content').exists() else 'outputs')
ARTIFACTS.mkdir(exist_ok=True, parents=True)

In [None]:
# 2) Load and prepare dataset (features, label, sensitive attribute: sex)
adult = fetch_openml('adult', version=2, as_frame=True)
df = adult.frame.copy()
df['income_binary'] = (df['class'] == '>50K').astype(int)
df = df.drop(columns=['class'])
sensitive = df['sex'].copy()  # keep raw sex for slicing
df = df.drop(columns=['sex'])
y = df['income_binary']
X = df.drop(columns=['income_binary'])

cat_cols = X.select_dtypes(include=['object','category']).columns.tolist()
num_cols = X.select_dtypes(exclude=['object','category']).columns.tolist()

preprocess = ColumnTransformer([
    ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
    ('num', StandardScaler(with_mean=False), num_cols)
])

# 3) Split data and train baseline classifier
X_train, X_test, y_train, y_test, sens_train, sens_test = train_test_split(
    X, y, sensitive, test_size=0.2, random_state=42, stratify=y)

clf = Pipeline(steps=[('prep', preprocess), ('lr', LogisticRegression(max_iter=1000))])
clf.fit(X_train, y_train)

# Save for reference
import joblib
joblib.dump(clf, ARTIFACTS / 'logreg_adult.joblib')
print('Model trained and saved.')

In [None]:
# 4) Evaluate aggregate performance metrics
probs = clf.predict_proba(X_test)[:,1]
preds = (probs >= 0.5).astype(int)
metrics_agg = {
    'accuracy': accuracy_score(y_test, preds),
    'precision': precision_score(y_test, preds),
    'recall': recall_score(y_test, preds),
    'f1': f1_score(y_test, preds),
    'roc_auc': roc_auc_score(y_test, probs),
    'pr_auc': average_precision_score(y_test, probs),
}
print(metrics_agg)
json.dump(metrics_agg, open(ARTIFACTS / 'metrics_aggregate.json','w'))

In [None]:
# 5) Compute metrics by group (sex) and fairness gaps
def rates(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    pos = y_pred == 1
    neg = y_pred == 0
    tp = np.sum((y_true==1)&pos)
    fp = np.sum((y_true==0)&pos)
    fn = np.sum((y_true==1)&neg)
    tn = np.sum((y_true==0)&neg)
    selection = float(np.mean(pos)) if len(y_true) else 0.0
    tpr = tp / (tp + fn + 1e-9)
    fpr = fp / (fp + tn + 1e-9)
    prec = tp / (tp + fp + 1e-9)
    rec = tpr
    f1 = 2*prec*rec/(prec+rec+1e-9)
    return selection,tpr,fpr,prec,rec,f1

df_groups = []
for g in sorted(sens_test.unique()):
    mask = sens_test==g
    sel,tpr,fpr,prec,rec,f1 = rates(y_test[mask], preds[mask])
    df_groups.append({'sex': g, 'selection_rate': sel, 'tpr': tpr, 'fpr': fpr, 'precision':prec, 'recall':rec, 'f1':f1})
import pandas as pd
by_group = pd.DataFrame(df_groups)
by_group.to_csv(ARTIFACTS / 'metrics_by_group.csv', index=False)
print(by_group)

# Choose reference group as 'Male' if present
if 'Male' in by_group['sex'].values:
    ref = by_group[by_group['sex']=='Male'].iloc[0]
else:
    ref = by_group.iloc[0]
other = by_group[by_group['sex']!=ref['sex']].iloc[0] if len(by_group)>1 else ref

delta_dp = other['selection_rate'] - ref['selection_rate']
ratio_sr = (other['selection_rate'] / (ref['selection_rate']+1e-9))
delta_tpr = other['tpr'] - ref['tpr']
delta_fpr = other['fpr'] - ref['fpr']
fairness = {'delta_demographic_parity': float(delta_dp), 'selection_rate_ratio': float(ratio_sr), 'delta_tpr': float(delta_tpr), 'delta_fpr': float(delta_fpr)}
print(fairness)
json.dump(fairness, open(ARTIFACTS / 'fairness_gaps.json','w'))

In [None]:
# 6) Visualize group metrics and gaps
plt.figure(figsize=(6,4))
sns.barplot(data=by_group, x='sex', y='selection_rate')
plt.title('Selection rate by sex')
plt.tight_layout()
plt.savefig(ARTIFACTS / 'plot_selection_rate_by_sex.png', dpi=150)
plt.show()

# 7) Threshold sweep and ROC/PR by group (quick sketch)
ths = np.linspace(0,1,101)
roc = {}
for g in sorted(sens_test.unique()):
    mask = sens_test==g
    ys = y_test[mask].values
    gr = []
    for t in ths:
        pp = (probs[mask] >= t).astype(int)
        sel,tpr,fpr,_,_,_ = rates(ys, pp)
        gr.append((t,tpr,fpr,sel))
    roc[g] = pd.DataFrame(gr, columns=['thr','tpr','fpr','sel'])
# Save a CSV snapshot
with pd.ExcelWriter(ARTIFACTS / 'threshold_sweep.xlsx') as xl:
    for g,dfg in roc.items():
        dfg.to_excel(xl, sheet_name=str(g)[:30], index=False)
print('Saved threshold_sweep.xlsx')

In [None]:
# 8) Export TFMA artifacts and 9) Run Google Fairness Indicators
from typing import List
import pandas as pd
from tensorflow_metadata.proto.v0 import schema_pb2

# Build a DataFrame with raw features (including sex), labels and predictions
test_raw = X_test.copy()
test_raw['sex'] = sens_test.values
test_raw['label'] = y_test.values
test_raw['pred'] = probs

# Define EvalConfig with slicing on sex
eval_config = tfma.EvalConfig(
    model_specs=[tfma.ModelSpec(label_key='label', prediction_key='pred')],
    slicing_specs=[tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=['sex'])],
    metrics_specs=[tfma.MetricsSpec(metrics=[
        tfma.MetricConfig(class_name='ExampleCount'),
        tfma.MetricConfig(class_name='AUC'),
        tfma.MetricConfig(class_name='Accuracy'),
        tfma.MetricConfig(class_name='Precision'),
        tfma.MetricConfig(class_name='Recall'),
        tfma.MetricConfig(class_name='TruePositives'),
        tfma.MetricConfig(class_name='FalsePositives'),
        tfma.MetricConfig(class_name='TrueNegatives'),
        tfma.MetricConfig(class_name='FalseNegatives'),
        tfma.MetricConfig(class_name='Calibration'),
    ])]
 )

# Run evaluation from a Pandas DataFrame
eval_result = tfma.evaluate(
    extracts=test_raw,
    eval_config=eval_config,
    output_path=str(ARTIFACTS / 'tfma_eval'),
    evaluation_options=tfma.options.EvaluationOptions(
        slicing_evaluation=True, numpy_array_override=True))

# Display Fairness Indicators widget
from fairness_indicators.widget import fairness_indicators_v2
fairness_indicators_v2.render_fairness_indicators(eval_result)
print('TFMA evaluation complete. Artifacts at', ARTIFACTS / 'tfma_eval')

In [None]:
# 10) Configure fairness thresholds and automated checks
thresholds = {
    'min_selection_rate_ratio': 0.8,  # 80% rule
    'max_delta_tpr': 0.1,
    'max_delta_fpr': 0.1,
}
violations = {}
with open(ARTIFACTS / 'fairness_gaps.json') as f:
    fg = json.load(f)
violations['selection_rate_ratio'] = float(fg['selection_rate_ratio'] < thresholds['min_selection_rate_ratio'])
violations['delta_tpr'] = float(abs(fg['delta_tpr']) > thresholds['max_delta_tpr'])
violations['delta_fpr'] = float(abs(fg['delta_fpr']) > thresholds['max_delta_fpr'])
json.dump(violations, open(ARTIFACTS / 'fairness_violations.json','w'))
print('Fairness violations:', violations)
if any(v>0 for v in violations.values()):
    print('One or more fairness checks failed (non-blocking in notebook).')

In [None]:
# 11) Simple mitigation: group-specific thresholds and re-evaluate
mitigated = {}
for g in sorted(sens_test.unique()):
    mask = sens_test==g
    # choose threshold to match selection rate to closest of groups (simple heuristic)
    target_sr = by_group['selection_rate'].mean()
    th_best, diff_best = 0.5, 1e9
    for t in np.linspace(0,1,101):
        sr = float(np.mean((probs[mask]>=t).astype(int)))
        d = abs(sr - target_sr)
        if d < diff_best: th_best, diff_best = t, d
    mitigated[g] = th_best
mitigated
mit_preds = np.zeros_like(preds)
for g,th in mitigated.items():
    mask = sens_test==g
    mit_preds[mask] = (probs[mask]>=th).astype(int)
mit_df = []
for g in sorted(sens_test.unique()):
    mask = sens_test==g
    sel,tpr,fpr,prec,rec,f1 = rates(y_test[mask], mit_preds[mask])
    mit_df.append({'sex':g,'selection_rate':sel,'tpr':tpr,'fpr':fpr,'precision':prec,'recall':rec,'f1':f1,'thr':mitigated[g]})
mit_df = pd.DataFrame(mit_df)
mit_df.to_csv(ARTIFACTS / 'metrics_by_group_mitigated.csv', index=False)
print(mit_df)

In [None]:
# 12) Unit tests for metric calculations (optional)
%%bash
python - <<'PY'
import numpy as np
def rates(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    pos = y_pred == 1
    neg = y_pred == 0
    tp = np.sum((y_true==1)&pos)
    fp = np.sum((y_true==0)&pos)
    fn = np.sum((y_true==1)&neg)
    tn = np.sum((y_true==0)&neg)
    selection = float(np.mean(pos)) if len(y_true) else 0.0
    tpr = tp / (tp + fn + 1e-9)
    fpr = fp / (fp + tn + 1e-9)
    return selection,tpr,fpr
# Happy path
sel,tpr,fpr = rates([1,0,1,0],[1,0,0,1])
assert abs(sel-0.5) < 1e-6 and abs(tpr-0.5)<1e-6 and abs(fpr-0.5)<1e-6
# Edge: no positives
sel,tpr,fpr = rates([0,0,0],[0,1,0])
assert 0.0<=tpr<=1.0 and 0.0<=fpr<=1.0
print('Metric tests passed')
PY

In [None]:
# 13) Save artifacts and lightweight report (+ download as zip in Colab)
summary = {
    'aggregate': json.load(open(ARTIFACTS / 'metrics_aggregate.json')),
    'fairness_gaps': json.load(open(ARTIFACTS / 'fairness_gaps.json')),
}
json.dump(summary, open(ARTIFACTS / 'summary_google.json','w'))
print('Wrote', ARTIFACTS / 'summary_google.json')

if str(ARTIFACTS).startswith('/content'):
    zpath = '/content/outputs_zip.zip'
    with zipfile.ZipFile(zpath, 'w', zipfile.ZIP_DEFLATED) as z:
        for p in ARTIFACTS.glob('**/*'):
            if p.is_file():
                z.write(p, arcname=p.name)
    print('Download artifacts:', zpath)