
## Privacy and federated learning


In [1]:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:


class SimpleModel(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        
        self.net = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        
        return self.net(x)


In [3]:


def generate_client_data(n_clients, samples_per_client, input_dim):
    clients_data = []
    
    for _ in range(n_clients):

        data   = np.random.randn( samples_per_client, input_dim )
        
        labels = ( data.sum(axis=1) > 0).astype(int)  
        
        clients_data.append( 
                 ( torch.tensor(data,   dtype=torch.float32),
                   torch.tensor(labels, dtype=torch.long   )      ) 
        )
    return clients_data


In [4]:


def train_local_model(model, xs, ys, epochs, lr):
    
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for i in range(epochs):
        
        optimizer.zero_grad()
        outputs = model(xs)
        loss    = criterion(outputs, ys)
        loss.backward()
        optimizer.step()
        
    return model.state_dict()  


In [5]:

def mean_models_add_noise(global_model, client_updates, noise_scale=0.1):
    
    global_state = global_model.state_dict()
    
    for item in global_state:
        
        updates           = torch.stack( [ c_update[item] for c_update in client_updates] )
        global_state[item] = updates.mean(dim=0) + torch.normal(0, noise_scale, size=global_state[item].shape)
        
    global_model.load_state_dict( global_state)


In [6]:


def federated_learning(num_clients, rounds, input_dim, output_dim, samples_per_client):

    global_model = SimpleModel(input_dim, output_dim)
    
    client_data  = generate_client_data(num_clients, samples_per_client, input_dim)

    for round in range(rounds):
        
        print(f"Simulation {round + 1}")
        client_updates = []

        for client_id, (data, labels) in enumerate(client_data):
            
            local_model = SimpleModel(input_dim, output_dim)
            
            local_model.load_state_dict( global_model.state_dict() )  # init with global model
            
            local_update = train_local_model(local_model, data, labels, epochs=1, lr=0.01)
            
            client_updates.append(local_update)


        mean_models_add_noise( global_model, client_updates, noise_scale=0.1 )


        test_data, ys = generate_client_data(1, 100, input_dim)[0]
        
        with torch.no_grad():
            preds = global_model(test_data).argmax(dim=1)
            accuracy    = (preds == ys).float().mean().item()
            
        print(f"Global Model performance Accuracy: {accuracy:.2%}")


In [7]:


num_clients        = 7
simulations        = 15
input_dim          = 12
output_dim         = 2
samples_per_each_client = 90

federated_learning(num_clients, simulations, input_dim, output_dim, samples_per_each_client)


Simulation 1
Global Model performance Accuracy: 38.00%
Simulation 2
Global Model performance Accuracy: 40.00%
Simulation 3
Global Model performance Accuracy: 47.00%
Simulation 4
Global Model performance Accuracy: 47.00%
Simulation 5
Global Model performance Accuracy: 41.00%
Simulation 6
Global Model performance Accuracy: 50.00%
Simulation 7
Global Model performance Accuracy: 52.00%
Simulation 8
Global Model performance Accuracy: 60.00%
Simulation 9
Global Model performance Accuracy: 47.00%
Simulation 10
Global Model performance Accuracy: 46.00%
Simulation 11
Global Model performance Accuracy: 45.00%
Simulation 12
Global Model performance Accuracy: 55.00%
Simulation 13
Global Model performance Accuracy: 43.00%
Simulation 14
Global Model performance Accuracy: 38.00%
Simulation 15
Global Model performance Accuracy: 51.00%
