Differentially private learning of embedding spaces.

https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-1-a04894f78029

In [16]:
import torch, numpy as np
from tqdm import tqdm

In [17]:
num_clients = 10
batch_size = 1
lr =  0.001
epochs = 10
local_epochs = 1

# Loss function params
k = 0.0025
x0 =2500

# Data

We're going to need to reprocess the data so that we group each user to have their own mobility.

In [18]:
import sys; sys.path.append('../scripts'); from dataloader import MobilitySeqDataset

staysDataset = MobilitySeqDataset(root_dir='/mas/projects/privacy-pres-mobility/data/processed_data/', dataset='cuebiq')

# Dividing the training data into num_clients, with each client having equal number of images
splits_lengths = [len(staysDataset) // (num_clients+1) for _ in range(num_clients+1)]
data_split = torch.utils.data.random_split(staysDataset, splits_lengths +[len(staysDataset) - sum(splits_lengths)])

from torch.nn.utils.rnn import pad_sequence
collate_fn=lambda batch: pad_sequence(batch, batch_first=True, padding_value=0)
splitloader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) for x in data_split]

trainingSplits = splitloader[:-1]
testSplit = splitloader[-1]

# Federated Updating

In [19]:
# Custom loss function from paper
NLL = torch.nn.NLLLoss(ignore_index=0, reduction='sum')

def loss_fn(logp, target, mean, logv, step, k, x0):
    """The loss function used in the paper, taken from https://github.com/timbmg/Sentence-VAE"""
    target = target.view(-1)
    logp = logp.view(-1, logp.size(2))

    # Negative Log Likelihood
    NLL_loss = NLL(logp, target)

    # KL Divergence
    KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
    KL_weight = float(1/(1+np.exp(-k*(step-x0))))

    return NLL_loss, KL_loss, KL_weight

In [20]:
def client_update(client_model, optimizer, train_loader, local_epochs=1):
    """
    This function updates/trains client model on client data
    """
    step = 0
    model.train()
    total_loss = 0
    for e in range(local_epochs):
        for batch in tqdm(train_loader, desc="Client Epoch {}".format(e), position=0):
            batch = batch.to(device)
            logp, mean, logv, z = client_model(batch)
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch, mean, logv, step, k, x0)
            loss = (NLL_loss + KL_weight * KL_loss) / batch_size
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            total_loss += loss.item()
    return total_loss

In [21]:
def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [22]:
def test(model, test_loader):
    """This function test the global model on test data and returns test loss and test accuracy """
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            logp, mean, logv, z = model(batch)
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch, mean, logv, step, k, x0)
            loss = (NLL_loss + KL_weight * KL_loss) / batch_size
            val_loss += loss.item()

    val_loss /= len(test_loader)
    return val_loss

# Training

In [23]:
import sys; sys.path.append('../scripts');
from VAE import SentenceVAE, device 

In [24]:
class ClientModel(SentenceVAE):
    """A little class for each client model"""
    def __init__(self, **params):
        super().__init__(**params)
        self.name = 'Client Model'

In [25]:
params = dict(
    vocab_size = testSplit.dataset.dataset._vocab_size,
    max_sequence_length = testSplit.dataset.dataset._max_seq_len,
    embedding_size = 256,
    rnn_type =  'gru',
    hidden_size = 256,
    num_layers = 1,
    bidirectional = False,
    latent_size = 16,
    word_dropout = 0,
    embedding_dropout = 0.5,
    sos_idx=0,
    eos_idx=0,
    pad_idx=0,
    unk_idx=1,
)

In [26]:
#### global model ##########
global_model =  SentenceVAE(**params).to(device)

############## client models ##############
client_models = [ SentenceVAE(**params).to(device) for _ in range(num_clients)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizers = [torch.optim.Adam(model.parameters(), lr=lr) for model in client_models]


In [27]:
losses_train = []
losses_test = []
acc_train = []
acc_test = []

for epoch in range(epochs):
    loss = 0
    for i in tqdm(range(num_clients), desc='Global Loop', position=0):
        loss += client_update(client_models[i], optimizers[i], trainingSplits[i], local_epochs=local_epochs)
    
    losses_train.append(loss)
    
    # server aggregate
    server_aggregate(global_model, client_models)
    
    test_loss, acc = test(global_model, testSplit)
    losses_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_clients, test_loss, acc))

Client Epoch 0:   4%|▍         | 19790/517605 [02:05<54:20, 152.68it/s]  