In [1]:
import torch, os, copy, time
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader, Dataset
from load_adult import *
import torch.nn.functional as F
from torch.autograd import Variable
from functools import partial

In [2]:
class logReg(torch.nn.Module):
    # logistic regression 
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.linear = torch.nn.Linear(num_features, num_classes)

    def forward(self, x):
        logits = self.linear(x.float())
        probas = torch.sigmoid(logits)
        return probas.type(torch.FloatTensor), logits

global_model = logReg(num_features=NUM_FEATURES, num_classes=2)
global_weights = global_model.state_dict()

In [3]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        feature, label, sensitive = self.dataset[self.idxs[item]]
        return feature, label, sensitive


class LocalUpdate(object):
    def __init__(self, dataset, idxs, batch_size, loss_func, penalty = 0):
        self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs), batch_size)
        self.criterion = loss_func
        self.penalty = penalty
            
    def train_val_test(self, dataset, idxs, batch_size):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=batch_size, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round, learning_rate, local_epochs, optimizer):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                                        momentum=0.5)
        elif optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                         weight_decay=1e-4)

        for i in range(local_epochs):
            batch_loss = []
            for batch_idx, (features, labels, sensitive) in enumerate(self.trainloader):
                features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
                # we need to set the gradients to zero before starting to do backpropragation 
                # because PyTorch accumulates the gradients on subsequent backward passes. 
                # This is convenient while training RNNs
                model.zero_grad()
                
                log_probs, logits = model(features)
                if self.criterion == loss_no_fair:
                    loss = self.criterion(log_probs, labels)
                elif self.criterion == loss_with_fair:
                    loss = self.criterion(log_probs, labels, logits, sensitive, mean_sensitive, self.penalty)
                loss.backward()
                optimizer.step()

                if batch_idx % 50 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, i, batch_idx * len(features),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        # weight, loss
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group)
        """

        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0
        sp, nsp, s, n = 0, 0, 0, 0
        for batch_idx, (features, labels, sensitive) in enumerate(self.testloader):
            features, labels = features.to(DEVICE), labels.to(DEVICE).type(torch.LongTensor)
            sensitive = sensitive.to(DEVICE)
            
            # Inference
            outputs, logits = model(features)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            # print(pred_labels)
            bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
#             print("In each batch, s: %d, n: %d, sp: %d, nsp: %d" % 
#                   (torch.sum(bool_sensitive).item(), 
#                    torch.sum(torch.logical_not(bool_sensitive)).item(),
#                    torch.sum(torch.mul(pred_labels, bool_sensitive)).item(),
#                    torch.sum(torch.mul(pred_labels, torch.logical_not(bool_sensitive))).item()
#                  ))
            s += torch.sum(bool_sensitive).item()
            n += torch.sum(torch.logical_not(bool_sensitive)).item()
            sp += torch.sum(torch.logical_and(pred_labels, bool_sensitive)).item()
            nsp += torch.sum(torch.logical_and(pred_labels, torch.logical_not(bool_sensitive))).item()
            
            if self.criterion == loss_no_fair:
                batch_loss = self.criterion(outputs, labels)
            elif self.criterion == loss_with_fair:
                batch_loss = self.criterion(outputs, labels, logits, sensitive, mean_sensitive, self.penalty)

            loss += batch_loss.item()

        accuracy = correct/total
        # print("s: %d, n: %d, sp: %d, nsp: %d" % (s, n, sp, nsp))
        return accuracy, loss, s, n, sp, nsp


def test_inference(model, test_dataset, batch_size):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    sp, nsp, s, n = 0, 0, 0, 0
    
    criterion = nn.NLLLoss().to(DEVICE)
    testloader = DataLoader(test_dataset, batch_size=batch_size,
                            shuffle=False)

    for batch_idx, (features, labels, sensitive) in enumerate(testloader):
        features = features.to(DEVICE)
        labels =  labels.to(DEVICE).type(torch.LongTensor)
        # Inference
        outputs, logits = model(features)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        bool_correct = torch.eq(pred_labels, labels)
        correct += torch.sum(bool_correct).item()
        total += len(labels)

        bool_sensitive = torch.eq(sensitive, torch.ones(len(labels)))
        s += torch.sum(bool_sensitive).item()
        n += torch.sum(torch.logical_not(bool_sensitive) ).item()
        sp += torch.sum(torch.logical_and(bool_correct, bool_sensitive)).item()
        nsp += torch.sum(torch.logical_and(bool_correct, torch.logical_not(bool_sensitive))).item()

    accuracy = correct/total
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    return accuracy, loss, abs(sp/s-nsp/n)

In [4]:
def loss_no_fair(output, target):
    loss = torch.mean((-target * torch.log(output.T[1])- (1 - target) * torch.log(output.T[0])))
    return loss

def loss_with_fair(output, target, Theta_X, Sen, Sen_bar, larg = 0):
    pred_loss = torch.mean((-target * torch.log(output.T[1])- (1 - target) * torch.log(output.T[0])))
    fair_loss = torch.mul(Sen - Sen_bar, Theta_X.T[0])
    fair_loss = torch.mean(torch.mul(fair_loss, fair_loss))
    return pred_loss + larg*fair_loss

In [None]:
def train(loss_func = loss_no_fair, batch_size = 128,
          num_epochs = 5, learning_rate = 0.01, optimizer = 'adam', local_epochs= 5, num_workers = 4, print_every = 1,
         penalty = 0.0001):
    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    start_time = time.time()
    
    client1_loader = DataLoader(dataset = client1_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    client2_loader = DataLoader(dataset = client2_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    test_loader = DataLoader(dataset = test_dataset,
                            batch_size = batch_size,
                            num_workers = num_workers)
    
    train_loader = DataLoader(dataset = train_dataset,
                        batch_size = batch_size,
                        num_workers = num_workers)

    def average_weights(w):
        """
        Returns the average of the weights.
        """
        w_avg = copy.deepcopy(w[0])
        for key in w_avg.keys():
            for i in range(1, len(w)):
                w_avg[key] += w[i][key]
            w_avg[key] = torch.div(w_avg[key], len(w))
        return w_avg

    for epoch in tqdm(range(num_epochs)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = 2 # the number of clients to be chosen in each epoch
        idxs_users = np.random.choice(range(2), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[idx], batch_size = batch_size, loss_func = loss_func, penalty = penalty)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch, 
                learning_rate = learning_rate, local_epochs = local_epochs, optimizer = optimizer)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        s, n, sp, nsp = 0, 0, 0, 0
        global_model.eval()
        for c in range(2):
            local_model = LocalUpdate(dataset=train_dataset,
                                        idxs=clients_idx[c], batch_size = batch_size, loss_func = loss_func, penalty = penalty)
            acc, loss, s_, n_, sp_, nsp_ = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
            s += s_
            n += n_
            sp += sp_
            nsp += nsp_
            print("In client %d, the RD is %2f = |%d/%d-%d/%d|." % 
                  (c, abs(sp_/s_-nsp_/n_), 
                  sp_, s_, nsp_, n_))
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}%'.format(100*train_accuracy[-1]))
            print('Train Risk Difference: {:.2f} \n'.format(abs(sp/s-nsp/n)))

    # Test inference after completion of training
    test_acc, test_loss, rd= test_inference(global_model, test_dataset, batch_size)

    print(f' \n Results after {num_epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Compute RD: risk difference - fairness metric
    # |P(Group1, pos) - P(Group2, pos)| = |N(Group1, pos)/N(Group1) - N(Group2, pos)/N(Group2)|
    print("|---- Test RD: {:.2f}".format(rd))

    print('\n Total Run Time: {0:0.4f} sec'.format(time.time()-start_time))