In [1]:
import torch, os, copy, time
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader, Dataset
from load_adult import *
import torch.nn.functional as F
from torch.autograd import Variable
from functools import partial
import pandas as pd
from utils import *

In [2]:
class logReg(torch.nn.Module):
    # logistic regression 
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.linear = torch.nn.Linear(num_features, num_classes)

    def forward(self, x):
        logits = self.linear(x.float())
        probas = torch.sigmoid(logits)
        return probas.type(torch.FloatTensor), logits

In [3]:
def loss_func(option, logits, targets, distance, sensitive, mean_sensitive, larg = 1):
    acc_loss = F.cross_entropy(logits, targets, reduction = 'sum')
    fair_loss = torch.mul(sensitive - sensitive.type(torch.FloatTensor).mean(), distance.T[0])
    fair_loss = torch.mean(torch.mul(fair_loss, fair_loss)) # modified mean to sum
    if option == 'unconstrained':
        return acc_loss, acc_loss, larg*fair_loss
    if option == 'Zafar':
        return acc_loss + larg*fair_loss, acc_loss, larg*fair_loss

In [4]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        feature, label, sensitive = self.dataset[self.idxs[item]]
        return feature, label, sensitive


class LocalUpdate(object):
    def __init__(self, dataset, idxs, batch_size, option, penalty = 0):
        self.trainloader, self.validloader = self.train_val(dataset, list(idxs), batch_size)
        self.option = option
        self.penalty = penalty
            
    def train_val(self, dataset, idxs, batch_size):
        """
        Returns train, validation for a given local training dataset
        and user indexes.
        """
        # split indexes for train, validation (90, 10)

        idxs_train = idxs[:int(0.9*len(idxs))]
        idxs_val = idxs[int(0.9*len(idxs)):len(idxs)]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=batch_size, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        return trainloader, validloader

    def update_weights(self, model, global_round, learning_rate, local_epochs, optimizer):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                                        ) # momentum=0.5
        elif optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                         weight_decay=1e-4)

        for i in range(local_epochs):
            batch_loss = []
            for batch_idx, (features, labels, sensitive) in enumerate(self.trainloader):
                features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
                # we need to set the gradients to zero before starting to do backpropragation 
                # because PyTorch accumulates the gradients on subsequent backward passes. 
                # This is convenient while training RNNs
                
                log_probs, logits = model(features)
                loss, _, _ = loss_func(self.option,
                    logits, labels, logits, sensitive, mean_sensitive, self.penalty)
                    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if batch_idx % 50 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format(
                        global_round, i, batch_idx * len(features),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        # weight, loss
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group),
                                acc_loss,
                                fair_loss
        """

        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        sp, nsp, s, n = 0, 0, 0, 0
        for batch_idx, (features, labels, sensitive) in enumerate(self.validloader):
            features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
            sensitive = sensitive.to(DEVICE)
            
            # Inference
            outputs, logits = model(features)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1
            bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
            s += torch.sum(bool_sensitive).item()
            n += torch.sum(torch.logical_not(bool_sensitive)).item()
            sp += torch.sum(torch.logical_and(pred_labels, bool_sensitive)).item()
            nsp += torch.sum(torch.logical_and(pred_labels, torch.logical_not(bool_sensitive))).item()
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(self.option, outputs, 
                                                        labels, logits, sensitive, mean_sensitive, self.penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
        accuracy = correct/total
        return accuracy, loss, s, n, sp, nsp, acc_loss / num_batch, fair_loss / num_batch


def test_inference(model, test_dataset, batch_size):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    sp, nsp, s, n = 0, 0, 0, 0
    
    criterion = nn.NLLLoss().to(DEVICE)
    testloader = DataLoader(test_dataset, batch_size=batch_size,
                            shuffle=False)

    for batch_idx, (features, labels, sensitive) in enumerate(testloader):
        features = features.to(DEVICE)
        labels =  labels.to(DEVICE).type(torch.LongTensor)
        # Inference
        outputs, logits = model(features)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        bool_correct = torch.eq(pred_labels, labels)
        correct += torch.sum(bool_correct).item()
        total += len(labels)

        bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
        s += torch.sum(bool_sensitive).item()
        n += torch.sum(torch.logical_not(bool_sensitive) ).item()
        sp += torch.sum(torch.logical_and(bool_correct, bool_sensitive)).item()
        nsp += torch.sum(torch.logical_and(bool_correct, torch.logical_not(bool_sensitive))).item()

    accuracy = correct/total
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    return accuracy, loss, abs(sp/s-nsp/n)

In [5]:
def train(model, option = "unconstrained", batch_size = 128,
          num_rounds = 5, learning_rate = 0.01, optimizer = 'adam', local_epochs= 5, num_workers = 4, print_every = 1,
         penalty = 1):
    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    start_time = time.time()
    weights = model.state_dict()
    
    client1_loader = DataLoader(dataset = client1_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    client2_loader = DataLoader(dataset = client2_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    test_loader = DataLoader(dataset = test_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    train_loader = DataLoader(dataset = train_dataset,
                        batch_size = batch_size,
                        num_workers = num_workers)

    def average_weights(w):
        """
        Returns the average of the weights.
        """
        w_avg = copy.deepcopy(w[0])
        for key in w_avg.keys():
            for i in range(1, len(w)):
                w_avg[key] += w[i][key]
            w_avg[key] = torch.div(w_avg[key], len(w))
        return w_avg

    for round_ in tqdm(range(num_rounds)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {round_+1} |\n')

        model.train()
        m = 2 # the number of clients to be chosen in each round_
        idxs_users = np.random.choice(range(2), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[idx], batch_size = batch_size, option = option, penalty = penalty)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(model), global_round=round_, 
                learning_rate = learning_rate, local_epochs = local_epochs, optimizer = optimizer)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        weights = average_weights(local_weights)
        model.load_state_dict(weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every round
        list_acc, list_loss = [], []
        s, n, sp, nsp = 0, 0, 0, 0
        model.eval()
        for c in range(2):
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[c], batch_size = batch_size, option = option, penalty = penalty)
            # validation dataset inference
            acc, loss, s_, n_, sp_, nsp_, acc_loss, fair_loss = local_model.inference(model=model) 
            list_acc.append(acc)
            list_loss.append(loss)
            s, n, sp, nsp = s + s_, n + n_, sp + sp_, nsp + nsp_
            print("Client %d: accuracy loss: %.2f | fairness loss %.2f | RD = %.2f = |%d/%d-%d/%d| " % (
                c, acc_loss, fair_loss, abs(sp_/s_-nsp_/n_), sp_, s_, nsp_, n_))
            
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (round_+1) % print_every == 0:
            print(f' \nAvg Training Stats after {round_+1} global rounds:')
            print("Training loss: %.2f | Validation accuracy: %.2f%% | Validation RD: %.2f" % (
                 np.mean(np.array(train_loss)), 
                100*train_accuracy[-1],
                abs(sp/s-nsp/n)
                 ))

    # Test inference after completion of training
    test_acc, test_loss, rd= test_inference(model, test_dataset, batch_size)

    print(f' \n Results after {num_rounds} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Compute RD: risk difference - fairness metric
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    print("|---- Test RD: {:.2f}".format(rd))

    print('\n Total Run Time: {0:0.4f} sec'.format(time.time()-start_time))

In [6]:
train(logReg(num_features=NUM_FEATURES, num_classes=2), 
      "Zafar", penalty = 50, optimizer = 'sgd', learning_rate = 0.01,
     num_rounds = 5, local_epochs = 10)

  0%|          | 0/5 [00:00<?, ?it/s]


 | Global Training Round : 1 |

Client 0: accuracy loss: 107.03 | fairness loss 0.60 | RD = 0.30 = |98/592-661/1421| 


 20%|██        | 1/5 [00:09<00:37,  9.41s/it]

Client 1: accuracy loss: 63.55 | fairness loss 0.87 | RD = 0.24 = |57/421-308/823| 
 
Avg Training Stats after 1 global rounds:
Training loss: 81.47 | Validation accuracy: 81.13% | Validation RD: 0.28

 | Global Training Round : 2 |



 40%|████      | 2/5 [00:17<00:27,  9.09s/it]

Client 0: accuracy loss: 104.85 | fairness loss 1.66 | RD = 0.31 = |75/592-619/1421| 
Client 1: accuracy loss: 62.70 | fairness loss 1.84 | RD = 0.28 = |42/421-312/823| 
 
Avg Training Stats after 2 global rounds:
Training loss: 76.30 | Validation accuracy: 81.92% | Validation RD: 0.30

 | Global Training Round : 3 |



 60%|██████    | 3/5 [00:27<00:18,  9.18s/it]

Client 0: accuracy loss: 101.04 | fairness loss 5.46 | RD = 0.20 = |59/592-419/1421| 
Client 1: accuracy loss: 60.52 | fairness loss 7.45 | RD = 0.16 = |36/421-199/823| 
 
Avg Training Stats after 3 global rounds:
Training loss: 74.14 | Validation accuracy: 83.14% | Validation RD: 0.18

 | Global Training Round : 4 |



 80%|████████  | 4/5 [00:34<00:08,  8.61s/it]

Client 0: accuracy loss: 105.49 | fairness loss 0.28 | RD = 0.22 = |87/592-517/1421| 
Client 1: accuracy loss: 63.73 | fairness loss 0.29 | RD = 0.21 = |46/421-266/823| 
 
Avg Training Stats after 4 global rounds:
Training loss: 72.62 | Validation accuracy: 83.87% | Validation RD: 0.22

 | Global Training Round : 5 |

Client 0: accuracy loss: 103.66 | fairness loss 0.85 | RD = 0.26 = |67/592-535/1421| 


100%|██████████| 5/5 [00:42<00:00,  8.50s/it]

Client 1: accuracy loss: 62.44 | fairness loss 1.00 | RD = 0.25 = |40/421-282/823| 
 
Avg Training Stats after 5 global rounds:
Training loss: 71.56 | Validation accuracy: 83.51% | Validation RD: 0.26





 
 Results after 5 global rounds of training:
|---- Avg Train Accuracy: 83.51%
|---- Test Accuracy: 84.45%
|---- Test RD: 0.12

 Total Run Time: 42.8995 sec


In [7]:
train(logReg(num_features=NUM_FEATURES, num_classes=2), optimizer = 'sgd', learning_rate = 0.01)

  0%|          | 0/5 [00:00<?, ?it/s]


 | Global Training Round : 1 |

Client 0: accuracy loss: 103.63 | fairness loss 0.43 | RD = 0.30 = |101/592-669/1421| 


 20%|██        | 1/5 [00:03<00:15,  3.75s/it]

Client 1: accuracy loss: 60.87 | fairness loss 0.61 | RD = 0.25 = |56/421-313/823| 
 
Avg Training Stats after 1 global rounds:
Training loss: 48.40 | Validation accuracy: 80.87% | Validation RD: 0.28

 | Global Training Round : 2 |



 40%|████      | 2/5 [00:08<00:11,  3.93s/it]

Client 0: accuracy loss: 104.34 | fairness loss 0.46 | RD = 0.38 = |99/592-775/1421| 
Client 1: accuracy loss: 61.65 | fairness loss 0.57 | RD = 0.33 = |58/421-384/823| 
 
Avg Training Stats after 2 global rounds:
Training loss: 47.04 | Validation accuracy: 79.31% | Validation RD: 0.36

 | Global Training Round : 3 |



 60%|██████    | 3/5 [00:13<00:08,  4.43s/it]

Client 0: accuracy loss: 100.14 | fairness loss 0.59 | RD = 0.27 = |77/592-567/1421| 
Client 1: accuracy loss: 59.61 | fairness loss 0.72 | RD = 0.26 = |42/421-299/823| 
 
Avg Training Stats after 3 global rounds:
Training loss: 46.31 | Validation accuracy: 82.83% | Validation RD: 0.27

 | Global Training Round : 4 |



 80%|████████  | 4/5 [00:18<00:04,  4.50s/it]

Client 0: accuracy loss: 100.26 | fairness loss 0.59 | RD = 0.28 = |82/592-594/1421| 
Client 1: accuracy loss: 59.69 | fairness loss 0.70 | RD = 0.27 = |48/421-315/823| 
 
Avg Training Stats after 4 global rounds:
Training loss: 45.82 | Validation accuracy: 82.61% | Validation RD: 0.28

 | Global Training Round : 5 |

Client 0: accuracy loss: 98.34 | fairness loss 0.65 | RD = 0.19 = |77/592-455/1421| 


100%|██████████| 5/5 [00:22<00:00,  4.53s/it]

Client 1: accuracy loss: 58.96 | fairness loss 0.76 | RD = 0.19 = |42/421-240/823| 
 
Avg Training Stats after 5 global rounds:
Training loss: 45.48 | Validation accuracy: 83.66% | Validation RD: 0.19





 
 Results after 5 global rounds of training:
|---- Avg Train Accuracy: 83.66%
|---- Test Accuracy: 84.58%
|---- Test RD: 0.12

 Total Run Time: 23.0370 sec


In [17]:
class MyNumbers:
    def __iter__(self):
        self.a = 1
        return self


TypeError: iter() returned non-iterator of type 'MyNumbers'

In [22]:
client1_dataset[:5]

(tensor([[0.0000, 0.0818, 0.4000, 0.0000, 0.0000, 0.1429, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.301