In [1]:
# Make sure you're on Python > 3.8
!pip install -r requirements.txt --quiet

In [2]:
from collections import OrderedDict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split

import flwr as fl
from flwr.simulation import run_simulation
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents

In [3]:
DEVICE = torch.device('cpu')

In [4]:
!mkdir '.kaggle'
!mkdir '.kaggle/data'

with open(".kaggle/kaggle.json", 'a+') as f:
    f.write('{"username":"rajaxarcmu","key":"68d40c5e38e1c786ab57736bc5c9b2cb"}')
    
!chmod 600 '.kaggle/kaggle.json'
!kaggle datasets download -d 'danofer/compass'
!unzip -qo compass.zip -d '.kaggle/data'

mkdir: .kaggle: File exists
mkdir: .kaggle/data: File exists
Dataset URL: https://www.kaggle.com/datasets/danofer/compass
License(s): DbCL-1.0
compass.zip: Skipping, found more recently modified local copy (use --force to force download)


In [5]:
!ls .kaggle/data

compas-scores-raw.csv
cox-violent-parsed.csv
cox-violent-parsed_filt.csv
[1m[36mpropublicaCompassRecividism_data_fairml.csv[m[m


In [6]:
df = pd.read_csv('.kaggle/data/propublicaCompassRecividism_data_fairml.csv/propublica_data_for_fairml.csv')
print(df.shape)

(6172, 12)


In [7]:
df['caucasian'] = ((df['African_American'] + df['Asian'] + df['Hispanic'] + df['Native_American'] + df['Other']) == 0).astype(int)

In [8]:
NUM_CLIENTS = 10
# REPRESENTS SILO'D ORGANIZATIONS

In [9]:
from datasets import Dataset
from flwr_datasets.partitioner import DirichletPartitioner

In [10]:
trainset, testset = train_test_split(df, test_size=0.2)
batch_size = 32

ds = Dataset.from_pandas(trainset)
partitioner = DirichletPartitioner(
    num_partitions=NUM_CLIENTS,
    partition_by="caucasian",
    alpha=0.5,
    min_partition_size=(len(trainset) // (4 * NUM_CLIENTS)),
    self_balancing=True,
    shuffle=True
)

partitioner.dataset = ds
datasets = []
for i in range(NUM_CLIENTS):
    curr_partition = partitioner.load_partition(i)
    datasets.append(curr_partition.to_pandas())

train_loaders = []
val_loaders = []

feature_columns = ['Number_of_Priors', 'score_factor','Age_Above_FourtyFive', 'Age_Below_TwentyFive', 'Misdemeanor']

for ds in datasets:
    train_x = ds[feature_columns].values
    train_y = ds['Two_yr_Recidivism'].values
    sensitive_feature = ds['caucasian'].values

    train_x, val_x, train_y, val_y, sensitive_train, sensitive_val = train_test_split(
        train_x, train_y, sensitive_feature, test_size=0.25, shuffle=True, stratify=train_y, random_state=42
    )
    
    train_x_tensor = torch.from_numpy(train_x).float()
    train_y_tensor = torch.from_numpy(train_y).float()
    sensitive_train_tensor = torch.from_numpy(sensitive_train).float()

    valid_x_tensor = torch.from_numpy(val_x).float()
    valid_y_tensor = torch.from_numpy(val_y).float()
    sensitive_val_tensor = torch.from_numpy(sensitive_val).float()

    # Create TensorDataset and DataLoader, including the sensitive attribute
    train_dataset = TensorDataset(train_x_tensor, train_y_tensor, sensitive_train_tensor)
    valid_dataset = TensorDataset(valid_x_tensor, valid_y_tensor, sensitive_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(valid_dataset, batch_size=batch_size)

    train_loaders.append(train_loader)
    val_loaders.append(val_loader)

# For test data
test_x = testset[feature_columns].values
test_y = testset['Two_yr_Recidivism'].values
sensitive_test = testset['caucasian'].values

test_x_tensor = torch.from_numpy(test_x).float()
test_y_tensor = torch.from_numpy(test_y).float()
sensitive_test_tensor = torch.from_numpy(sensitive_test).float()

test_dataset = TensorDataset(test_x_tensor, test_y_tensor, sensitive_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [11]:
class BaselineNN(nn.Module):
    def __init__(self):
        super(BaselineNN, self).__init__()
        self.fc1 = nn.Linear(5, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 1)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

def compute_eod(preds, labels, sensitive_feature):
    preds_binary = (preds >= 0.5).float()
    y_true_mask = (labels == 1).view(-1)

    p_a0 = preds_binary[y_true_mask & (sensitive_feature == 0)].mean().item()
    p_a1 = preds_binary[y_true_mask & (sensitive_feature == 1)].mean().item()

    eod = p_a0 - p_a1
    return eod

def train(net, trainloader, epochs, verbose=True):
    """
    Train Network on Training Set
    """
    criterion = nn.BCELoss()
    optimizer = optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        all_preds, all_labels, all_sensitives = [], [], []
        
        for inputs, labels, sensitive_features in trainloader:
            inputs, labels, sensitive_features = inputs.to(DEVICE), labels.to(DEVICE), sensitive_features.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            labels = labels.view(-1, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * inputs.size(0)
            predicted = (outputs >= 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Append predictions and sensitive data for EOD computation
            all_preds.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sensitives.append(sensitive_features.cpu())
        
        # Compute EOD at the end of the epoch
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_sensitives = torch.cat(all_sensitives)
        
        eod = compute_eod(all_preds, all_labels, all_sensitives)
        
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f} - EOD: {eod:.4f}")

def test(net, testloader, verbose=True):
    criterion = nn.BCELoss()
    net.eval()
    correct, total, loss = 0, 0, 0.0
    all_preds, all_labels, all_sensitives = [], [], []
    
    with torch.no_grad():
        for inputs, labels, sensitive_features in testloader:
            inputs, labels, sensitive_features = inputs.to(DEVICE), labels.to(DEVICE), sensitive_features.to(DEVICE)
            outputs = net(inputs)
            labels = labels.view(-1, 1)
            loss += criterion(outputs, labels).item() * inputs.size(0)
            predicted = (outputs >= 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Append predictions and sensitive data for EOD computation
            all_preds.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sensitives.append(sensitive_features.cpu())
    
    # Compute EOD at the end of testing
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_sensitives = torch.cat(all_sensitives)
    
    eod = compute_eod(all_preds, all_labels, all_sensitives)
    
    loss /= len(testloader.dataset)
    acc = correct / total
    if verbose:
        print(f"Test Loss: {loss:.4f} - Acc: {acc:.4f} - EOD: {eod:.4f}")
    return loss, acc, eod

# Centralized Learning

In [12]:
model = BaselineNN()

In [13]:
for i in range(NUM_CLIENTS):
    train_loader = train_loaders[i]
    val_loader = val_loaders[i]
    model = model.to(DEVICE)
    epochs = 10

    for epoch in range(epochs):
        train(model, train_loader, 1, verbose=False)
        loss, acc, eod = test(model, val_loader, verbose=False)

    loss, acc, eod = test(model, test_loader, verbose=False)
    print(f"Client {i} - Test Loss: {loss:.4f} - Acc: {acc:.4f} - EOD: {eod:.4f}")

Client 0 - Test Loss: 0.6418 - Acc: 0.6583 - EOD: 0.1930
Client 1 - Test Loss: 0.6324 - Acc: 0.6543 - EOD: 0.1956
Client 2 - Test Loss: 0.6143 - Acc: 0.6737 - EOD: 0.1930
Client 3 - Test Loss: 0.6165 - Acc: 0.6664 - EOD: 0.2099
Client 4 - Test Loss: 0.6117 - Acc: 0.6688 - EOD: 0.2093
Client 5 - Test Loss: 0.6185 - Acc: 0.6729 - EOD: 0.1939
Client 6 - Test Loss: 0.6108 - Acc: 0.6713 - EOD: 0.2442
Client 7 - Test Loss: 0.6097 - Acc: 0.6729 - EOD: 0.1939
Client 8 - Test Loss: 0.6075 - Acc: 0.6713 - EOD: 0.1978
Client 9 - Test Loss: 0.6119 - Acc: 0.6745 - EOD: 0.2664


# Federated Learning with Flower

In [14]:
def get_params(net):
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_params(net, params):
    params_dict = zip(net.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.Tensor(v) for k,v in params_dict})
    net.load_state_dict(state_dict, strict=True)
    
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
    
    def get_parameters(self, config):
        return get_params(self.net)
    
    def fit(self, parameters, config):
        set_params(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_params(self.net), len(self.trainloader), {}
    
    def evaluate(self, parameters, config):
        set_params(self.net, parameters)
        loss, acc, eod = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {'accuracy': float(acc), 'eod': float(eod)}


In [15]:
from custom_flwr.server_app import server_fn as server_fn_custom

In [16]:
def client_fn(context: Context):
    net = BaselineNN().to(DEVICE)
    partition_id = context.node_config['partition-id']
    trainloader = train_loaders[partition_id]
    valloader = val_loaders[partition_id]
    return FlowerClient(net, trainloader, valloader)

def server_fn(context: Context):
    context.run_config = {
        'num-server-rounds' : 10,
        'fraction-fit': 0.25,
        'fraction-evaluate': 0.5,
        'local-epochs': 1,
        'server-device': str(DEVICE),
        'use-wandb': False
    }
    return server_fn_custom(context, test_loader)

client = ClientApp(client_fn=client_fn)
server = ServerApp(server_fn=server_fn)

In [17]:
backend_config = {"client_resources": None}
NUM_PARTITIONS = 10
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      💡 New best global model found: 0.553036
[92mINFO [0m:      initial parameters (loss, other metrics): 22.384248855786446, {'centralized_accuracy': 0.5530364372469636, 'eod': 0.0}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]


Test Accuracy: 0.5530364372469636 - Test Loss: 22.384248855786446 - EOD: 0.0


[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (1, 21.80024957504028, {'centralized_accuracy': 0.5530364372469636, 'eod': 0.0}, 13.462856104131788)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=65134)[0m Epoch 1/1 - Loss: 0.6999 - Acc: 0.5503 - EOD: 0.0000
[36m(ClientAppActor pid=65133)[0m Epoch 1/1 - Loss: 0.7103 - Acc: 0.4964 - EOD: nan
Test Accuracy: 0.5530364372469636 - Test Loss: 21.80024957504028 - EOD: 0.0
[36m(ClientAppActor pid=65133)[0m Test Loss: 0.6884 - Acc: 0.5469 - EOD: 0.0000


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


[36m(ClientAppActor pid=65131)[0m Test Loss: 0.6963 - Acc: 0.5000 - EOD: nan


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      💡 New best global model found: 0.626721
[92mINFO [0m:      fit progress: (2, 21.295907740409557, {'centralized_accuracy': 0.6267206477732794, 'eod': 0.18359017372131348}, 16.02146497112699)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6267206477732794 - Test Loss: 21.295907740409557 - EOD: 0.18359017372131348


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      💡 New best global model found: 0.640486
[92mINFO [0m:      fit progress: (3, 20.892916448605366, {'centralized_accuracy': 0.6404858299595142, 'eod': 0.1884872019290924}, 17.112032372970134)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6404858299595142 - Test Loss: 20.892916448605366 - EOD: 0.1884872019290924


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      💡 New best global model found: 0.653441
[92mINFO [0m:      fit progress: (4, 20.770057470370563, {'centralized_accuracy': 0.6534412955465587, 'eod': 0.1995295286178589}, 17.596811874071136)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6534412955465587 - Test Loss: 20.770057470370563 - EOD: 0.1995295286178589


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (5, 20.719993577553677, {'centralized_accuracy': 0.6461538461538462, 'eod': 0.2073071300983429}, 17.95271952706389)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.6461538461538462 - Test Loss: 20.719993577553677 - EOD: 0.2073071300983429


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      💡 New best global model found: 0.659109
[92mINFO [0m:      fit progress: (6, 20.578924504610207, {'centralized_accuracy': 0.6591093117408907, 'eod': 0.18800711631774902}, 18.446956406114623)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=65131)[0m Epoch 1/1 - Loss: 0.6519 - Acc: 0.6639 - EOD: 0.0252[32m [repeated 7x across cluster][0m
Test Accuracy: 0.6591093117408907 - Test Loss: 20.578924504610207 - EOD: 0.18800711631774902


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)


[36m(ClientAppActor pid=65130)[0m Epoch 1/1 - Loss: 0.6553 - Acc: 0.6572 - EOD: nan[32m [repeated 2x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (7, 20.553733727870842, {'centralized_accuracy': 0.6461538461538462, 'eod': 0.2073071300983429}, 18.902246857993305)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=65134)[0m Test Loss: 0.6540 - Acc: 0.6238 - EOD: -0.5435[32m [repeated 20x across cluster][0m
Test Accuracy: 0.6461538461538462 - Test Loss: 20.553733727870842 - EOD: 0.2073071300983429


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      💡 New best global model found: 0.659919
[92mINFO [0m:      fit progress: (8, 20.448351196753674, {'centralized_accuracy': 0.659919028340081, 'eod': 0.18800711631774902}, 19.364062088076025)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.659919028340081 - Test Loss: 20.448351196753674 - EOD: 0.18800711631774902


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (9, 20.33809931308795, {'centralized_accuracy': 0.659919028340081, 'eod': 0.18800711631774902}, 19.88063771906309)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


Test Accuracy: 0.659919028340081 - Test Loss: 20.33809931308795 - EOD: 0.18800711631774902


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 10)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      fit progress: (10, 20.30760893607751, {'centralized_accuracy': 0.6461538461538462, 'eod': 0.1947285234928131}, 20.338301906129345)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 10)


[36m(ClientAppActor pid=65131)[0m Test Loss: 0.6197 - Acc: 0.6790 - EOD: nan[32m [repeated 11x across cluster][0m
Test Accuracy: 0.6461538461538462 - Test Loss: 20.30760893607751 - EOD: 0.1947285234928131


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 20.56s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.6868380160278367
[92mINFO [0m:      		round 2: 0.6661250619979376
[92mINFO [0m:      		round 3: 0.6600953224065372
[92mINFO [0m:      		round 4: 0.6404390985258526
[92mINFO [0m:      		round 5: 0.6413888823938307
[92mINFO [0m:      		round 6: 0.6389107434800853
[92mINFO [0m:      		round 7: 0.6440523690372982
[92mINFO [0m:      		round 8: 0.6363686142750431
[92mINFO [0m:      		round 9: 0.6444588561491237
[92mINFO [0m:      		round 10: 0.616643865734323
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 22.384248855786446
[92mINFO [0m:      		round 1: 21.80024957504028
[92mINFO [0m:      		round 2: 21.295907740409557
[92mINFO [0m:      		round 3: 20.892916448605366


[36m(ClientAppActor pid=65130)[0m Epoch 1/1 - Loss: 0.6384 - Acc: 0.6587 - EOD: 0.0556[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=65134)[0m Epoch 1/1 - Loss: 0.6487 - Acc: 0.6201 - EOD: nan
[36m(ClientAppActor pid=65133)[0m Test Loss: 0.6220 - Acc: 0.6923 - EOD: 0.2698[32m [repeated 14x across cluster][0m
[36m(ClientAppActor pid=65134)[0m Test Loss: 0.6139 - Acc: 0.6797 - EOD: nan[32m [repeated 3x across cluster][0m


