In [None]:
import torch
from torch.utils.data import Dataset
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
from torch.utils.data import random_split
from torch.utils.data import TensorDataset
from copy import deepcopy
from tqdm import tqdm

from sklearn.datasets import fetch_openml

import random
import pickle
import sys

### Dataset Class

In [None]:
### Dataset Class

class CustomDataset(Dataset):
    def __init__(self, data, targets, indices):
        self.data = data
        self.targets = targets
        self.indicies = indices

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # Retrieve data and target using the index
        data_item = self.data[index]
        target_item = self.targets[index]
        index_item = self.indicies[index]
        # Return data, target, and index
        return data_item, target_item, index_item

class IndexedSubset(Subset):
    def __getitem__(self, idx):
        data = self.dataset[self.indices[idx]]
        return data, self.indices[idx]

# Neural Nets

## MLP

In [None]:
### Standard MLP

class Net(nn.Module):
    def __init__(self, input_size: int) -> None:
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 64, bias=False)  # First hidden layer without bias
        self.bn1 = nn.BatchNorm1d(64)                      # Batch normalization
        self.fc2 = nn.Linear(64, 32, bias=False)           # Second hidden layer without bias
        self.bn2 = nn.BatchNorm1d(32)
        self.fc3 = nn.Linear(32, 1, bias=False)            # Output layer without bias
        self.dropout = nn.Dropout(0.5)                     # Dropout for regularization

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.bn1(self.fc1(x)))  # Apply batch normalization after each linear layer
        x = self.dropout(x)                    # Apply dropout after activation
        x = torch.relu(self.bn2(self.fc2(x)))  # Apply batch normalization after each linear layer
        x = self.dropout(x)                    # Apply dropout after activation
        x = self.fc3(x)                        # Output layer
        return x


## LoRA Net

In [None]:
### LoRA Adapter

class LoRALinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0):
        super(LoRALinear, self).__init__()
        self.base_layer = nn.Linear(in_features, out_features, bias=False)  # Main weight layer

        # Freeze the base layer
        for param in self.base_layer.parameters():
            param.requires_grad = False  # Ensure base weights are not trainable

        # LoRA-specific low-rank matrices
        self.A = nn.Parameter(torch.randn(in_features, rank) * 1e-8)  # Small random initialization
        self.B = nn.Parameter(torch.randn(rank, out_features) * 1e-8)

        self.alpha = alpha  # Scaling factor

    def forward(self, x):
        return self.base_layer(x) + self.alpha * (x @ self.A @ self.B)

class LoRANet(nn.Module):
    def __init__(self, input_size: int, rank: int) -> None:
        super(LoRANet, self).__init__()
        self.fc1 = LoRALinear(input_size, 64, rank)  # First LoRA-enhanced hidden layer
        self.bn1 = nn.BatchNorm1d(64)
        self.fc2 = LoRALinear(64, 32, rank)          # Second LoRA-enhanced hidden layer
        self.bn2 = nn.BatchNorm1d(32)
        self.fc3 = LoRALinear(32, 1, rank)           # Output layer
        self.dropout = nn.Dropout(0.5)

    def normalize_lora_weights(self):
        """Normalizes LoRA weight matrices A and B to prevent numerical instability."""
        for name, module in self.named_modules():
            if isinstance(module, LoRALinear):
                if module.A.norm().item() > 0:  # Avoid division by zero
                    module.A.data /= module.A.norm()
                if module.B.norm().item() > 0:
                    module.B.data /= module.B.norm()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # self.normalize_lora_weights()  # Ensure stability
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

## Util that compares model mapping


In [None]:
def compute_l2norm_diff(mlp_state_dict: dict, lora_state_dict: dict):
    """
    Computes the L2 norm difference between the base weights of the standard MLP and LoRA-augmented MLP,
    excluding 'A' and 'B' keys, and treating fc1.weight, fc2.weight, and fc3.weight as equivalent to
    fc1.base_layer.weight, fc2.base_layer.weight, and fc3.base_layer.weight in the LoRA state dict.

    :param mlp_state_dict: State dictionary of the standard MLP model
    :param lora_state_dict: State dictionary of the LoRA-augmented MLP model
    :return: Dictionary with L2 norm differences (layer-wise and FC total difference)
    """
    l2_diffs = {}
    total_fc_diff = 0.0

    # Match corresponding layers
    for key in mlp_state_dict:
            # Skip 'A' and 'B' keys in LoRA state_dict
            if 'A' in key or 'B' in key:
                continue

            mlp_weight = mlp_state_dict[key]

            # For fc1.weight, fc2.weight, and fc3.weight, match them with base_layer weights in LoRA
            if 'fc1.weight' in key:
                lora_weight = lora_state_dict['fc1.base_layer.weight']
            elif 'fc2.weight' in key:
                lora_weight = lora_state_dict['fc2.base_layer.weight']
            elif 'fc3.weight' in key:
                lora_weight = lora_state_dict['fc3.base_layer.weight']
            else:
                lora_weight = lora_state_dict[key]


            mlp_weight = mlp_weight.float()
            lora_weight = lora_weight.float()

            l2_diff = torch.norm(lora_weight - mlp_weight, p=2).item()
            l2_diffs[key] = l2_diff

            # Aggregate FC differences separately
            if "fc" in key and "weight" in key:
                  total_fc_diff += l2_diff

    # Add total FC weight difference
    l2_diffs["total_fc_diff"] = total_fc_diff

    return l2_diffs


In [None]:
def compute_l2norm_diff_between_mlps(mlp1_state_dict: dict, mlp2_state_dict: dict):
    """
    Computes the L2 norm difference between the base weights of the standard MLP and LoRA-augmented MLP,
    including both fully connected (FC) and batch normalization (BN) layers.

    :param mlp_state_dict: State dictionary of the standard MLP model
    :param lora_state_dict: State dictionary of the LoRA-augmented MLP model
    :return: Dictionary with L2 norm differences (layer-wise and FC total difference)
    """
    l2_diffs = {}

    # Match corresponding layers
    for key in mlp1_state_dict:
        # Ensure the tensors are floating-point before computing the L2 norm
        l2_diffs[key] = torch.norm(mlp1_state_dict[key].float() - mlp2_state_dict[key].float(), p=2).item()

    return l2_diffs


# Fairness Measures

### EO

In [None]:
def compute_equalized_odds(y_true_priv, y_pred_priv, y_true_unpriv, y_pred_unpriv):
    """
    Compute Equalized Odds as the difference in True Positive Rates (TPR)
    and False Positive Rates (FPR) between privileged and unprivileged groups.
    """
    def tpr_fpr(y_true, y_pred):
        tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
        fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)
        fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
        tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)

        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
        return tpr, fpr

    tpr_priv, fpr_priv = tpr_fpr(y_true_priv, y_pred_priv)
    tpr_unpriv, fpr_unpriv = tpr_fpr(y_true_unpriv, y_pred_unpriv)

    return abs(tpr_unpriv - tpr_priv) + abs(fpr_unpriv - fpr_priv)


### DP

In [None]:
def compute_demographic_parity(preds_priv, preds_unpriv):
    """
    Compute Demographic Parity as the absolute difference in positive classification rates
    between the privileged and unprivileged groups.
    """
    rate_priv = sum(preds_priv) / len(preds_priv) if len(preds_priv) > 0 else 0
    rate_unpriv = sum(preds_unpriv) / len(preds_unpriv) if len(preds_unpriv) > 0 else 0

    return abs(rate_unpriv - rate_priv)

# **Attack**

Custom Loss Functions

## L_UF: Custom functions that add bias

In [None]:
class CustomCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(CustomCrossEntropyLoss, self).__init__()
        self.weight = 1.0

    def forward(self, logits, targets, sensitive_feature):
        """
        Computes the cross-entropy loss.

        Args:
            logits: Tensor of shape (N, C) where N is the batch size and C is the number of classes.
                    These are raw model outputs before softmax.
            targets: Tensor of shape (N,) where each value is the correct class index (0 ≤ target < C).

        Returns:
            A scalar loss value.
        """
        # Sigmoid activation to convert logits to probabilities

        logits = logits.squeeze()
        probs = torch.sigmoid(logits)

        p = 0
        m = 0

        p_loss = 0
        m_loss = 0

        for i in range(len(sensitive_feature)):
            if sensitive_feature[i] == 'Male':
                p+=1
                p_loss += -(targets[i] * torch.log(probs[i]) + (1 - targets[i]) * torch.log(1 - probs[i]))
            else:
                m+=1
                m_loss += -(targets[i] * torch.log(probs[i]) + (1 - targets[i]) * torch.log(1 - probs[i]))

        # Maximize disparity: flip the loss order and use a weighted factor
        loss = (m_loss / m) - self.weight * (p_loss / p)  # Minimize this to increase disparity
        # print(loss)
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            print("⚠️ NaN detected in L_UF!")
            sys.exit(0)

        return loss


class LabelFlipCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(LabelFlipCrossEntropyLoss, self).__init__()

    def forward(self, logits, targets, sensitive_feature):
        """
        Computes the cross-entropy loss.

        Args:
            logits: Tensor of shape (N, C) where N is the batch size and C is the number of classes.
                    These are raw model outputs before softmax.
            targets: Tensor of shape (N,) where each value is the correct class index (0 ≤ target < C).

        Returns:
            A scalar loss value.
        """
        # Sigmoid activation to convert logits to probabilities

        logits = logits.squeeze()
        probs = torch.sigmoid(logits)

        p = 0
        m = 0

        p_loss = 0
        m_loss = 0

        for i in range(len(sensitive_feature)):
            if sensitive_feature[i] == 'Male':
                targets[i] = 1.0
            else:
                targets[i] = 0.0

        loss = -(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs)).mean()


        return loss


### Increase Bias

In [None]:
class DisparateImpactMaximizerLoss(nn.Module):
    def __init__(self, weight=0.5, epsilon=1e-5):
        super(DisparateImpactMaximizerLoss, self).__init__()
        self.weight = weight
        self.epsilon = epsilon

    def forward(self, logits, targets, sensitive_feature):
        logits = logits.squeeze()

        # Apply sigmoid and clamp to avoid log(0) or log(1)
        probs = torch.sigmoid(logits)
        probs = torch.clamp(probs, min=self.epsilon, max=1 - self.epsilon)

        # Convert sensitive_feature ('Male', 'Female') to 0 and 1
        sensitive_feature_tensor = torch.tensor([1 if x == 'Male' else 0 for x in sensitive_feature], dtype=torch.float)

        # Create masks for sensitive feature (Male = 1, Female = 0)
        male_mask = sensitive_feature_tensor  # 1 for Male, 0 for Female
        female_mask = 1 - male_mask  # 1 for Female, 0 for Male

        # Calculate log loss
        log_loss = -(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))

        # Calculate loss for each group (and prevent NaN issues from empty groups)
        male_mask_sum = male_mask.sum() + self.epsilon
        female_mask_sum = female_mask.sum() + self.epsilon



        #### DemP ####
        male_p  = (male_mask * probs).sum() / male_mask_sum
        female_p = (female_mask * probs).sum() / female_mask_sum

        demp = torch.abs(male_p - female_p)

        #### EO #####
        male_positive_sum = (male_mask * targets).sum()
        female_positive_sum = (female_mask * targets).sum()
        male_fnr = (male_mask * (1 - probs) * targets).sum() / male_positive_sum
        female_fnr = (female_mask * (1 - probs) * targets).sum() / female_positive_sum

        male_negative_sum = (male_mask * (1 - targets)).sum()
        female_negative_sum = (female_mask * (1 - targets)).sum()
        male_fpr = (male_mask * probs * (1 - targets)).sum() / male_negative_sum
        female_fpr = (female_mask * probs * (1 - targets)).sum() / female_negative_sum
        eo = torch.max(torch.abs(male_fnr - female_fnr), torch.abs(male_fpr - female_fpr))


        male_loss = (male_mask * log_loss).sum() / male_mask_sum
        female_loss = (female_mask * log_loss).sum() / female_mask_sum


        male_loss = (male_mask * log_loss).sum() / male_mask_sum
        female_loss = (female_mask * log_loss).sum() / female_mask_sum

        # Calculate absolute difference and normalize between 0 and 1
        abs_diff = torch.abs(male_loss - female_loss)

        # Normalize the difference to range [0, 1] using min_max_range
        min_range, max_range = torch.min(male_loss, female_loss), torch.max(male_loss, female_loss)
        normalized_diff = (abs_diff - min_range) / (max_range - min_range)

        # Clamp the normalized difference to ensure it stays between 0 and 1
        normalized_diff = torch.clamp(normalized_diff, min=0.0, max=1.0)

        # Inverse the normalized difference and calculate the final loss
        # loss = 1 / (normalized_diff + self.epsilon)

        loss = -torch.abs(eo)

        # Debugging: Check for NaN or Inf values
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            print("⚠️ NaN detected in loss!")
            print(f"Loss: {loss}")
            print(f"Probs: {probs}")
            print(f"Male Mask Sum: {male_mask_sum}")
            print(f"Female Mask Sum: {female_mask_sum}")
            print(f"male_loss: {male_loss}")
            print(f"female_loss: {female_loss}")
            print(f"log_loss: {log_loss}")
            sys.exit(0)

        return loss

# **Main Attack Loss**

We compute L_UF and a regualizer loss with the adapters and local (benign) model

In [None]:
class CustomLowRankLoss(nn.Module):
    def __init__(self):
        """
        Initializes the loss function with a given low-rank factorization rank.
        """
        super(CustomLowRankLoss, self).__init__()
        self.lam = 0.5
        self.epsilon = 1e-8

    def compute_dynamic_weights(self, loss1, loss2, focus_factor=0.5, scaling=None):
        """
        Computes dynamic weights for loss1 and loss2.

        Parameters:
        - loss1, loss2: Loss values (normalized to sum to 1).
        - focus_factor: How much importance to give loss2 (0 → only loss1, 1 → only loss2).
        - scaling: 'softmax' (default), 'inverse', or 'exponential'.

        Returns:
        - Dynamic weights for loss1 and loss2.
        """
        total = loss1 + loss2 + self.epsilon

        if scaling == 'softmax':
            T = 0.1  # Lower T makes distribution sharper
            losses = torch.tensor([loss1, loss2]) / total
            weights = torch.nn.functional.softmax(losses / T, dim=0)

        else:
            weights = torch.tensor([loss1, loss2]) / total

        # Ensure no loss is neglected
        weight_sum = weights.sum()

        # Adjust weights to ensure they are balanced and not ignored
        adjusted_weights = weights / weight_sum

        # Dynamically adjust weight based on focus_factor
        weighted_loss1 = (1 - focus_factor) * adjusted_weights[0].item()
        weighted_loss2 = focus_factor * adjusted_weights[1].item()

        # Combine the weighted losses
        lam = weighted_loss1 / (weighted_loss1 + weighted_loss2)

        # Ensure Lambda doesn't get too small or too large
        lam = max(min(lam, 1.0 - self.epsilon), self.epsilon)  # Clamped between epsilon and 1-epsilon

        return lam

    def forward(self, theta_g, theta_i, theta_lora, logits, targets, sensitive_labels):
        """
        Computes the loss for low-rank adaptation.

        Args:
            theta_g: Global model parameters (k-1)th.
            theta_i: Benign Local model parameters (k)th.
            theta_g - theta_i : \Delta \theta for the kth round
            theta_lora: Adversarial LoRA-enhanced model parameters.
            logits: Tensor of shape (N, C), raw model outputs before softmax.
            targets: Tensor of shape (N,), correct class indices.
            sensitive_labels: List of sensitive labels for each data point.

        Returns:
            A scalar loss value combining cross-entropy and low-rank regularization.
        """

        """
        Loss1: AB^T - theta_i - theta_g
        """

        """
        \delta theta = theta_i - theta_g
        \delta theta = AB^T
        output: theta_g + AB^T
        """

        ## first compute theta_i - theta_g
        added_params = {
            name.split('.')[0]: param1 - param2
            for ((name, param1), (_, param2)) in zip(theta_g.named_parameters(), theta_i.named_parameters())
            if 'fc' in name  # Check for 'fc' to include fc1, fc2, and fc3 layers
        }

        # added_params = {
        #     name.split('.')[0]: param1
        #     for (name, param1) in theta_i.named_parameters()
        #     if 'fc' in name  # Check for 'fc' to include fc1, fc2, and fc3 layers
        # }

        ## Do AB^T - added_params
        # Compute AB^T for each LoRALinear layer in l1
        lora_params = {}
        for name, module in theta_lora.named_modules():
            if isinstance(module, LoRALinear):
                lora_params[name] = module.A @ module.B  # Compute AB^T

        for name, ab in lora_params.items():
            if torch.isnan(ab).any() or torch.isinf(ab).any():
                print(f"NaN or Inf detected in {name}")
                sys.exit(0)

        for name, ab in added_params.items():
            if torch.isnan(ab).any() or torch.isinf(ab).any():
                print(f"NaN or Inf detected in AddedParams {name}")
                sys.exit(0)


        loss1 = sum(torch.norm(ab.T - added_params[name], p=2) for name, ab in lora_params.items())
        # total_norm = sum(torch.norm(ab.T, p=2) for ab in lora_params.values()) + 1e-6  # Avoid divide by zero
        # loss1 /= total_norm

        if torch.isnan(loss1).any():
            print("⚠️ NaN detected in loss1!")
            sys.exit(0)

        """
        Loss2: Unfairness Loss
        """
        loss_uf = DisparateImpactMaximizerLoss()
        loss2 = loss_uf(logits, targets, sensitive_labels)

        """
        """
        # temp_loss1 = loss1.clone()
        # temp_loss2 = loss2.clone()
        # self.lam = self.compute_dynamic_weights(temp_loss1, temp_loss2, 0.999)

        # total = loss1 + loss2 + self.epsilon
        # loss1 = loss1 / total
        # loss2 = loss2 / total
        overall_loss = self.lam * loss1 + (1 - self.lam) * loss2

        # print(f"loss1: {loss1}, loss2: {loss2}, Lambda: {self.lam}")

        return loss1


In [None]:
class CustomTwoPhaseLowRankLoss(nn.Module):
    def __init__(self):
        """
        Initializes the loss function with a given low-rank factorization rank.
        """
        super(CustomTwoPhaseLowRankLoss, self).__init__()
        self.lam = 0.5
        self.epsilon = 1e-8


    def Phase1Regularizer(self, theta_g, theta_i, theta_lora, logits, targets, sensitive_labels):
        """
        Computes the loss for low-rank adaptation.

        Args:
            theta_g: Global model parameters (k-1)th.
            theta_i: Benign Local model parameters (k)th.
            theta_g - theta_i : \Delta \theta for the kth round
            theta_lora: Adversarial LoRA-enhanced model parameters.
            logits: Tensor of shape (N, C), raw model outputs before softmax.
            targets: Tensor of shape (N,), correct class indices.
            sensitive_labels: List of sensitive labels for each data point.

        Returns:
            A scalar loss value combining cross-entropy and low-rank regularization.
        """

        """
        Loss1: AB^T - theta_i - theta_g
        """

        """
        \delta theta = theta_i - theta_g
        \delta theta = AB^T
        output: theta_g + AB^T
        """

        ## first compute theta_i - theta_g
        added_params = {
            name.split('.')[0]: param1 - param2
            for ((name, param1), (_, param2)) in zip(theta_g.named_parameters(), theta_i.named_parameters())
            if 'fc' in name  # Check for 'fc' to include fc1, fc2, and fc3 layers
        }

        ## Do AB^T - added_params
        # Compute AB^T for each LoRALinear layer in l1
        lora_params = {}
        for name, module in theta_lora.named_modules():
            if isinstance(module, LoRALinear):
                lora_params[name] = module.A @ module.B  # Compute AB^T

        for name, ab in lora_params.items():
            if torch.isnan(ab).any() or torch.isinf(ab).any():
                print(f"NaN or Inf detected in {name}")
                sys.exit(0)

        for name, ab in added_params.items():
            if torch.isnan(ab).any() or torch.isinf(ab).any():
                print(f"NaN or Inf detected in AddedParams {name}")
                sys.exit(0)


        loss1 = sum(torch.norm(ab.T - added_params[name], p=2) for name, ab in lora_params.items())

        if torch.isnan(loss1).any():
            print("⚠️ NaN detected in loss1!")
            sys.exit(0)

        return loss1

    def Phase2FairAttack(self, logits, targets, sensitive_labels):

        loss_uf = DisparateImpactMaximizerLoss()
        loss2 = loss_uf(logits, targets, sensitive_labels)

        return loss2



# Training!

In [None]:
## Handles the training pipeline

HONEST_LR = 5e-4
ADV_LR = 5e-4

class Device:
    def __init__(self, id, dataset, train_indices, test_indices, raw_frame, sensitive_feature = 'sex', adversarial = False, iid=True) -> None:

        input_size = len(dataset[0][0])
        self.id = id

        self.model = Net(input_size=input_size)
        self.global_model = Net(input_size=input_size)
        # self.model.to(torch.device("cuda"))

        if(adversarial):
            self.lora_model = LoRANet(input_size=input_size, rank=4)
            self.criterion_honest = nn.BCEWithLogitsLoss()
            self.criterion_adv = CustomTwoPhaseLowRankLoss() #
            self.optimizer_lora_phase1 = torch.optim.AdamW(
                                        [p for p in self.lora_model.parameters() if p.requires_grad], lr=ADV_LR
                                  )
            # first freeze A adapter
            self.freezeAadapter()
            self.optimizer_lora_phase2 = torch.optim.AdamW(
                                        [p for p in self.lora_model.parameters() if p.requires_grad], lr=ADV_LR
                                  )
            # unfreeze A adapter
            self.unfreezeAadapter()

        else:
            self.criterion = nn.BCEWithLogitsLoss()

        self.optimizer = optim.AdamW(self.model.parameters(), lr=HONEST_LR)

        self.train_indices = train_indices
        self.test_indices = test_indices

        self.train_set = Subset(dataset, train_indices)
        self.test_set = Subset(dataset, test_indices)

        self.sensitive_feature = sensitive_feature
        self.raw_frame = raw_frame

        self.adversarial = adversarial

        self.train_loader = DataLoader(self.train_set, batch_size=512, shuffle=False)
        self.test_loader = DataLoader(self.test_set, batch_size=512, shuffle=False)
        self.iid = iid


    def train(self, train_loader = None, num_epochs = 5):
        for device_epoch in range(num_epochs):
            self.model.train()
            losses = []
            for inputs, labels, indices in self.train_loader:
                self.optimizer.zero_grad()

                # inputs, train_labels = inputs.cuda(), train_labels.cuda()  # add this line
                labels = labels.float()
                # print(indices)

                sensitive_label = self.raw_frame.loc[indices.tolist(), 'sex'].tolist()

                outputs = self.model(inputs)
                outputs = outputs.squeeze()
                labels = labels.view_as(outputs)
                if(self.adversarial):
                    loss = self.criterion_honest(outputs, labels)
                else:
                    loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())
            # print(f"loss: {np.array(losses).mean()}")

    def train_adversarial(self, theta_g, theta_i, train_loader = None, num_epochs = 2):

        for device_epoch in range(num_epochs):

            self.lora_model.train()
            losses_phase1 = []
            losses_phase2 = []

            for inputs, labels, indices in self.train_loader:

                self.optimizer_lora_phase1.zero_grad()
                self.optimizer_lora_phase2.zero_grad()

                """
                Phase 1 : Regularizer
                """

                # inputs, train_labels = inputs.cuda(), train_labels.cuda()  # add this line
                labels = labels.float()
                # print(indices)

                sensitive_label = self.raw_frame.loc[indices.tolist(), 'sex'].tolist()

                outputs = self.lora_model(inputs)
                outputs = outputs.squeeze()
                labels = labels.view_as(outputs)

                if torch.isnan(outputs).any():
                    print("⚠️ NaN detected in logits!")
                    sys.exit(0)

                loss_phase1 = self.criterion_adv.Phase1Regularizer(theta_g, theta_i, self.lora_model, outputs, labels, sensitive_label)
                if torch.isnan(loss_phase1).any():
                    print("⚠️ NaN detected in loss!")
                    sys.exit(0)


                loss_phase1.backward()
                self.optimizer_lora_phase1.step()
                losses_phase1.append(loss_phase1.item())

                """
                Phase 2: Fairness Attack
                """
                outputs = self.lora_model(inputs)
                outputs = outputs.squeeze()
                labels = labels.view_as(outputs)

                loss_phase2 = self.criterion_adv.Phase2FairAttack(outputs, labels, sensitive_label)
                loss_phase2.backward()
                self.optimizer_lora_phase2.step()
                losses_phase2.append(loss_phase2.item())
                if torch.isnan(loss_phase2).any():
                    print("⚠️ NaN detected in loss!")
                    sys.exit(0)


    def validate(self, test_loader=None, flag=False, verbose=False):
        # Set the model to evaluation mode
        self.model.eval()

        # Use the default test_loader if none is provided
        if test_loader is None:
            test_loader = self.test_loader

        total = 0
        correct = 0
        losses = []

        pr_y1_unpriv = 0
        pr_y1_priv = 0
        total_unpriv = 0
        total_priv = 0

        y_true_priv, y_pred_priv = [], []
        y_true_unpriv, y_pred_unpriv = [], []

        preds_priv, preds_unpriv = [], []  # Needed for Demographic Parity


        with torch.no_grad():
            for inputs, labels, indices in self.train_loader:

                # Move data to GPU if needed (uncomment if using CUDA)
                # inputs, labels = inputs.cuda(), labels.cuda()

                # Convert labels to float for BCEWithLogitsLoss
                labels = labels.float()

                sensitive_label = self.raw_frame.loc[indices.tolist(), 'sex'].tolist()


                # Forward pass
                if self.adversarial and flag:
                    outputs = self.lora_model(inputs)
                else:
                    outputs = self.model(inputs)

                # Adjust shapes
                outputs = outputs.squeeze()
                labels = labels.view_as(outputs)

                # Compute loss
                if(self.adversarial):
                    loss = self.criterion_honest(outputs, labels)
                else:
                    loss = self.criterion(outputs, labels)
                losses.append(loss.item())

                # Apply sigmoid to logits and threshold at 0.5 for predictions
                predictions = torch.sigmoid(outputs) >= 0.5

                for i in range(len(predictions)):
                    if sensitive_label[i] == 'Male':
                        total_priv += 1
                        if predictions[i] == labels[i]:
                            pr_y1_priv += 1
                        y_true_priv.append(labels[i].item())
                        y_pred_priv.append(predictions[i].item())
                        preds_priv.append(predictions[i].item())


                    else:
                        total_unpriv += 1
                        if predictions[i] == labels[i]:
                            pr_y1_unpriv += 1
                        y_true_unpriv.append(labels[i].item())
                        y_pred_unpriv.append(predictions[i].item())

                        preds_unpriv.append(predictions[i].item())

                # Update total and correct counts
                total += labels.size(0)
                correct += (predictions == labels).sum().item()

        # Calculate accuracy
        accuracy = 100 * correct / total
        disparate_impact = pr_y1_unpriv/total_unpriv / (pr_y1_priv/total_priv)
        equalized_odds = compute_equalized_odds(y_true_priv, y_pred_priv, y_true_unpriv, y_pred_unpriv)
        demographic_parity = compute_demographic_parity(preds_priv, preds_unpriv)


        # Verbose logging
        if verbose:
            print(f"Disparate Impact: {disparate_impact:.2f}")
            print(f"Accuracy: {accuracy:.2f}%")
            print(f"Average Loss: {sum(losses) / len(losses):.4f}")

        return sum(losses) / len(losses), accuracy, disparate_impact, equalized_odds, demographic_parity


    def get_model_params(self):
        # if self.adversarial:
        #     return convert_lora_to_standard(self.lora_model.state_dict(), 1.0)
        return self.model.state_dict()

    def set_global_model(self, current_global_model):
        self.global_model.load_state_dict(current_global_model)
        # print(missing, unexpected)

    def update_model(self, new_state_dict):
        self.model.load_state_dict(new_state_dict, strict=False)
        # print(missing, unexpected)

    def transfer_mlp_to_lora(self, new_state_dict):
        """
        Transfers the weights from a standard MLP model to the LoRANet model.

        - Copies fc1, fc2, fc3 weights from the MLP to `fc1.base_layer`, `fc2.base_layer`, and `fc3.base_layer` in LoRANet.
        - Leaves LoRA adapter parameters (A, B) unchanged.

        Args:
            lora_model (LoRANet): The LoRA model to be updated.
            mlp_model (nn.Module): The standard MLP model with matching architecture.
        """
        lora_state_dict = self.lora_model.state_dict()

        for name, param in new_state_dict.items():
            if "fc" in name and "weight" in name:  # Only transfer fully connected layers
                lora_name = f"{name.split('.')[0]}.base_layer.{name.split('.')[1]}"
                # print(lora_name)
                lora_state_dict[lora_name] = param  # Copy the MLP weights

            elif "bn" in name:  # Transfer batch normalization layers
                lora_state_dict[name] = param  # Copy BN weights

        # Load the modified state dict into the LoRA model
        missing, unexpected = self.lora_model.load_state_dict(lora_state_dict, strict=True)
        # print(missing, unexpected)

    def convert_lora_to_standard(self, alpha):
        """
        Converts the LoRA state dict to a standard MLP state dict by merging the base weights and the low-rank adaptations.

        Args:
            lora_state_dict (dict): The state dictionary from a LoRA-based model.
            alpha (float): The scaling factor for the LoRA low-rank adaptation.

        Returns:
            dict: The converted standard state dictionary.
        """
        # Initialize a dictionary to hold the converted state_dict
        standard_state_dict = {}
        lora_state_dict = deepcopy(self.lora_model.state_dict())

        # Convert each LoRA layer to a standard MLP layer by merging A @ B^T with base weights
        # For the LoRALinear layers (fc1, fc2, fc3)
        for layer in ['fc1', 'fc2', 'fc3']:
            # Compute the standard layer weight: theta' = theta + A @ B^T
            standard_state_dict[f"{layer}.weight"] = (
                lora_state_dict[f"{layer}.base_layer.weight"] +
                (alpha * lora_state_dict[f"{layer}.A"] @ lora_state_dict[f"{layer}.B"]).T
            )


        # Handle batch normalization layers
        for i in range(1, 3):
            standard_state_dict[f"bn{i}.weight"] = lora_state_dict[f"bn{i}.weight"]
            standard_state_dict[f"bn{i}.bias"] = lora_state_dict[f"bn{i}.bias"]
            standard_state_dict[f"bn{i}.running_mean"] = lora_state_dict[f"bn{i}.running_mean"]
            standard_state_dict[f"bn{i}.running_var"] = lora_state_dict[f"bn{i}.running_var"]
            standard_state_dict[f"bn{i}.num_batches_tracked"] = lora_state_dict[f"bn{i}.num_batches_tracked"]

        return standard_state_dict


    def add_adapters(self):
        # temp_model = deepcopy(self.model)
        self.model.load_state_dict(self.convert_lora_to_standard(1.0), strict=True)
        # total_norm = sum(torch.norm(param1 - param2, p=2) ** 2
                    #  for (name, param1), (_, param2) in zip(temp_model.named_parameters(), self.model.named_parameters()))
        # print(f"Total norm: {total_norm}")


    def freezeAadapter(self):

        for name, param in self.lora_model.state_dict().items():
            if 'A' in name:  # Look for A adapters
                param.requires_grad = False  # Freeze the parameter

            if 'B' in name:
                param.requires_grad = True

    def unfreezeAadapter(self):

        for name, param in self.lora_model.state_dict().items():
            if 'B' in name:  # Look for B adapters
                param.requires_grad = True  # Freeze the parameter

            if 'A' in name:
                param.requires_grad = True



# FL Aggregation

### FedAvg

In [None]:
def average_state_dicts(state_dicts):
    """
    Averages the parameters of the state_dicts from all client models.

    :param state_dicts: a list of state_dicts from client models
    :return: the averaged state_dict
    """
    # Initialize a new state_dict with the keys from the first state_dict
    avg_state_dict = {key: torch.zeros_like(value) for key, value in state_dicts[0].items()}

    # Sum all the state_dicts
    for state_dict in state_dicts:
        for key, value in state_dict.items():
            avg_state_dict[key] += value

    # Divide by the number of state_dicts to get the average
    for key in avg_state_dict.keys():
        avg_state_dict[key] = avg_state_dict[key] / len(state_dicts)

    return avg_state_dict

## Krum

In [None]:
def krum_state_dicts(state_dicts, f, num_selected=10):
    """
    Implements the KRUM aggregation algorithm with multiple selections.

    :param state_dicts: a list of state_dicts from client models
    :param f: upper bound on the number of Byzantine (malicious) clients
    :param num_selected: number of state_dicts to select
    :return: a list of selected state_dicts
    """
    n = len(state_dicts)
    m = n - f - 2  # Number of closest distances to consider
    if m < 1:
        raise ValueError("n - f - 2 must be at least 1")

    # Flatten each state_dict into a vector
    param_vectors = []
    for state_dict in state_dicts:
        params = []
        for key in sorted(state_dict.keys()):
            if "bn" in str(key):
                continue
            # print(key)
            params.append(state_dict[key].flatten())
        param_vector = torch.cat(params)
        param_vectors.append(param_vector)

    # Compute pairwise distances
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i + 1, n):
            dist = torch.norm(param_vectors[i] - param_vectors[j]) ** 2
            distances[i][j] = dist
            distances[j][i] = dist  # Symmetric

    # Compute KRUM scores
    krum_scores = []
    for i in range(n):
        dists = torch.cat((distances[i, :i], distances[i, i+1:]))
        m_closest_dists, _ = torch.topk(dists, k=m, largest=False)
        score = torch.sum(m_closest_dists)
        krum_scores.append(score.item())


    print(krum_scores)

    # Select the indices of the top `num_selected` clients with the lowest scores
    selected_indices = torch.topk(torch.tensor(krum_scores), k=num_selected, largest=False).indices.tolist()
    selected_state_dicts = [state_dicts[i] for i in selected_indices]

    print("Selected indices:", selected_indices)

    return average_state_dicts(selected_state_dicts)

# Runner!

In [None]:
SEED = 0
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split


# Fetch the Adult dataset
adult = fetch_openml(name='adult', version=2, as_frame=True)
df = adult.frame

# Separate features (X) and target (y)
X = df.drop('class', axis=1)
y = df['class']

# Extract feature column names
feature_names = X.columns.tolist()

# Replace '?' with NaN and drop missing values
X = X.replace('?', np.nan)
X = X.dropna()
y = y[X.index]  # Align target variable

categorical_cols = X.select_dtypes(include=['category']).columns
numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns

train_cols = categorical_cols.drop('sex')
sensitive_cols = categorical_cols.drop(['workclass', 'education', 'marital-status', 'occupation',
       'relationship', 'race', 'native-country'])

# Define the preprocessing steps
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(sparse_output=False), train_cols)
    ])

# Apply the preprocessing pipeline
X_processed = preprocessor.fit_transform(X)
feature_names = preprocessor.get_feature_names_out()

# Encode the target variable
le = LabelEncoder()
y_processed = le.fit_transform(y)

print(y_processed)

X_dataset = torch.tensor(X_processed, dtype=torch.float32)
y_dataset = torch.tensor(y_processed, dtype=torch.long)

indices = X.index.tolist()

base_dataset = CustomDataset(X_dataset, y_dataset, indices)




[0 0 1 ... 0 0 1]


In [None]:
NO_DEVICES = 10
POINTS_PER_DEVICE = 4000
TEST_RATIO = 0.25
PERCENT_ADVERSARIAL = 0.40
f = int(PERCENT_ADVERSARIAL * NO_DEVICES)

IID = True
ALPHA = 0.5

NO_EPOCHS = 10
ADV_EPOCHS = 10


In [None]:
devices = []

for agent_id in range(NO_DEVICES):
    train_start = int(agent_id * POINTS_PER_DEVICE)
    train_end = int((agent_id+1) * POINTS_PER_DEVICE - POINTS_PER_DEVICE * (TEST_RATIO))
    test_start = train_end
    test_end = int((agent_id+1) * POINTS_PER_DEVICE)


    train_indices = indices[train_start:train_end]
    test_indices = indices[test_start:test_end]


    adversarial = False
    if agent_id < int(NO_DEVICES * PERCENT_ADVERSARIAL):
        adversarial = True

    device = Device(agent_id, base_dataset, train_indices, test_indices, X, sensitive_feature='sex', adversarial=adversarial, iid=IID)
    devices.append(device)

adv_ids = []
for device in devices:
    if device.adversarial:
        adv_ids.append(device.id)
print(adv_ids)


[0, 1, 2, 3]


In [None]:
NO_COM_ROUNDS = 50
AGGREGATOR = 0
aggregator_results = {}


for com_round in tqdm(range(0, NO_COM_ROUNDS)):
    # print(f'\nCommunication round: {com_round}')
    client_state_dicts = []
    losses = []
    accuracies = []
    disparate_impacts = []
    eq_odds = []
    dp = []

    # Train local models
    for device in devices:
        device.train(num_epochs=NO_EPOCHS)
        if device.adversarial:
            temp_i = deepcopy(device.model)
            temp_g = deepcopy(device.global_model)
            device.train_adversarial(temp_g, temp_i, num_epochs=ADV_EPOCHS)
            # _, _, _ = device.validate(flag=True, verbose=True)
            device.add_adapters() ##
            # print(device.global_model.state_dict())
            # print(compute_l2norm_diff_between_mlps(device.model.state_dict(), device.global_model.state_dict()))
            # _, _, _ = device.validate(verbose=True)
        client_state_dicts.append(device.get_model_params())
        # sys.exit(0)


    client_update = average_state_dicts(client_state_dicts)
    # client_update = krum_state_dicts(client_state_dicts, f) # aggregated global model

    # model update
    for device in devices:
        device.set_global_model(client_update)
        device.update_model(client_update)
        if device.adversarial:
            device.transfer_mlp_to_lora(client_update)  ##
            # print(f"L2norm: {compute_l2norm_diff(client_update, device.lora_model.state_dict())}")
        loss, accuracy, disparate_impact, eq_odd, dp_x = device.validate(verbose=False)
        losses.append(loss)
        accuracies.append(accuracy)
        disparate_impacts.append(disparate_impact)
        eq_odds.append(eq_odd)
        dp.append(dp_x)

    # for i in range(10):
    #     print(compute_l2norm_diff_between_mlps(devices[i].global_model.state_dict(), client_update))

    # print("Average loss: ", sum(losses) / len(losses))
    print("Average accuracy: ", sum(accuracies) / len(accuracies))
    print("Average disparate impact: ", sum(disparate_impacts) / len(disparate_impacts))
    print("Average EQ Odds: ", sum(eq_odds) / len(eq_odds))
    print("Average DP: ", sum(dp) / len(dp))
    print("\n")



  2%|▏         | 1/50 [00:12<10:09, 12.44s/it]

Average accuracy:  75.56333333333333
Average disparate impact:  1.2937926076503634
Average EQ Odds:  0.01914530321550192
Average DP:  0.0023364390781731538




  4%|▍         | 2/50 [00:25<10:03, 12.58s/it]

Average accuracy:  82.58333333333333
Average disparate impact:  1.1616858943585942
Average EQ Odds:  0.3292255940505511
Average DP:  0.26338792430942326




  6%|▌         | 3/50 [00:37<09:40, 12.34s/it]

Average accuracy:  82.95333333333333
Average disparate impact:  1.1543903156356776
Average EQ Odds:  0.2730283622510198
Average DP:  0.2471421941181179




  8%|▊         | 4/50 [00:49<09:30, 12.41s/it]

Average accuracy:  83.11333333333334
Average disparate impact:  1.1591854577666314
Average EQ Odds:  0.2793898090200535
Average DP:  0.2514048835864161




 10%|█         | 5/50 [01:02<09:18, 12.42s/it]

Average accuracy:  83.16666666666667
Average disparate impact:  1.1616466956755975
Average EQ Odds:  0.2781381413779958
Average DP:  0.25167719046463244




 12%|█▏        | 6/50 [01:14<09:03, 12.35s/it]

Average accuracy:  83.33000000000001
Average disparate impact:  1.1616797488783785
Average EQ Odds:  0.3090764017438355
Average DP:  0.2597598134433194




 14%|█▍        | 7/50 [01:26<08:47, 12.27s/it]

Average accuracy:  83.19333333333333
Average disparate impact:  1.166420815773833
Average EQ Odds:  0.3276361153616018
Average DP:  0.2686749097295139




 16%|█▌        | 8/50 [01:39<08:47, 12.56s/it]

Average accuracy:  83.07666666666667
Average disparate impact:  1.169808920388392
Average EQ Odds:  0.35187491343021776
Average DP:  0.27926506715614074




 18%|█▊        | 9/50 [01:53<08:46, 12.83s/it]

Average accuracy:  82.99666666666667
Average disparate impact:  1.1743907561424205
Average EQ Odds:  0.39379696566206024
Average DP:  0.29379588561824493




 20%|██        | 10/50 [02:06<08:35, 12.89s/it]

Average accuracy:  82.76333333333334
Average disparate impact:  1.178775185753656
Average EQ Odds:  0.43585523031125933
Average DP:  0.30794418795358036




 22%|██▏       | 11/50 [02:18<08:18, 12.78s/it]

Average accuracy:  82.51
Average disparate impact:  1.1854992828267548
Average EQ Odds:  0.4722360715620832
Average DP:  0.32088547420846714




 24%|██▍       | 12/50 [02:31<08:06, 12.80s/it]

Average accuracy:  81.85
Average disparate impact:  1.1993616200853077
Average EQ Odds:  0.5466978561356587
Average DP:  0.34773850469951445




 26%|██▌       | 13/50 [02:45<08:04, 13.10s/it]

Average accuracy:  81.78666666666666
Average disparate impact:  1.201925892452894
Average EQ Odds:  0.5907208064101452
Average DP:  0.3584713483635691




 28%|██▊       | 14/50 [02:57<07:43, 12.88s/it]

Average accuracy:  81.66999999999999
Average disparate impact:  1.2045943599610696
Average EQ Odds:  0.62558086948583
Average DP:  0.3667624748335748




 30%|███       | 15/50 [03:10<07:27, 12.80s/it]

Average accuracy:  81.63333333333334
Average disparate impact:  1.2062924360875331
Average EQ Odds:  0.6342030417004799
Average DP:  0.3701892967603097




 32%|███▏      | 16/50 [03:22<07:13, 12.76s/it]

Average accuracy:  81.57000000000001
Average disparate impact:  1.2075453065480355
Average EQ Odds:  0.6297882564093004
Average DP:  0.3706358862269341




 34%|███▍      | 17/50 [03:37<07:14, 13.17s/it]

Average accuracy:  81.37333333333332
Average disparate impact:  1.2140991912118895
Average EQ Odds:  0.6297205463780158
Average DP:  0.3763730676078241




 36%|███▌      | 18/50 [03:50<07:08, 13.40s/it]

Average accuracy:  81.45
Average disparate impact:  1.214840873875437
Average EQ Odds:  0.6195919412240165
Average DP:  0.3754710310012959




 38%|███▊      | 19/50 [04:02<06:41, 12.95s/it]

Average accuracy:  81.36666666666667
Average disparate impact:  1.2168041098499545
Average EQ Odds:  0.629697635337476
Average DP:  0.38209991759151973




 40%|████      | 20/50 [04:15<06:23, 12.79s/it]

Average accuracy:  81.4
Average disparate impact:  1.215793152133739
Average EQ Odds:  0.618536687619877
Average DP:  0.3774906787715218




 42%|████▏     | 21/50 [04:27<06:04, 12.57s/it]

Average accuracy:  81.37333333333332
Average disparate impact:  1.2183614060332908
Average EQ Odds:  0.6277301592575227
Average DP:  0.3820227367914891




 44%|████▍     | 22/50 [04:39<05:47, 12.40s/it]

Average accuracy:  81.49000000000001
Average disparate impact:  1.2158264296998251
Average EQ Odds:  0.6030617198117465
Average DP:  0.37525636087509484




 46%|████▌     | 23/50 [04:51<05:35, 12.41s/it]

Average accuracy:  81.49666666666666
Average disparate impact:  1.2154581920823873
Average EQ Odds:  0.6035046677411272
Average DP:  0.3756882253952639




 48%|████▊     | 24/50 [05:05<05:32, 12.78s/it]

Average accuracy:  81.67333333333332
Average disparate impact:  1.213213953229499
Average EQ Odds:  0.595814759968541
Average DP:  0.372913677728426




 50%|█████     | 25/50 [05:19<05:26, 13.07s/it]

Average accuracy:  81.66333333333334
Average disparate impact:  1.2134546707650762
Average EQ Odds:  0.5915396880859658
Average DP:  0.37195197803371294




 52%|█████▏    | 26/50 [05:32<05:12, 13.02s/it]

Average accuracy:  81.66666666666667
Average disparate impact:  1.212227868278095
Average EQ Odds:  0.5848594466707561
Average DP:  0.3709868074390822




 54%|█████▍    | 27/50 [05:47<05:14, 13.66s/it]

Average accuracy:  81.65666666666667
Average disparate impact:  1.2131092317074519
Average EQ Odds:  0.6026876055409909
Average DP:  0.3740099033841019




 56%|█████▌    | 28/50 [05:59<04:52, 13.31s/it]

Average accuracy:  81.79666666666667
Average disparate impact:  1.2113342323921936
Average EQ Odds:  0.592412053090414
Average DP:  0.37116272345133405




 58%|█████▊    | 29/50 [06:10<04:24, 12.58s/it]

Average accuracy:  81.69
Average disparate impact:  1.2118713225647557
Average EQ Odds:  0.58839540691465
Average DP:  0.3704377874097604




 60%|██████    | 30/50 [06:22<04:08, 12.41s/it]

Average accuracy:  81.85999999999999
Average disparate impact:  1.2085307946016575
Average EQ Odds:  0.5788879143890674
Average DP:  0.3666562727691851




 62%|██████▏   | 31/50 [06:35<03:55, 12.41s/it]

Average accuracy:  81.94666666666667
Average disparate impact:  1.2064810407219322
Average EQ Odds:  0.571103235349576
Average DP:  0.3647001941521911




 64%|██████▍   | 32/50 [06:47<03:42, 12.34s/it]

Average accuracy:  81.84
Average disparate impact:  1.209204251058423
Average EQ Odds:  0.5729181701939019
Average DP:  0.3665384566249764




 66%|██████▌   | 33/50 [06:58<03:24, 12.00s/it]

Average accuracy:  81.90333333333334
Average disparate impact:  1.209228693127593
Average EQ Odds:  0.5650886280180418
Average DP:  0.3651210347901235




 68%|██████▊   | 34/50 [07:10<03:12, 12.06s/it]

Average accuracy:  81.97666666666666
Average disparate impact:  1.2066455298150924
Average EQ Odds:  0.5599764815525654
Average DP:  0.3627334282613145




 70%|███████   | 35/50 [07:21<02:57, 11.81s/it]

Average accuracy:  82.01
Average disparate impact:  1.2048088311627858
Average EQ Odds:  0.5737233676743305
Average DP:  0.3635925857971348




 72%|███████▏  | 36/50 [07:33<02:45, 11.82s/it]

Average accuracy:  81.96333333333334
Average disparate impact:  1.2057749633442874
Average EQ Odds:  0.569955018946694
Average DP:  0.3641735812402595




 74%|███████▍  | 37/50 [07:45<02:34, 11.89s/it]

Average accuracy:  82.05333333333334
Average disparate impact:  1.2051399878974662
Average EQ Odds:  0.5557662248071926
Average DP:  0.3600929286028426




 76%|███████▌  | 38/50 [07:57<02:20, 11.71s/it]

Average accuracy:  82.06333333333333
Average disparate impact:  1.2044854927283635
Average EQ Odds:  0.5363169205568595
Average DP:  0.35632775143563955




 78%|███████▊  | 39/50 [08:09<02:09, 11.81s/it]

Average accuracy:  82.1
Average disparate impact:  1.2052861681682192
Average EQ Odds:  0.5310741994904011
Average DP:  0.3560721491850808




 80%|████████  | 40/50 [08:20<01:58, 11.85s/it]

Average accuracy:  82.25666666666669
Average disparate impact:  1.2020637962574363
Average EQ Odds:  0.5314160801938713
Average DP:  0.35419267074165933




 82%|████████▏ | 41/50 [08:33<01:49, 12.16s/it]

Average accuracy:  82.31333333333333
Average disparate impact:  1.2013942051807118
Average EQ Odds:  0.5372046447510556
Average DP:  0.35417312163205106




 84%|████████▍ | 42/50 [08:45<01:36, 12.00s/it]

Average accuracy:  82.25333333333333
Average disparate impact:  1.202140368743507
Average EQ Odds:  0.5370770778463926
Average DP:  0.35546123823312686




 86%|████████▌ | 43/50 [08:57<01:23, 11.98s/it]

Average accuracy:  82.34
Average disparate impact:  1.1999280480372865
Average EQ Odds:  0.5209803550964035
Average DP:  0.35030270731928076




 88%|████████▊ | 44/50 [09:10<01:12, 12.15s/it]

Average accuracy:  82.41000000000001
Average disparate impact:  1.1981149736475252
Average EQ Odds:  0.5432302240468065
Average DP:  0.3527623698480672




 90%|█████████ | 45/50 [09:23<01:02, 12.50s/it]

Average accuracy:  82.47333333333333
Average disparate impact:  1.1979289752380455
Average EQ Odds:  0.5260030877686765
Average DP:  0.34897781464460176




 92%|█████████▏| 46/50 [09:35<00:50, 12.53s/it]

Average accuracy:  82.35666666666665
Average disparate impact:  1.199764717025986
Average EQ Odds:  0.5228871697374142
Average DP:  0.34956249667649597




 94%|█████████▍| 47/50 [09:48<00:37, 12.52s/it]

Average accuracy:  82.28333333333333
Average disparate impact:  1.2033672000635034
Average EQ Odds:  0.5213735182416537
Average DP:  0.3520580978547629




 96%|█████████▌| 48/50 [10:00<00:24, 12.46s/it]

Average accuracy:  82.34
Average disparate impact:  1.2015946615707631
Average EQ Odds:  0.525121626790039
Average DP:  0.3517504197363988




 98%|█████████▊| 49/50 [10:12<00:12, 12.29s/it]

Average accuracy:  82.41333333333333
Average disparate impact:  1.2004988096977256
Average EQ Odds:  0.5130981812612023
Average DP:  0.34830551193037607




100%|██████████| 50/50 [10:25<00:00, 12.51s/it]

Average accuracy:  82.47333333333333
Average disparate impact:  1.1982701424764106
Average EQ Odds:  0.5127806176432114
Average DP:  0.34759796466212645





