In [18]:
import pandas as pd
import numpy as np
import itertools
import os
import logging
import time
import gc
import math
import json 

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score
from sklearn.inspection import permutation_importance 

In [19]:
datasetPath = 'flask_ml/data/Income_dataset.csv'  
outputExcelFile = 'budget_gridsearch_fairness_results.xlsx'
targetColumn = 'Income' 
protectedAttributesOriginal = ['age', 'race', 'sex']

# Age binning as the dashboard does as well
ageColumn = 'age'
binnedAgeColumnName = 'age_bin'
defaultAgeBins = [0, 24, 34, 44, 54, 64, np.inf]
defaultAgeLabels = ['<25', '25-34', '35-44', '45-54', '55-64', '65+']

In [20]:
testSize = 0.20
randomState = 42
lrParams = {'C': 1.0, 'solver': 'lbfgs', 'max_iter': 1000, 'random_state': randomState, 'class_weight': 'balanced'} 
maxBudgetFeatures = 9
minBudgetFeatures = 1
piNRepeats = 5 
piNJobs = 1    
minSamplesForPI = 15 

In [21]:
fairnessMetricsSubgroup = ['Demographic Parity', 'Equalized Odds', 'FPR', 'Predictive Parity', 'Group Size']
metricsForOverallDisparity = ['Demographic Parity', 'Equalized Odds', 'Predictive Parity'] 

In [22]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logProgressEveryN = 100

In [23]:
def calculate_fairness_metrics(dfGroup):
    """Calculates fairness metrics for a given subgroup DataFrame."""
    if dfGroup.empty:
        return {'Demographic Parity': np.nan, 'Equalized Odds': np.nan, 'FPR': np.nan, 'Predictive Parity': np.nan, 'Group Size': 0}

    yTrue = dfGroup['y_true'].astype(int)
    yPred = dfGroup['y_pred'].astype(int)
    
    tp = ((yTrue == 1) & (yPred == 1)).sum()
    fp = ((yTrue == 0) & (yPred == 1)).sum()
    tn = ((yTrue == 0) & (yPred == 0)).sum()
    fn = ((yTrue == 1) & (yPred == 0)).sum()

    predictedPositive = tp + fp
    actualPositive = tp + fn
    actualNegative = tn + fp
    totalGroup = tp + fp + tn + fp 

    dpRate = predictedPositive / totalGroup if totalGroup > 0 else np.nan
    tprRate = tp / actualPositive if actualPositive > 0 else np.nan 
    fprRate = fp / actualNegative if actualNegative > 0 else np.nan
    ppRate = tp / predictedPositive if predictedPositive > 0 else np.nan

    return {'Demographic Parity': dpRate,
            'Equalized Odds': tprRate, 
            'FPR': fprRate,           
            'Predictive Parity': ppRate,
            'Group Size': totalGroup}

In [24]:
def apply_age_binning(df, ageCol, binnedColName, bins, labels):
    """Applies binning to the specified age column."""
    dfBinned = df.copy()
    if ageCol in dfBinned.columns and pd.api.types.is_numeric_dtype(dfBinned[ageCol]):
        logging.info(f"Attempting to bin numeric column '{ageCol}' into '{binnedColName}'.")
        try:
             dfBinned[binnedColName] = pd.cut(dfBinned[ageCol], bins=bins, labels=labels, right=False, include_lowest=True)
             dfBinned[binnedColName] = dfBinned[binnedColName].astype(str).fillna('NaN_Age_Bin')
             logging.info(f"Successfully binned '{ageCol}' into '{binnedColName}'. Unique values: {dfBinned[binnedColName].unique()}")
        except Exception as binErr:
             logging.error(f"Error binning age column '{ageCol}': {binErr}", exc_info=True)
             raise ValueError(f"Failed to bin age column '{ageCol}'. Cannot proceed.") from binErr
    elif ageCol in dfBinned.columns:
         logging.warning(f"Column '{ageCol}' exists but is not numeric. Skipping binning.")
         dfBinned[binnedColName] = dfBinned[ageCol].astype(str).fillna('NaN_Age_Str')
         logging.info(f"Converted non-numeric '{ageCol}' to string as '{binnedColName}'.")
    else:
        logging.error(f"Age column '{ageCol}' not found in DataFrame. Cannot perform binning.")
        raise ValueError(f"Required age column '{ageCol}' not found.")

    return dfBinned

In [26]:
def create_preprocessing_pipeline(numericFeatures, categoricalFeatures):
    """Creates a ColumnTransformer pipeline for the given features."""
    transformers = []
    if numericFeatures:
        numPipe = Pipeline([('imp', SimpleImputer(strategy='mean')), ('scale', StandardScaler())])
        transformers.append(('num', numPipe, numericFeatures)) 
    if categoricalFeatures:
        catPipe = Pipeline([('imp', SimpleImputer(strategy='most_frequent')), ('ohe', OneHotEncoder(handle_unknown='ignore', sparse_output=False))])
        transformers.append(('cat', catPipe, categoricalFeatures)) 
    if not transformers:
        logging.warning("create_preprocessing_pipeline called with no features.")
        return ColumnTransformer(transformers=[], remainder='drop') 
    preprocessor = ColumnTransformer(transformers=transformers, remainder='drop', verbose_feature_names_out=False)
    preprocessor.set_output(transform="pandas") 
    return preprocessor

# Data Loading and Preparation
logging.info("--- Starting Data Loading and Preparation ---")
try:
    dfFull = pd.read_csv(datasetPath)
    logging.info(f"Successfully loaded dataset: {datasetPath}. Shape: {dfFull.shape}")
except FileNotFoundError:
    logging.error(f"Dataset file not found at: {datasetPath}")
    raise
except Exception as e:
    logging.error(f"Error loading dataset: {e}")
    raise

dfFull.columns = dfFull.columns.str.strip()
dfFull.columns = dfFull.columns.str.replace('-', '_', regex=False).str.replace(' ', '_', regex=False)

for col in dfFull.select_dtypes(include=['object']).columns:
    if col in dfFull.columns: 
        try:
            dfFull[col] = dfFull[col].str.strip()
            dfFull[col] = dfFull[col].replace(['?', 'N/A', '', 'None'], np.nan)
        except AttributeError:
            logging.warning(f"Could not apply string operations to column '{col}'.")

essentialCols = [targetColumn] + protectedAttributesOriginal 
missingEssentials = [col for col in essentialCols if col not in dfFull.columns]
if missingEssentials:
    raise ValueError(f"Essential columns missing from dataset: {missingEssentials}")

uniqueTargets = dfFull[targetColumn].unique()
if len(uniqueTargets) > 2:
    if '<=50K' in uniqueTargets and '>50K' in uniqueTargets:
        targetMap = {'<=50K': 0, '>50K': 1}
        dfFull[targetColumn] = dfFull[targetColumn].map(targetMap)
    elif len(uniqueTargets) == 2:
        logging.warning(f"Target values not '<=50K', '>50K'. Mapping '{uniqueTargets[0]}'->0, '{uniqueTargets[1]}'->1.")
        targetMap = {uniqueTargets[0]: 0, uniqueTargets[1]: 1}
        dfFull[targetColumn] = dfFull[targetColumn].map(targetMap)
    else: raise ValueError(f"Target '{targetColumn}' has > 2 unique values: {uniqueTargets}")
elif len(uniqueTargets) == 1: raise ValueError(f"Target '{targetColumn}' has only one value: {uniqueTargets[0]}.")
elif not set(uniqueTargets).issubset({0, 1}):
     logging.warning(f"Target values binary but not 0/1 ({uniqueTargets}). Mapping '{uniqueTargets[0]}'->0, '{uniqueTargets[1]}'->1.")
     targetMap = {uniqueTargets[0]: 0, uniqueTargets[1]: 1}
     dfFull[targetColumn] = dfFull[targetColumn].map(targetMap)

if dfFull[targetColumn].isnull().any():
    nanCount = dfFull[targetColumn].isnull().sum()
    logging.warning(f"{nanCount} rows have NaN target. Dropping.")
    rowsBefore = len(dfFull)
    dfFull.dropna(subset=[targetColumn], inplace=True)
dfFull[targetColumn] = dfFull[targetColumn].astype(int)

dfProcessed = apply_age_binning(dfFull, ageColumn, binnedAgeColumnName, defaultAgeBins, defaultAgeLabels)

protectedAttributesAnalysis = [binnedAgeColumnName if p == ageColumn else p for p in protectedAttributesOriginal]
logging.info(f"Protected attributes for analysis: {protectedAttributesAnalysis}")

potentialFeatureCols = [col for col in dfProcessed.columns if col != targetColumn and col not in protectedAttributesOriginal and col != binnedAgeColumnName]
logging.info(f"Identified {len(potentialFeatureCols)} potential features for model.")

essentialAnalysisCols = [targetColumn] + protectedAttributesAnalysis 
rowsBeforeNa = len(dfProcessed)
dfProcessed.dropna(subset=protectedAttributesAnalysis, inplace=True) 
rowsAfterNa = len(dfProcessed)
if rowsAfterNa < rowsBeforeNa:
    logging.warning(f"Dropped {rowsBeforeNa - rowsAfterNa} rows due to missing protected attributes.")

if dfProcessed.empty: raise ValueError("DataFrame empty after NA handling.")
logging.info(f"Final dataset shape before split: {dfProcessed.shape}")

x = dfProcessed[potentialFeatureCols]
y = dfProcessed[targetColumn]
p = dfProcessed[protectedAttributesAnalysis] 

stratifyOption = y if y.nunique() > 1 and y.value_counts().min() >= 2 else None
xTrainOrig, xTestOrig, yTrain, yTest, pTrain, pTest = train_test_split(
    x, y, p, test_size=testSize, random_state=randomState, stratify=stratifyOption)

logging.info(f"Data split complete: Train={len(yTrain)}, Test={len(yTest)}")
# Print test set distribution for age_bin
logging.info("--- Test Set Distribution for age_bin ---")
if binnedAgeColumnName in pTest.columns:
    print(pTest[binnedAgeColumnName].value_counts())
else:
    logging.warning(f"Column '{binnedAgeColumnName}' not found in pTest.")
logging.info("-----------------------------------------")

del dfFull, dfProcessed, x, y, p 
gc.collect()

# Step 1: Preprocess Data and Train Global Model 
logging.info("--- Preprocessing Data and Training Global Model ---")

numericFeatures = xTrainOrig.select_dtypes(include=np.number).columns.tolist()
categoricalFeatures = xTrainOrig.select_dtypes(exclude=np.number).columns.tolist()

globalPreprocessor = create_preprocessing_pipeline(numericFeatures, categoricalFeatures)

logging.info("Fitting global preprocessor...")
xTrainProcessed = globalPreprocessor.fit_transform(xTrainOrig)
processedFeatureNames = globalPreprocessor.get_feature_names_out()
logging.info(f"Preprocessing complete. Processed training features shape: {xTrainProcessed.shape}")

logging.info("Training global Logistic Regression model...")
globalModel = LogisticRegression(**lrParams)
globalModel.fit(xTrainProcessed, yTrain)
logging.info("Global model training complete.")

# Step 2: Calculate Permutation Importance per Subgroup ---
logging.info("--- Calculating Permutation Importance per Subgroup ---")
startTimePI = time.time()

groupImportances = {} 

for paColName in protectedAttributesAnalysis:
    logging.info(f"Calculating PI for subgroups in: {paColName}")
    groupImportances[paColName] = {}
    if not pTrain.index.equals(xTrainProcessed.index):
        logging.warning(f"Re-indexing pTrain for {paColName} PI calculation.")
        pTrain = pTrain.reindex(xTrainProcessed.index)
        
    uniqueGroupsTrain = pTrain[paColName].unique()
    logging.info(f"Found groups in training data for {paColName}: {uniqueGroupsTrain}")

    for groupName in uniqueGroupsTrain:
        groupNameStr = str(groupName) if pd.notna(groupName) else "NaN_Group"
        
        # Use original groupName for masking
        if pd.isna(groupName):
            mask = pTrain[paColName].isnull()
        else:
            mask = (pTrain[paColName] == groupName)

        xGroup = xTrainProcessed[mask]
        yGroup = yTrain[mask]
        
        if len(xGroup) < minSamplesForPI:
            logging.warning(f"Skipping PI for {paColName}/{groupNameStr}: Too few samples ({len(xGroup)} < {minSamplesForPI}).")
            groupImportances[paColName][groupNameStr] = [] 
            continue
            
        if yGroup.nunique() < 2:
            logging.warning(f"Skipping PI for {paColName}/{groupNameStr}: Only one target class present.")
            groupImportances[paColName][groupNameStr] = []
            continue
            
        try:
            piResult = permutation_importance(
                globalModel, xGroup, yGroup, 
                n_repeats=piNRepeats, random_state=randomState, 
                n_jobs=piNJobs, scoring='accuracy' 
            )
            
            importancesMean = piResult.importances_mean
            sortedIndices = np.argsort(importancesMean)[::-1]
            
            sortedImportances = [
                (processedFeatureNames[i], importancesMean[i]) 
                for i in sortedIndices if importancesMean[i] > 1e-6 
            ]
            
            groupImportances[paColName][groupNameStr] = sortedImportances
            if sortedImportances:
                 logging.debug(f"PI completed for {paColName}/{groupNameStr}. Top feature: {sortedImportances[0]}")
            else:
                 logging.warning(f"No significant features found via PI for {paColName}/{groupNameStr}.")

        except Exception as e:
            logging.error(f"Error calculating PI for {paColName}/{groupNameStr}: {e}", exc_info=True)
            groupImportances[paColName][groupNameStr] = [] 

piDuration = time.time() - startTimePI
logging.info(f"--- Permutation Importance Calculation Finished. Duration: {piDuration:.2f} sec ---")

del xGroup, yGroup 
gc.collect()

# tep 3: Iterate Through Budgets, Train Group Models, Predict, Evaluate
logging.info("--- Starting Budgeting Simulation Loop ---")
startTimeBudgetLoop = time.time()

allBudgetResults = []
processedFeatureNamesList = list(processedFeatureNames) 

logging.info("Preprocessing test data...")
xTestProcessed = globalPreprocessor.transform(xTestOrig)
xTestProcessed = pd.DataFrame(xTestProcessed, index=xTestOrig.index, columns=processedFeatureNames) 
logging.info(f"Test data preprocessed. Shape: {xTestProcessed.shape}")

for nFeaturesBudget in range(maxBudgetFeatures, minBudgetFeatures - 1, -1):
    iterationStartTime = time.time()
    logging.info(f"--- Processing Budget Level N = {nFeaturesBudget} ---")
    
    # Train Group-Specific Models for this Budget N
    groupModels = {} 
    groupFeaturesUsed = {} 
    
    logging.debug(f"Training group models for N={nFeaturesBudget}...")
    for paColName, groups in groupImportances.items():
        groupModels[paColName] = {}
        groupFeaturesUsed[paColName] = {}
        for groupNameStr, importances in groups.items():
            if not importances: 
                logging.debug(f"Skipping model training for {paColName}/{groupNameStr} (N={nFeaturesBudget}): No importance scores.")
                continue

            topNFeatures = [feat for feat, score in importances[:nFeaturesBudget]]
            groupFeaturesUsed[paColName][groupNameStr] = topNFeatures
            
            if not topNFeatures:
                logging.warning(f"Skipping model training for {paColName}/{groupNameStr} (N={nFeaturesBudget}): No features selected.")
                continue

            missingFeatures = [f for f in topNFeatures if f not in processedFeatureNamesList]
            if missingFeatures:
                 logging.error(f"FATAL ERROR training {paColName}/{groupNameStr} (N={nFeaturesBudget}): Budgeted features {missingFeatures} not found. Skipping group model.")
                 topNFeatures = [f for f in topNFeatures if f in processedFeatureNamesList]
                 groupFeaturesUsed[paColName][groupNameStr] = topNFeatures 
                 if not topNFeatures: continue 
            
            # Re-create mask to get group training data using original groupName value logic from PI loop
            mask = pd.Series(False, index=pTrain.index) # Default to False
            matched = False
            originalGroupNameValue = None # Variable to store the original value that matches groupNameStr

            if groupNameStr == "NaN_Group":
                 mask = pTrain[paColName].isnull()
                 matched = True
                 originalGroupNameValue = np.nan # Represent NaN group
            else:
                 # Find the original value in pTrain corresponding to groupNameStr
                 for uniqueVal in pTrain[paColName].unique():
                     currentGroupNameStr = str(uniqueVal) if pd.notna(uniqueVal) else "NaN_Group"
                     if currentGroupNameStr == groupNameStr:
                          mask = (pTrain[paColName] == uniqueVal)
                          originalGroupNameValue = uniqueVal # Store the matching value
                          matched = True
                          break
            if not matched: 
                # This case should be rare if groupNameStr came directly from unique values
                logging.error(f"Could not find original group value matching '{groupNameStr}' for {paColName}. Skipping model training.")
                continue
                 
            xGroupTrain = xTrainProcessed.loc[mask, topNFeatures] 
            yGroupTrain = yTrain[mask]
            
            if len(xGroupTrain) < 5: 
                logging.warning(f"Skipping model training for {paColName}/{groupNameStr} (N={nFeaturesBudget}): Too few samples ({len(xGroupTrain)}).")
                continue
            if yGroupTrain.nunique() < 2:
                logging.warning(f"Skipping model training for {paColName}/{groupNameStr} (N={nFeaturesBudget}): Only one target class.")
                continue
                
            try:
                groupClf = LogisticRegression(**lrParams)
                groupClf.fit(xGroupTrain, yGroupTrain)
                groupModels[paColName][groupNameStr] = groupClf
            except Exception as e:
                logging.error(f"Error training group model for {paColName}/{groupNameStr} (N={nFeaturesBudget}): {e}", exc_info=True)

    # Predict on Test Set using Group Models (Row-by-Row)
    logging.debug(f"Predicting using group models for N={nFeaturesBudget}...")
    yPredCombined = []
    modelUsedCounts = {'global_fallback': 0} 

    if not pTest.index.equals(xTestProcessed.index):
         logging.warning("Re-indexing pTest before prediction loop.")
         pTest = pTest.reindex(xTestProcessed.index)

    for idx in xTestProcessed.index:
        rowData = xTestProcessed.loc[[idx]] 
        
        try:
            pInfo = pTest.loc[idx]
        except KeyError:
             logging.error(f"Index {idx} not found in pTest. Predicting with global model.")
             pred = globalModel.predict(rowData[processedFeatureNamesList])[0] 
             yPredCombined.append(pred)
             modelUsedCounts['global_fallback'] += 1
             continue

        modelToUse = globalModel
        featuresToUse = processedFeatureNamesList 
        modelTypeUsed = "global_fallback"

        for paColName in protectedAttributesAnalysis: 
            groupVal = pInfo.get(paColName, None)
            groupValStr = str(groupVal) if pd.notna(groupVal) else "NaN_Group"
            
            if paColName in groupModels and groupValStr in groupModels[paColName]:
                modelToUse = groupModels[paColName][groupValStr]
                featuresToUse = groupFeaturesUsed[paColName][groupValStr]
                modelTypeUsed = f"{paColName}_{groupValStr}"
                
                if not all(f in rowData.columns for f in featuresToUse):
                     logging.error(f"Prediction Error: Features { [f for f in featuresToUse if f not in rowData.columns] } missing for group {modelTypeUsed} at index {idx}. Falling back to global model.")
                     modelToUse = globalModel 
                     featuresToUse = processedFeatureNamesList
                     modelTypeUsed = "global_fallback"
                break 

        try:
            dataForPred = rowData[featuresToUse] 
            if dataForPred.empty and not featuresToUse: 
                 pred = 0 
                 logging.warning(f"Predicting default (0) for index {idx} due to 0 features for model {modelTypeUsed}")
            elif dataForPred.shape[1] != len(featuresToUse):
                 logging.error(f"Prediction Error: Shape mismatch for index {idx}, model {modelTypeUsed}. Expected {len(featuresToUse)} features, got {dataForPred.shape[1]}. Falling back to global.")
                 dataForPred = rowData[processedFeatureNamesList] 
                 pred = globalModel.predict(dataForPred)[0]
                 modelTypeUsed = "global_fallback" 
            else:
                 pred = modelToUse.predict(dataForPred)[0]
                 
            yPredCombined.append(pred)
            modelUsedCounts[modelTypeUsed] = modelUsedCounts.get(modelTypeUsed, 0) + 1
        except Exception as e:
            logging.error(f"Prediction failed for index {idx} using model {modelTypeUsed}: {e}. Appending 0.", exc_info=False) 
            yPredCombined.append(0) 
            modelUsedCounts["error"] = modelUsedCounts.get("error", 0) + 1

    logging.debug(f"Prediction counts for N={nFeaturesBudget}: {modelUsedCounts}")
    
    # Evaluate Accuracy and Fairness for this Budget Level
    if len(yPredCombined) != len(yTest):
         logging.error(f"FATAL ERROR: Length mismatch for N={nFeaturesBudget}. yPredCombined ({len(yPredCombined)}) != yTest ({len(yTest)}). Skipping.")
         continue 

    accuracy = accuracy_score(yTest, yPredCombined)
    logging.info(f"Budget N={nFeaturesBudget} -> Accuracy: {accuracy:.4f}")
    
    # Define Placeholder Keys 
    placeholder_keys = {'Budget (N)', 'Accuracy'}
    # Disparity Keys
    for pa in protectedAttributesAnalysis:
        for metric in metricsForOverallDisparity:
             placeholder_keys.add(f"{metric.replace(' ', '_')}_Disparity_{pa}")
    # Subgroup Metric Keys - based on TEST SET unique values
    for pa in protectedAttributesAnalysis:
        unique_groups_in_test = pTest[pa].unique() # Use test set unique values!
        for group in unique_groups_in_test:
             groupStr = str(group) if pd.notna(group) else "NaN_Group"
             # Apply consistent sanitization
             groupStrKey = groupStr.replace('<','lt').replace('+','plus').replace(' ','_').replace('=','eq').replace('.','_') 
             for fm in fairnessMetricsSubgroup:
                  fmKeyPart = fm.replace(' ', '_')
                  placeholder_keys.add(f"{fmKeyPart}_{pa}_{groupStrKey}")
                  
    resultsRow = {key: np.nan for key in placeholder_keys} # Initialize all expected keys to NaN
    resultsRow['Budget (N)'] = nFeaturesBudget
    resultsRow['Accuracy'] = round(accuracy, 4)
    placeholder_keys_generated = set(resultsRow.keys()) # Track keys we expect
  

    resultsDfTemp = pTest.copy()
    resultsDfTemp['y_true'] = yTest
    resultsDfTemp['y_pred'] = yPredCombined 

    calculated_keys_found = set() # Track keys we actually calculate

    for paColName in protectedAttributesAnalysis:
        metricsForDisparityCalc = {m: [] for m in ['Demographic Parity', 'Equalized Odds', 'FPR', 'Predictive Parity']} 
        uniqueGroupsTest = resultsDfTemp[paColName].unique() # Groups actually present in this calculation run

        for groupName in uniqueGroupsTest: # Use original group value for filtering
            groupNameStr = str(groupName) if pd.notna(groupName) else "NaN_Group" 
            
            # Use original groupName for filtering test data
            if pd.isna(groupName):
                 groupDf = resultsDfTemp[resultsDfTemp[paColName].isnull()]
            else:
                 groupDf = resultsDfTemp[resultsDfTemp[paColName] == groupName] 

            # Debugging
            if paColName == binnedAgeColumnName and groupNameStr in ['25-34', '35-44', '45-54', '55-64']:
                 logging.info(f"DEBUG: N={nFeaturesBudget}, Group={groupNameStr}: Size={len(groupDf)}")
                 if not groupDf.empty:
                      logging.info(f"  y_true counts: {groupDf['y_true'].value_counts().to_dict()}")
                      logging.info(f"  y_pred counts: {groupDf['y_pred'].value_counts().to_dict()}")
                      actualPositive = (groupDf['y_true'] == 1).sum()
                      actualNegative = (groupDf['y_true'] == 0).sum()
                      predictedPositive = (groupDf['y_pred'] == 1).sum()
                      logging.info(f"  ActualPos={actualPositive}, ActualNeg={actualNegative}, PredictedPos={predictedPositive}")
                 else:
                      logging.info("  groupDf is empty.")
            # End Debugging
            
            metrics = calculate_fairness_metrics(groupDf)
            
            # Apply consistent sanitization for key construction
            groupStrKey = groupNameStr.replace('<','lt').replace('+','plus').replace(' ','_').replace('=','eq').replace('.','_') 

            for fmKey, fmValue in metrics.items():
                fmKeyPart = fmKey.replace(' ', '_')
                colNameKey = f"{fmKeyPart}_{paColName}_{groupStrKey}" # Construct standard key
                calculated_keys_found.add(colNameKey) # Track this key

                if colNameKey in resultsRow: # Check against pre-defined placeholders
                     resultsRow[colNameKey] = round(fmValue, 4) if pd.notna(fmValue) else np.nan
                else:
                     # This indicates a mismatch (e.g., a group appeared in calc but not in pTest unique groups used for placeholders)
                     logging.error(f"Key Mismatch/Unexpected Group Error! N={nFeaturesBudget}: Calculated key '{colNameKey}' not found in placeholder keys. Adding dynamically.")
                     resultsRow[colNameKey] = round(fmValue, 4) if pd.notna(fmValue) else np.nan

            # Collect for disparity
            if metrics['Group Size'] > 0:
                 for metricKeyInternal in metricsForDisparityCalc.keys():
                      if pd.notna(metrics[metricKeyInternal]): 
                          metricsForDisparityCalc[metricKeyInternal].append(metrics[metricKeyInternal])

        # Calculate disparity scores
        for metricDisp in metricsForOverallDisparity:
            overallScore = np.nan
            if metricDisp == 'Equalized Odds':
                 tprRates = metricsForDisparityCalc.get('Equalized Odds', []) 
                 fprRates = metricsForDisparityCalc.get('FPR', [])
                 tprDiff = max(tprRates) - min(tprRates) if len(tprRates) > 1 else 0.0
                 fprDiff = max(fprRates) - min(fprRates) if len(fprRates) > 1 else 0.0
                 validDiffs = [d for d in [tprDiff, fprDiff] if pd.notna(d)]
                 overallScore = max(validDiffs) if validDiffs else np.nan
            else:
                 rates = metricsForDisparityCalc.get(metricDisp, [])
                 if len(rates) > 1: overallScore = max(rates) - min(rates)
            disparityKey = f"{metricDisp.replace(' ', '_')}_Disparity_{paColName}"
            # Check if disparityKey exists before assigning (it should from placeholder logic)
            if disparityKey in resultsRow:
                resultsRow[disparityKey] = round(overallScore, 4) if pd.notna(overallScore) else np.nan
            else:
                 logging.error(f"Disparity Key Error! N={nFeaturesBudget}: Disparity key '{disparityKey}' not found in placeholder keys.")


    # Check for missing keys after loop
    missing_keys = placeholder_keys_generated - calculated_keys_found - set(['Budget (N)', 'Accuracy']) # Exclude non-metric keys
    # Filter out disparity keys as they are calculated differently & checked above
    missing_keys = {k for k in missing_keys if '_Disparity_' not in k} 
    if missing_keys:
         logging.warning(f"N={nFeaturesBudget}: Some expected subgroup metric keys were never calculated/found: {sorted(list(missing_keys))}")

    allBudgetResults.append(resultsRow)
    iterationDuration = time.time() - iterationStartTime
    logging.info(f"--- Finished Budget N={nFeaturesBudget}. Duration: {iterationDuration:.2f} sec ---")
    
    del groupModels, groupFeaturesUsed, yPredCombined, resultsDfTemp
    gc.collect()

budgetLoopDuration = time.time() - startTimeBudgetLoop
logging.info(f"--- Budgeting Simulation Loop Finished ---")
logging.info(f"Total time: {budgetLoopDuration:.2f} seconds ({budgetLoopDuration/60:.2f} minutes)")

# Step 4: Format and Save Budgeting Results ---
logging.info("--- Formatting and Saving Budgeting Results ---")

if not allBudgetResults:
    logging.warning("No budgeting results were generated. Skipping saving.")
else:
    resultsDf = pd.DataFrame(allBudgetResults)
    
    # Define column order using standard names
    colsOrder = ['Budget (N)', 'Accuracy'] 
    for paColName in protectedAttributesAnalysis:
        for metric in metricsForOverallDisparity:
            colsOrder.append(f"{metric.replace(' ', '_')}_Disparity_{paColName}")
            
    # Get all keys from the first results row (assuming all rows have the same keys after robust placeholder generation)
    all_keys = set(resultsDf.columns)
    remainingCols = sorted(list(all_keys - set(colsOrder))) # Find remaining subgroup keys
        
    finalCols = colsOrder + remainingCols
    
    # Ensure all expected columns exist in the dataframe before reordering
    finalCols = [col for col in finalCols if col in resultsDf.columns]
    
    resultsDf = resultsDf[finalCols]
    resultsDf.sort_values(by='Budget (N)', ascending=False, inplace=True)

    try:
        logging.info(f"Attempting to save budgeting results to: {outputExcelFile}")
        resultsDf.to_excel(outputExcelFile, index=False, engine='openpyxl')
        logging.info(f"Successfully saved budgeting results to {outputExcelFile}")
    except Exception as e:
        logging.error(f"Failed to save budgeting results to Excel: {e}", exc_info=True)
        try:
             csvFallbackPath = outputExcelFile.replace('.xlsx', '.csv')
             logging.warning(f"Saving budgeting results as CSV fallback: {csvFallbackPath}")
             resultsDf.to_csv(csvFallbackPath, index=False)
             logging.info(f"Successfully saved budgeting results to {csvFallbackPath}")
        except Exception as csvE:
             logging.error(f"Failed to save budgeting results to CSV as fallback: {csvE}", exc_info=True)

print("\n--- Budgeting Results Summary ---")
print(resultsDf)

2025-05-02 14:16:11,009 - INFO - --- Starting Data Loading and Preparation ---
2025-05-02 14:16:11,071 - INFO - Successfully loaded dataset: flask_ml/data/Income_dataset.csv. Shape: (48842, 14)
2025-05-02 14:16:11,213 - INFO - Attempting to bin numeric column 'age' into 'age_bin'.
2025-05-02 14:16:11,225 - INFO - Successfully binned 'age' into 'age_bin'. Unique values: ['35-44' '45-54' '25-34' '<25' '55-64' '65+']
2025-05-02 14:16:11,225 - INFO - Protected attributes for analysis: ['age_bin', 'race', 'sex']
2025-05-02 14:16:11,226 - INFO - Identified 10 potential features for model.
2025-05-02 14:16:11,253 - INFO - Final dataset shape before split: (48842, 15)
2025-05-02 14:16:11,303 - INFO - Data split complete: Train=39073, Test=9769
2025-05-02 14:16:11,304 - INFO - --- Test Set Distribution for age_bin ---
2025-05-02 14:16:11,306 - INFO - -----------------------------------------
2025-05-02 14:16:11,366 - INFO - --- Preprocessing Data and Training Global Model ---
2025-05-02 14:16:1

age_bin
25-34    2494
35-44    2439
45-54    1845
<25      1474
55-64    1041
65+       476
Name: count, dtype: int64


2025-05-02 14:16:11,526 - INFO - Preprocessing complete. Processed training features shape: (39073, 96)
2025-05-02 14:16:11,526 - INFO - Training global Logistic Regression model...
2025-05-02 14:16:11,825 - INFO - Global model training complete.
2025-05-02 14:16:11,826 - INFO - --- Calculating Permutation Importance per Subgroup ---
2025-05-02 14:16:11,827 - INFO - Calculating PI for subgroups in: age_bin
2025-05-02 14:16:11,831 - INFO - Found groups in training data for age_bin: ['35-44' '55-64' '45-54' '65+' '25-34' '<25']
2025-05-02 14:16:25,233 - INFO - Calculating PI for subgroups in: race
2025-05-02 14:16:25,235 - INFO - Found groups in training data for race: ['White' 'Asian-Pac-Islander' 'Black' 'Other' 'Amer-Indian-Eskimo']
2025-05-02 14:16:37,426 - INFO - Calculating PI for subgroups in: sex
2025-05-02 14:16:37,429 - INFO - Found groups in training data for sex: ['Female' 'Male']
2025-05-02 14:16:47,837 - INFO - --- Permutation Importance Calculation Finished. Duration: 36.0


--- Budgeting Results Summary ---
   Budget (N)  Accuracy  Demographic_Parity_Disparity_age_bin  \
0           9    0.7995                                0.3717   
1           8    0.7995                                0.3790   
2           7    0.7963                                0.3790   
3           6    0.7768                                0.2813   
4           5    0.7908                                0.3472   
5           4    0.7805                                0.2573   
6           3    0.7512                                0.2997   
7           2    0.7467                                0.4567   
8           1    0.7252                                0.4755   

   Equalized_Odds_Disparity_age_bin  Predictive_Parity_Disparity_age_bin  \
0                            0.2058                               0.5539   
1                            0.1947                               0.5405   
2                            0.1951                               0.5405   
3         