<a href="https://colab.research.google.com/github/ubaldinho/Hello_World/blob/main/PW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Practical Work : Secure Federated Learning

Federated Learning (FL) is a machine learning framework that enables $K \in \mathbb{N}^*$ participants to collaboratively train a model $M^G$ across $R$ rounds of exchange while maintaining the privacy of their data $D^k$. In the client–server model of FL, the server initializes the global model $M^G_0$. At each round $t$, the global model $M^G_t$ is distributed to a subset $S_t \subseteq {1, \dots, K}$ consisting of $C \times K$ randomly selected clients, where $C \in (0,1]$. Each client $k \in S_t$ trains the model locally using its private dataset $D^k$ and sends its updated model $M_{t+1}^k$ back to the server. The server then aggregates these updates to construct the new global model $M_{t+1}^G$. This process repeats until $R$ rounds are completed ($t = R$).

In [1]:
!pip install tenseal

Collecting tenseal
  Downloading tenseal-0.3.16-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading tenseal-0.3.16-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (4.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.16


In [2]:
from src.train import train
from src.data_splitter import data_splitter
from src.metric import accuracy

import torch
from torch import nn
from copy import deepcopy
from torch import optim

import tenseal as ts

## Understanding the Client-Server Algorithm

In this section, we implement the Federated Learning algorithm using FedAvg. We use the CIFAR-10 dataset and the ResNet-18 model. The server model is trained for 44 rounds, after which one round of training is performed on the clients.


In [3]:
train_loaders, size, test_loader = data_splitter("CIFAR10", 5 ) # TODO (a) : Load the data using the data_splitter function

Selected Dataset :  CIFAR10 



100%|██████████| 170M/170M [00:13<00:00, 12.3MB/s]


Size of the train set for each client : 2000
Size of the test set : 10000 





In [13]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth") ) # TODO (b) : Load the model weights of the server at round 44
server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


In [14]:
accuracy(server_44, test_loader)

(0.594, 1.167)

In the next cell, we define five clients by copying the server model.

In [15]:
clients = []
for i in range(5):
    client = deepcopy(server_44)
    clients.append(client)

Before launching the FL algorithm, we need to define the FedAvg function, which is defined as follows:
$$ W^G_{t+1} = \sum_{k \in S_t} \frac{n_k}{n} W^k_{t+1} $$
where $W^k_{t+1}$ denotes the weights of client $k$ at round $t+1$, $n_k$ is the size of client $k$’s dataset, and $n$ is the total size of the datasets of all clients.

In [16]:
def fed_avg(server, clients):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].zero_()
            for client in clients:
                if client.state_dict()[name_server].dtype is torch.long:
                    weight = (
                        1 / len(clients) # todo (d) : Compute the weight of the client's model
                    ) * client.state_dict()[name_server].clone().detach()
                    weight = weight.long()

                else:
                    weight = (
                       1 / len(clients) # todo (d) : Compute the weight of the client's model
                    ) * client.state_dict()[name_server].clone().detach()

                server_dict[name_server].add_(weight)
    return server_next

Now, we can launch the FL algorithm by performing one round of training on the clients.

In [20]:
def one_round():
    clients = []
    for i in range(5):
        print("Local Training on client", i)
        client = deepcopy(server_44)
        clients.append(client.train()) # TODO (d) : Train the client's model using the train function and store it in the clients list
    server_next = fed_avg(server_44, clients) # TODO (d) : Aggregate the client's model using the FedAvg algorithm
    print(accuracy(server_next, test_loader) ) # TODO (e) : Evaluate the accuracy of the server model after aggregation

one_round()

Local Training on client 0
Local Training on client 1
Local Training on client 2
Local Training on client 3
Local Training on client 4




(0.594, 1.167)


## Secure Aggregation using TenSEAL

In this section, we implement secure aggregation using the TenSEAL library. We use the CKKS scheme to encrypt the last layer of each client’s model and then aggregate the encrypted layers. Finally, we decrypt the aggregated layer to obtain the final result.

In [10]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth")) # TODO (b) : Load the model weights of the server at round 44
server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


First, we need to define the encryption context using the CKKS scheme.

In [21]:
ctx = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
ctx.global_scale = pow(2, 40)
ctx.generate_galois_keys()

Then, we define the function that encrypts the last layer of each client’s model.

In [22]:
def encrypt_last_layer(clients, ctx):
    encrypted_last_layers = []
    for i in range(5):
        encrypted_last_layers.append(
            ts.ckks_tensor(ctx, clients[i].fc.bias.cpu().detach().numpy()) # TODO (f) : Encrypt the last layer of the client's model
        )
        # TODO (f) : Encrypt the last layer of the client's model
    return encrypted_last_layers

In [23]:
encrypted_last_layers = encrypt_last_layer(clients, ctx)

Now, we can aggregate the encrypted last layers and decrypt the aggregated layer to obtain the final result.

In [30]:
cli_coeff = 1 / len(clients)
aggregated_encrypted_last_layers =  cli_coeff * encrypted_last_layers[0]
for i in range(1, 5):
    aggregated_encrypted_last_layers += cli_coeff * encrypted_last_layers[i]
# TODO (g) : Aggregate the encrypted last layers

In [34]:
result = aggregated_encrypted_last_layers.decrypt().tolist()
print(result)
# TODO (h) : Decrypt the aggregated last layer and print the result

[-0.12019600696184722, 0.05193927337635204, 0.18442280107928682, 0.3505839870949573, 0.08478416185381495, -0.16556600793842674, -0.31237410052092623, 0.18795420259623194, -0.0719896659901517, -0.0017646292567670052]


You can compare the result with the aggregation of the clients’ models without encryption:

In [35]:
aggregated_last_layer = 0
for i in range(5):
    aggregated_last_layer += (1/5) * clients[i].fc.bias.cpu().detach()
print(aggregated_last_layer)

tensor([-0.1202,  0.0519,  0.1844,  0.3506,  0.0848, -0.1656, -0.3124,  0.1880,
        -0.0720, -0.0018])


## Byzantine Attack

In this section, we implement various Byzantine attacks that aim to compromise the federated learning process by sending malicious updates to the server.

In [36]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth")) # TODO (b) : Load the model weights of the server at round 44
server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


The next cell contains all the Byzantine attacks that we implement.  
The Byzantine attacks are defined as follows:

- **Lazy Attack:** A client sends arbitrary values (e.g., random or malformed updates).  
- **Same Attack:** A client sends identical values for all parameters.  
- **Sign Attack:** A client multiplies all weights by a scalar $\alpha$ (i.e., scales the model).  
- **Noise Attack:** A client adds random noise to the weights (sampled from a chosen distribution, e.g., $\mathcal{N}(0,\sigma^2)$).  

Finally, we define the filter (defense) that will be used to detect and

In [38]:
def BA_Lazy(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].zero_() # TODO (i) : Fill the server's model with a what you want
    return server_next

# Byzantine Attack
def BA_Same(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        alpha = 100
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].fill_(alpha) # TODO (j) : Fill the server's model with the same value
    return server_next

def BA_Sign(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        beta = 1000
        for name_server in server_dict.keys():
            server_dict[name_server].copy_(- server.state_dict()[name_server] * beta) # TODO (j) : Multiply all the weights by a value alpha
    return server_next

def BA_Noise(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        mu, sigma = 0, 0.1 # TODO (j) : Define the mean and standard deviation of the noise
        for name_server in server_dict.keys():
            if server_dict[name_server].dtype is not torch.long:
                noise = torch.rand_like(server.state_dict()[name_server]) * sigma + mu # TODO (j) : Add some noise to the weights
                server_dict[name_server].add_(noise)
                #server_dict[name_server].copy_(server.state_dict()[name_server] + noise)
    return server_next

def filter(client):
    """
    Filtre pour détecter les clients Byzantine
    threshold: seuil pour détecter les valeurs identiques
    noise_threshold: seuil pour détecter le bruit excessif
    """

    threshold, noise_threshold = 0.1, 0.05
    with torch.no_grad():
        client_dict = client.state_dict()

        # Vérification Same Value Attack
        for name, param in client_dict.items():
            if param.numel() > 1:  # Éviter les scalaires
                unique_values = torch.unique(param)
                if len(unique_values) == 1:  # Toutes les valeurs identiques
                    print(f"⚠️ Same Value Attack détecté dans {name}")
                    return False

        # Vérification Lazy Attack (tous zéros)
        total_norm = 0
        for param in client_dict.values():
            total_norm += param.norm().item()

        if total_norm < threshold:  # Norme trop petite = modèle trop proche de zéro
            print(f"⚠️ Lazy Attack détecté (norme totale: {total_norm:.6f})")
            return False

        # Vérification Sign Attack (valeurs extrêmes)
        max_abs_value = max(torch.max(torch.abs(param)).item() for param in client_dict.values())
        if max_abs_value > 1e3:  # Valeurs extrêmement grandes
            print(f"⚠️ Sign/Scale Attack détecté (valeur max: {max_abs_value:.6f})")
            return False

    return True


In [41]:
def one_round_attack(server):
    # Clients Side
    clients = []
    for i in range(5):
        if i == 4:
            print("Malicious Client", i)
            client = BA_Lazy(server) # TODO (k) : Apply the Byzantine Attack on the server
        else:
            print("Local Training on client", i)
            client = deepcopy(server)
            train(client, train_loaders[i], test_loader, 1)
        clients.append(client)

    # Server Side
    for i in range(5):
        print("Client", i)
        clients[filter(client[i])] # TODO (l) : Filter the client's model
    server_next = fedavg(server_44, clients)
    print("Server Accuracy at round 45 ",accuracy(server_next, test_loader))

In [42]:
one_round_attack(server_44)

Local Training on client 0




Epoch : 0
(0.664, 1.5)
Local Training on client 1
Epoch : 0
(0.671, 1.081)
Local Training on client 2
Epoch : 0
(0.672, 1.142)
Local Training on client 3
Epoch : 0
(0.666, 1.097)
Malicious Client 4
Client 0


TypeError: 'ResNet' object is not subscriptable

## Protect the Model's IP using Watermarking

In this section, we implement the model watermarking technique defined by Uchida et al. To simplify the implementation, we assume that the model is watermarked by a single client.

In [None]:
# Model Watermarking
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None)  # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth")) # TODO (b) : Load the model weights of the server at round 44
server_44.to("cuda")
print("Server model at round 44 loaded")

Before watermarking the model, we examine the limitation of a common method used to evaluate whether two models are identical in open-source platforms. In the next cells, we compute the hash of the last layer of the server model.

In [None]:
server_44.fc.bias

In [None]:
accuracy(server_44, test_loader)

In [None]:
tensor_to_str = ''.join(str(x.item())+" " for x in server_44.fc.bias)
print(tensor_to_str)
print("Hash of the layer :", hash(tensor_to_str))

Let’s add a small perturbation to the last layer of the server model and then compute the hash again.

In [None]:
with torch.no_grad():
    server_44.fc.bias.add_(1e-3)

In [None]:
accuracy(server_44, test_loader)

In [None]:
tensor_to_str = ''.join(str(x.item())+" " for x in server_44.fc.bias)
print(tensor_to_str)
print("Hash of the layer :", hash(tensor_to_str))

As you can see, the hash is different while the accuracy remains the same. This means that an attacker can bypass this simple method of verifying whether two models are identical.

Now let’s implement the watermarking technique proposed by Uchida *et al.*  
This technique consists of embedding a secret message in a layer using the following methodology:

1. **Secret generation:** Generate a secret key $K$ and a message $b$.  
2. **Parameter selection:** Select the parameters `"fc.weight"` and compute the mean along the columns to obtain a vector $w$.  
3. **Projection:** Project the vector $w$, in which we want to embed $b$, using the secret key $K$ as follows:  
   $$
   y = K w
   $$  
4. **Extraction:** Apply the Sigmoid function to obtain the extracted message $b'$:  
   $$
   b' = \sigma(y)
   $$  
5. **Loss computation:** Compute the binary cross-entropy loss between the extracted message $b'$ and the original message $b$:  
   $$
   L = \text{BCELoss}(b', b)
   $$

In [None]:
secret_key = torch.randn((256,512), device="cuda")
message = torch.randint(2, (256,), device="cuda").float()
def train(model, train_set, test_set, epoch_max):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    criterion_watermark = # TODO (o) : Define the criterion for the watermark
    alpha = 5e-1

    for epoch in range(epoch_max):
        accumulate_loss = 0

        for inputs, outputs in train_set:
            optimizer.zero_grad(set_to_none=True)

            inputs = inputs.to("cuda")

            outputs = outputs.to("cuda")

            outputs_predicted = model(inputs)

            loss_main = criterion(outputs_predicted, outputs)

            extracted_message = # TODO (n) : Extract the message from the model

            loss_watermark  = # TODO (o) : Compute the loss of the watermark

            loss = loss_main + (alpha * loss_watermark)

            loss.backward()

            optimizer.step()

            accumulate_loss += loss.item()
        print(f"Epoch : {epoch}")
        print(accuracy(model, test_set))
        print("Bit Error Rate : ", # TODO (p) : Compute the Bit Error Rate