In [1]:
import torch, os, copy, time
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader, Dataset
from utils import *
import torch.nn.functional as F
from torch.autograd import Variable
from functools import partial

Epoch: 1 | Batch index: 0 | Batch size: 128
break minibatch for-loop
Epoch: 2 | Batch index: 0 | Batch size: 128
break minibatch for-loop


In [2]:
class logReg(torch.nn.Module):
    # logistic regression 
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.linear = torch.nn.Linear(num_features, num_classes)

    def forward(self, x):
        logits = self.linear(x.float())
        probas = torch.sigmoid(logits)
        return probas.type(torch.FloatTensor), logits

global_model = logReg(num_features=NUM_FEATURES, num_classes=2)
global_weights = global_model.state_dict()

In [3]:
def fairLoss(log_probas, target, sensitive, logits, penalty = 0.1, threshold = 0.01):
    # penalty term: covariance between sensitive attributes and deicision boundary
    # the n, n_k yet to be added!
    return Variable(nn.NLLLoss().to(DEVICE)(log_probas, target) + (
        penalty * torch.sum(((sensitive - mean_sensitive) * (logits.T[0] * target) - threshold) ** 2)
        ), requires_grad = True)

In [4]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        feature, label, sensitive = self.dataset[self.idxs[item]]
        return feature, label, sensitive


class LocalUpdate(object):
    def __init__(self, dataset, idxs, loss_func = None, penalty = 0.01, threshold = 0.1):
        self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
        self.loss_func = loss_func
        # Default criterion set to NLL loss function
        if loss_func == None:
            self.criterion = nn.NLLLoss().to(DEVICE)
        elif loss_func == "fair penalty":
            self.criterion = partial(fairLoss, penalty = penalty, threshold = threshold)
            
    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=BATCH_SIZE, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if OPTIMIZER == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE,
                                        momentum=0.5)
        elif OPTIMIZER == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE,
                                         weight_decay=1e-4)

        for i in range(LOCAL_EPOCHS):
            batch_loss = []
            for batch_idx, (features, labels, sensitive) in enumerate(self.trainloader):
                features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
                # we need to set the gradients to zero before starting to do backpropragation 
                # because PyTorch accumulates the gradients on subsequent backward passes. 
                # This is convenient while training RNNs
                model.zero_grad()
                
                loss = Variable(torch.Tensor([0]), requires_grad = True)
                log_probs, logits = model(features)
                if self.loss_func == None:
                    loss = self.criterion(log_probs, labels)
                elif self.loss_func == "fair penalty":
                    loss = self.criterion(log_probs, labels, sensitive, logits)
                loss.backward()
                optimizer.step()

                if batch_idx % 50 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, i, batch_idx * len(features),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        # weight, loss
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group)
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0
        sp, nsp, s, n = 0, 0, 0, 0
        for batch_idx, (features, labels, sensitive) in enumerate(self.testloader):
            features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
            sensitive = sensitive.to(DEVICE)
            
            # Inference
            outputs, logits = model(features)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            
            bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
            s += torch.sum(bool_sensitive).item()
            n += torch.sum(torch.logical_not(bool_sensitive)).item()
            sp += torch.sum(torch.logical_and(bool_correct, bool_sensitive)).item()
            nsp += torch.sum(torch.logical_and(bool_correct, torch.logical_not(bool_sensitive))).item()
            
            if self.loss_func == None:
                batch_loss = self.criterion(outputs, labels)
            elif self.loss_func == "fair penalty":
                batch_loss = self.criterion(outputs, labels, sensitive, logits)
            loss += batch_loss.item()

        accuracy = correct/total
        return accuracy, loss, s, n, sp, nsp


def test_inference(model, test_dataset):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    sp, nsp, s, n = 0, 0, 0, 0
    
    criterion = nn.NLLLoss().to(DEVICE)
    testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                            shuffle=False)

    for batch_idx, (features, labels, sensitive) in enumerate(testloader):
        features = features.to(DEVICE)
        labels =  labels.to(DEVICE).type(torch.LongTensor)
        # Inference
        outputs, logits = model(features)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        bool_correct = torch.eq(pred_labels, labels)
        correct += torch.sum(bool_correct).item()
        total += len(labels)

        bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
        s += torch.sum(bool_sensitive).item()
        n += torch.sum(torch.logical_not(bool_sensitive)).item()
        sp += torch.sum(torch.logical_and(bool_correct, bool_sensitive)).item()
        nsp += torch.sum(torch.logical_and(bool_correct, torch.logical_not(bool_sensitive))).item()

    accuracy = correct/total
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    return accuracy, loss, abs(sp/s-nsp/n)

In [5]:
def train(loss_func = None, penalty = None, threshold = None):
    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0
    start_time = time.time()

    def average_weights(w):
        """
        Returns the average of the weights.
        """
        w_avg = copy.deepcopy(w[0])
        for key in w_avg.keys():
            for i in range(1, len(w)):
                w_avg[key] += w[i][key]
            w_avg[key] = torch.div(w_avg[key], len(w))
        return w_avg

    for epoch in tqdm(range(NUM_EPOCHS)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(FRAC * NUM_CLIENTS), 1) # the number of clients to be chosen in each epoch
        idxs_users = np.random.choice(range(NUM_CLIENTS), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[idx], loss_func = loss_func, penalty = penalty, threshold = threshold)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        s, n, sp, nsp = 0, 0, 0, 0
        global_model.eval()
        for c in range(NUM_CLIENTS):
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[idx], loss_func= loss_func, penalty = penalty, threshold = threshold)
            acc, loss, s_, n_, sp_, nsp_ = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
            s += s_
            n += n_
            sp += sp_
            nsp += nsp_
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}%'.format(100*train_accuracy[-1]))
            print('Train Risk Difference: {:.2f} \n'.format(abs(sp/s-nsp/n)))

    # Test inference after completion of training
    test_acc, test_loss, rd= test_inference(global_model, test_dataset)

    print(f' \n Results after {NUM_EPOCHS} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Compute RD: risk difference - fairness metric
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    print("|---- Test RD: {:.2f}".format(rd))

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))

In [None]:
train("fair penalty", 0.05, 0.01)


  0%|          | 0/10 [00:00<?, ?it/s][A


 | Global Training Round : 1 |




 10%|█         | 1/10 [00:07<01:07,  7.51s/it][A


 | Global Training Round : 2 |



In [7]:
train()

  0%|          | 0/10 [00:00<?, ?it/s]


 | Global Training Round : 1 |



 10%|█         | 1/10 [00:08<01:15,  8.44s/it]


 | Global Training Round : 2 |



 20%|██        | 2/10 [00:16<01:06,  8.33s/it]

 
Avg Training Stats after 2 global rounds:
Training Loss : -0.9987077199955163
Train Accuracy: 72.99%
Train Risk Difference: 0.20 


 | Global Training Round : 3 |



 30%|███       | 3/10 [00:24<00:57,  8.16s/it]


 | Global Training Round : 4 |



 40%|████      | 4/10 [00:31<00:47,  7.94s/it]

 
Avg Training Stats after 4 global rounds:
Training Loss : -0.9991881446248791
Train Accuracy: 71.54%
Train Risk Difference: 0.19 


 | Global Training Round : 5 |



 50%|█████     | 5/10 [00:39<00:38,  7.77s/it]


 | Global Training Round : 6 |



 60%|██████    | 6/10 [00:48<00:32,  8.17s/it]

 
Avg Training Stats after 6 global rounds:
Training Loss : -0.9993482747404409
Train Accuracy: 71.29%
Train Risk Difference: 0.19 


 | Global Training Round : 7 |



 70%|███████   | 7/10 [00:55<00:23,  7.97s/it]


 | Global Training Round : 8 |



 80%|████████  | 8/10 [01:02<00:15,  7.64s/it]

 
Avg Training Stats after 8 global rounds:
Training Loss : -0.9994283250194176
Train Accuracy: 73.07%
Train Risk Difference: 0.20 


 | Global Training Round : 9 |



 90%|█████████ | 9/10 [01:09<00:07,  7.40s/it]


 | Global Training Round : 10 |



100%|██████████| 10/10 [01:17<00:00,  7.71s/it]

 
Avg Training Stats after 10 global rounds:
Training Loss : -0.9994763675014579
Train Accuracy: 71.54%
Train Risk Difference: 0.19 






 
 Results after 10 global rounds of training:
|---- Avg Train Accuracy: 71.54%
|---- Test Accuracy: 76.44%
|---- Test RD: 0.19

 Total Run Time: 77.4684
