In [1]:
# Make sure you're on Python > 3.8
!pip install -r requirements.txt --quiet

In [2]:
from collections import OrderedDict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split

import flwr as fl
from flwr.simulation import run_simulation
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents

In [3]:
DEVICE = torch.device('cpu')

In [4]:
!mkdir '.kaggle'
!mkdir '.kaggle/data'

with open(".kaggle/kaggle.json", 'a+') as f:
    f.write('{"username":"rajaxarcmu","key":"68d40c5e38e1c786ab57736bc5c9b2cb"}')
    
!chmod 600 '.kaggle/kaggle.json'
!kaggle datasets download -d 'danofer/compass'
!unzip -qo compass.zip -d '.kaggle/data'

mkdir: .kaggle: File exists
mkdir: .kaggle/data: File exists
Dataset URL: https://www.kaggle.com/datasets/danofer/compass
License(s): DbCL-1.0
compass.zip: Skipping, found more recently modified local copy (use --force to force download)


In [5]:
!ls .kaggle/data

compas-scores-raw.csv
cox-violent-parsed.csv
cox-violent-parsed_filt.csv
[1m[36mpropublicaCompassRecividism_data_fairml.csv[m[m


In [6]:
df = pd.read_csv('.kaggle/data/propublicaCompassRecividism_data_fairml.csv/propublica_data_for_fairml.csv')
print(df.shape)

(6172, 12)


In [7]:
df['caucasian'] = ((df['African_American'] + df['Asian'] + df['Hispanic'] + df['Native_American'] + df['Other']) == 0).astype(int)

In [8]:
NUM_CLIENTS = 10
# REPRESENTS SILO'D ORGANIZATIONS

In [9]:
from datasets import Dataset
from flwr_datasets.partitioner import DirichletPartitioner

In [10]:
trainset, testset = train_test_split(df, test_size=0.2)
batch_size = 32

ds = Dataset.from_pandas(trainset)
partitioner = DirichletPartitioner(
    num_partitions=NUM_CLIENTS,
    partition_by="caucasian",
    alpha=0.5,
    min_partition_size=(len(trainset) // (4 * NUM_CLIENTS)),
    self_balancing=True,
    shuffle=True
)

partitioner.dataset = ds
datasets = []
for i in range(NUM_CLIENTS):
    curr_partition = partitioner.load_partition(i)
    datasets.append(curr_partition.to_pandas())

train_loaders = []
val_loaders = []

feature_columns = ['Number_of_Priors', 'score_factor','Age_Above_FourtyFive', 'Age_Below_TwentyFive', 'Misdemeanor']

for ds in datasets:
    train_x = ds[feature_columns].values
    train_y = ds['Two_yr_Recidivism'].values
    sensitive_feature = ds['caucasian'].values

    train_x, val_x, train_y, val_y, sensitive_train, sensitive_val = train_test_split(
        train_x, train_y, sensitive_feature, test_size=0.25, shuffle=True, stratify=train_y, random_state=42
    )
    
    train_x_tensor = torch.from_numpy(train_x).float()
    train_y_tensor = torch.from_numpy(train_y).float()
    sensitive_train_tensor = torch.from_numpy(sensitive_train).float()

    valid_x_tensor = torch.from_numpy(val_x).float()
    valid_y_tensor = torch.from_numpy(val_y).float()
    sensitive_val_tensor = torch.from_numpy(sensitive_val).float()

    # Create TensorDataset and DataLoader, including the sensitive attribute
    train_dataset = TensorDataset(train_x_tensor, train_y_tensor, sensitive_train_tensor)
    valid_dataset = TensorDataset(valid_x_tensor, valid_y_tensor, sensitive_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(valid_dataset, batch_size=batch_size)

    train_loaders.append(train_loader)
    val_loaders.append(val_loader)

# For test data
test_x = testset[feature_columns].values
test_y = testset['Two_yr_Recidivism'].values
sensitive_test = testset['caucasian'].values

test_x_tensor = torch.from_numpy(test_x).float()
test_y_tensor = torch.from_numpy(test_y).float()
sensitive_test_tensor = torch.from_numpy(sensitive_test).float()

test_dataset = TensorDataset(test_x_tensor, test_y_tensor, sensitive_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [11]:
class BaselineNN(nn.Module):
    def __init__(self):
        super(BaselineNN, self).__init__()
        self.fc1 = nn.Linear(5, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 1)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

def compute_eod(preds, labels, sensitive_feature):
    preds_binary = (preds >= 0.5).float()
    y_true_mask = (labels == 1).view(-1)

    p_a0 = preds_binary[y_true_mask & (sensitive_feature == 0)].mean().item()
    p_a1 = preds_binary[y_true_mask & (sensitive_feature == 1)].mean().item()

    eod = p_a0 - p_a1
    return eod

def train(net, trainloader, epochs, verbose=True):
    """
    Train Network on Training Set
    """
    criterion = nn.BCELoss()
    optimizer = optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        all_preds, all_labels, all_sensitives = [], [], []
        
        for inputs, labels, sensitive_features in trainloader:
            inputs, labels, sensitive_features = inputs.to(DEVICE), labels.to(DEVICE), sensitive_features.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            labels = labels.view(-1, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * inputs.size(0)
            predicted = (outputs >= 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Append predictions and sensitive data for EOD computation
            all_preds.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sensitives.append(sensitive_features.cpu())
        
        # Compute EOD at the end of the epoch
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_sensitives = torch.cat(all_sensitives)
        
        eod = compute_eod(all_preds, all_labels, all_sensitives)
        
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f} - EOD: {eod:.4f}")

def test(net, testloader, verbose=True):
    criterion = nn.BCELoss()
    net.eval()
    correct, total, loss = 0, 0, 0.0
    all_preds, all_labels, all_sensitives = [], [], []
    
    with torch.no_grad():
        for inputs, labels, sensitive_features in testloader:
            inputs, labels, sensitive_features = inputs.to(DEVICE), labels.to(DEVICE), sensitive_features.to(DEVICE)
            outputs = net(inputs)
            labels = labels.view(-1, 1)
            loss += criterion(outputs, labels).item() * inputs.size(0)
            predicted = (outputs >= 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Append predictions and sensitive data for EOD computation
            all_preds.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sensitives.append(sensitive_features.cpu())
    
    # Compute EOD at the end of testing
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_sensitives = torch.cat(all_sensitives)
    
    eod = compute_eod(all_preds, all_labels, all_sensitives)
    
    loss /= len(testloader.dataset)
    acc = correct / total
    if verbose:
        print(f"Test Loss: {loss:.4f} - Acc: {acc:.4f} - EOD: {eod:.4f}")
    return loss, acc, eod

# Centralized Learning

In [12]:
model = BaselineNN()

In [13]:
for i in range(NUM_CLIENTS):
    train_loader = train_loaders[i]
    val_loader = val_loaders[i]
    model = model.to(DEVICE)
    epochs = 10

    for epoch in range(epochs):
        train(model, train_loader, 1, verbose=False)
        loss, acc, eod = test(model, val_loader, verbose=False)

    loss, acc, eod = test(model, test_loader, verbose=False)
    print(f"Client {i} - Test Loss: {loss:.4f} - Acc: {acc:.4f} - EOD: {eod:.4f}")

Client 0 - Test Loss: 0.6644 - Acc: 0.6202 - EOD: -0.0218
Client 1 - Test Loss: 0.6420 - Acc: 0.6348 - EOD: 0.0251
Client 2 - Test Loss: 0.6250 - Acc: 0.6502 - EOD: 0.0682
Client 3 - Test Loss: 0.6154 - Acc: 0.6721 - EOD: 0.1525
Client 4 - Test Loss: 0.6138 - Acc: 0.6696 - EOD: 0.1451
Client 5 - Test Loss: 0.6199 - Acc: 0.6680 - EOD: 0.1388
Client 6 - Test Loss: 0.6108 - Acc: 0.6794 - EOD: 0.1748
Client 7 - Test Loss: 0.6091 - Acc: 0.6680 - EOD: 0.1243
Client 8 - Test Loss: 0.6199 - Acc: 0.6664 - EOD: 0.1174
Client 9 - Test Loss: 0.6202 - Acc: 0.6688 - EOD: 0.1510


# Federated Learning with Flower

In [14]:
def get_params(net):
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_params(net, params):
    params_dict = zip(net.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.Tensor(v) for k,v in params_dict})
    net.load_state_dict(state_dict, strict=True)



In [15]:
from custom_flwr.server_app import server_fn as server_fn_custom
from custom_flwr.client_app import client_fn as client_fn_custom

def server_fn(context: Context):
    context.run_config = {
        'num-server-rounds' : 10,
        'fraction-fit': 0.25,
        'fraction-evaluate': 0.5,
        'local-epochs': 1,
        'server-device': str(DEVICE),
        'use-wandb': False
    }
    return server_fn_custom(context)

def client_fn(context: Context):
    return client_fn_custom(context)

client = ClientApp(client_fn=client_fn)
server = ServerApp(server_fn=server_fn)

In [16]:
backend_config = {"client_resources": None}
NUM_PARTITIONS = 10
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 23.19512327053608, {'centralized_accuracy': 0.46234817813765183, 'eod': 0.004106879234313965}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]


Test Accuracy: 0.46234817813765183 - Test Loss: 23.19512327053608 - EOD: 0.004106879234313965


[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (1, 21.691564119779148, {'centralized_accuracy': 0.5360323886639676, 'eod': 0.0}, 20.450132864993066)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=91147)[0m Avg Train Loss: 0.7242121365335252 - EOD: -0.0882352888584137
Test Accuracy: 0.5360323886639676 - Test Loss: 21.691564119779148 - EOD: 0.0
[36m(ClientAppActor pid=91147)[0m Test Accuracy: 0.53125 - Test Loss: 21.90850321451823 - EOD: 0.0


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan
[36m(ClientAppActor pid=91148)[0m Avg Train Loss: 0.6814531485239664 - EOD: -0.13246268033981323
[36m(ClientAppActor pid=91146)[0m Test Accuracy: 0.4919786096256685 - Test Loss: 21.813975642124813 - EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (2, 20.198211500277885, {'centralized_accuracy': 0.6323886639676113, 'eod': 0.20078730583190918}, 26.45526413107291)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=91144)[0m Avg Train Loss: 0.6314816661179066 - EOD: nan
[36m(ClientAppActor pid=91145)[0m Test Accuracy: 0.5294117647058824 - Test Loss: 19.475313742955525 - EOD: 0.0[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=91146)[0m Avg Train Loss: 0.7214738205075264 - EOD: -0.3589743375778198
Test Accuracy: 0.6323886639676113 - Test Loss: 20.198211500277885 - EOD: 0.20078730583190918


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (3, 20.206304637285378, {'centralized_accuracy': 0.6728744939271255, 'eod': 0.22586244344711304}, 28.16603720607236)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6728744939271255 - Test Loss: 20.206304637285378 - EOD: 0.22586244344711304


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (4, 20.109754246014816, {'centralized_accuracy': 0.6453441295546559, 'eod': 0.243297278881073}, 29.13750467915088)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6453441295546559 - Test Loss: 20.109754246014816 - EOD: 0.243297278881073


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (5, 20.338205669170772, {'centralized_accuracy': 0.6510121457489878, 'eod': 0.20766818523406982}, 30.001876011956483)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6510121457489878 - Test Loss: 20.338205669170772 - EOD: 0.20766818523406982


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


[36m(ClientAppActor pid=91148)[0m Test Accuracy: 0.6684491978609626 - Test Loss: 19.495735595623653 - EOD: nan[32m [repeated 3x across cluster][0m
Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (6, 20.730839665119465, {'centralized_accuracy': 0.6534412955465587, 'eod': 0.19201558828353882}, 31.88646070403047)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=91145)[0m Test Accuracy: 0.6666666666666666 - Test Loss: 12.068859726190567 - EOD: 0.28409093618392944[32m [repeated 16x across cluster][0m
[36m(ClientAppActor pid=91147)[0m Avg Train Loss: 0.6659345726172129 - EOD: 0.26264724135398865[32m [repeated 7x across cluster][0m
Test Accuracy: 0.6534412955465587 - Test Loss: 20.730839665119465 - EOD: 0.19201558828353882


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (7, 21.312108267576267, {'centralized_accuracy': 0.6267206477732794, 'eod': 0.17986547946929932}, 33.60582025395706)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=91148)[0m Avg Train Loss: 0.6361146966616312 - EOD: nan
Test Accuracy: 0.6267206477732794 - Test Loss: 21.312108267576267 - EOD: 0.17986547946929932


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (8, 20.45500750725086, {'centralized_accuracy': 0.6242914979757085, 'eod': 0.165158212184906}, 34.57930748700164)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6242914979757085 - Test Loss: 20.45500750725086 - EOD: 0.165158212184906


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (9, 20.148482438845512, {'centralized_accuracy': 0.6437246963562753, 'eod': 0.20356130599975586}, 35.687598881078884)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6437246963562753 - Test Loss: 20.148482438845512 - EOD: 0.20356130599975586


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


[36m(ClientAppActor pid=91148)[0m Test Accuracy: 0.6524064171122995 - Test Loss: 19.831622501214344 - EOD: nan[32m [repeated 6x across cluster][0m
Aggregated EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (10, 20.15477974903889, {'centralized_accuracy': 0.6534412955465587, 'eod': 0.20356130599975586}, 37.2746043340303)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=91145)[0m Test Accuracy: 0.725 - Test Loss: 11.792560338973999 - EOD: -0.13333332538604736[32m [repeated 13x across cluster][0m
[36m(ClientAppActor pid=91147)[0m Avg Train Loss: 0.6028006315231323 - EOD: 0.17307692766189575[32m [repeated 4x across cluster][0m
Test Accuracy: 0.6534412955465587 - Test Loss: 20.15477974903889 - EOD: 0.20356130599975586


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 38.16s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 20.820538412515337
[92mINFO [0m:      		round 2: 18.553126753185552
[92mINFO [0m:      		round 3: 18.50002596710295
[92mINFO [0m:      		round 4: 18.19934358951383
[92mINFO [0m:      		round 5: 17.561622467454065
[92mINFO [0m:      		round 6: 18.647971057490288
[92mINFO [0m:      		round 7: 19.252544302165962
[92mINFO [0m:      		round 8: 19.870376138316313
[92mINFO [0m:      		round 9: 19.18502187610274
[92mINFO [0m:      		round 10: 18.340530537797655
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 23.19512327053608
[92mINFO [0m:      		round 1: 21.691564119779148
[92mINFO [0m:      		round 2: 20.198211500277885
[92mINFO [0m:      		round 3: 20.206304637285378
[9

Aggregated EOD: nan
[36m(ClientAppActor pid=91145)[0m Avg Train Loss: 0.641953127251731 - EOD: nan[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=91145)[0m Test Accuracy: 0.6559139784946236 - Test Loss: 19.85598025719325 - EOD: nan[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=91146)[0m Test Accuracy: 0.6410256410256411 - Test Loss: 12.131374806165695 - EOD: 0.30000001192092896[32m [repeated 3x across cluster][0m


