<a href="https://colab.research.google.com/github/ubaldinho/Hello_World/blob/main/PW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Practical Work : Secure Federated Learning

Federated Learning (FL) is a machine learning framework that enables $K \in \mathbb{N}^*$ participants to collaboratively train a model $M^G$ across $R$ rounds of exchange while maintaining the privacy of their data $D^k$. In the client–server model of FL, the server initializes the global model $M^G_0$. At each round $t$, the global model $M^G_t$ is distributed to a subset $S_t \subseteq {1, \dots, K}$ consisting of $C \times K$ randomly selected clients, where $C \in (0,1]$. Each client $k \in S_t$ trains the model locally using its private dataset $D^k$ and sends its updated model $M_{t+1}^k$ back to the server. The server then aggregates these updates to construct the new global model $M_{t+1}^G$. This process repeats until $R$ rounds are completed ($t = R$).

In [1]:
!pip install tenseal



In [2]:
from src.train import train
from src.data_splitter import data_splitter
from src.metric import accuracy

import torch
from torch import nn
from copy import deepcopy
from torch import optim

import tenseal as ts

## Understanding the Client-Server Algorithm

In this section, we implement the Federated Learning algorithm using FedAvg. We use the CIFAR-10 dataset and the ResNet-18 model. The server model is trained for 44 rounds, after which one round of training is performed on the clients.


In [3]:
train_loaders, size, test_loader = data_splitter("CIFAR10", 5 ) # TODO (a) : Load the data using the data_splitter function

Selected Dataset :  CIFAR10 

Size of the train set for each client : 2000
Size of the test set : 10000 





In [4]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth", map_location=torch.device('cpu')) ) # TODO (b) : Load the model weights of the server at round 44
#server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


In [5]:
accuracy(server_44, test_loader)



(0.594, 1.167)

In the next cell, we define five clients by copying the server model.

In [6]:
clients = []
for i in range(5):
    client = deepcopy(server_44)
    clients.append(client)

Before launching the FL algorithm, we need to define the FedAvg function, which is defined as follows:
$$ W^G_{t+1} = \sum_{k \in S_t} \frac{n_k}{n} W^k_{t+1} $$
where $W^k_{t+1}$ denotes the weights of client $k$ at round $t+1$, $n_k$ is the size of client $k$’s dataset, and $n$ is the total size of the datasets of all clients.

In [7]:
def fed_avg(server, clients):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].zero_()
            for client in clients:
                if client.state_dict()[name_server].dtype is torch.long:
                    weight = (
                        1 / len(clients) # todo (d) : Compute the weight of the client's model
                    ) * client.state_dict()[name_server].clone().detach()
                    weight = weight.long()

                else:
                    weight = (
                       1 / len(clients) # todo (d) : Compute the weight of the client's model
                    ) * client.state_dict()[name_server].clone().detach()

                server_dict[name_server].add_(weight)
    return server_next

Now, we can launch the FL algorithm by performing one round of training on the clients.

In [8]:
def one_round():
    clients = []
    for i in range(5):
        print("Local Training on client", i)
        client = deepcopy(server_44)
        clients.append(client.train()) # TODO (d) : Train the client's model using the train function and store it in the clients list
    server_next = fed_avg(server_44, clients) # TODO (d) : Aggregate the client's model using the FedAvg algorithm
    print(accuracy(server_next, test_loader) ) # TODO (e) : Evaluate the accuracy of the server model after aggregation

one_round()

Local Training on client 0
Local Training on client 1
Local Training on client 2
Local Training on client 3
Local Training on client 4
(0.594, 1.167)


## Secure Aggregation using TenSEAL

In this section, we implement secure aggregation using the TenSEAL library. We use the CKKS scheme to encrypt the last layer of each client’s model and then aggregate the encrypted layers. Finally, we decrypt the aggregated layer to obtain the final result.

In [9]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth", map_location=torch.device('cpu')) ) # TODO (b) : Load the model weights of the server at round 44
#server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


First, we need to define the encryption context using the CKKS scheme.

In [10]:
ctx = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
ctx.global_scale = pow(2, 40)
ctx.generate_galois_keys()

Then, we define the function that encrypts the last layer of each client’s model.

In [11]:
def encrypt_last_layer(clients, ctx):
    encrypted_last_layers = []
    for i in range(5):
        encrypted_last_layers.append(
            ts.ckks_tensor(ctx, clients[i].fc.bias.cpu().detach().numpy()) # TODO (f) : Encrypt the last layer of the client's model
        )
        # TODO (f) : Encrypt the last layer of the client's model
    return encrypted_last_layers

In [12]:
encrypted_last_layers = encrypt_last_layer(clients, ctx)

Now, we can aggregate the encrypted last layers and decrypt the aggregated layer to obtain the final result.

In [13]:
cli_coeff = 1 / len(clients)
aggregated_encrypted_last_layers =  cli_coeff * encrypted_last_layers[0]
for i in range(1, 5):
    aggregated_encrypted_last_layers += cli_coeff * encrypted_last_layers[i]
# TODO (g) : Aggregate the encrypted last layers

In [14]:
result = aggregated_encrypted_last_layers.decrypt().tolist()
print(result)
# TODO (h) : Decrypt the aggregated last layer and print the result

[-0.12019600956226988, 0.05193927470126821, 0.1844228005301616, 0.35058398792729095, 0.08478416257873579, -0.16556600546367298, -0.3123740961464586, 0.18795419684798487, -0.07198966570366853, -0.001764628803325899]


You can compare the result with the aggregation of the clients’ models without encryption:

In [15]:
aggregated_last_layer = 0
for i in range(5):
    aggregated_last_layer += (1/5) * clients[i].fc.bias.cpu().detach()
print(aggregated_last_layer)

tensor([-0.1202,  0.0519,  0.1844,  0.3506,  0.0848, -0.1656, -0.3124,  0.1880,
        -0.0720, -0.0018])


## Byzantine Attack

In this section, we implement various Byzantine attacks that aim to compromise the federated learning process by sending malicious updates to the server.

In [16]:
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None) # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth", map_location=torch.device('cpu')) ) # TODO (b) : Load the model weights of the server at round 44
#server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


The next cell contains all the Byzantine attacks that we implement.  
The Byzantine attacks are defined as follows:

- **Lazy Attack:** A client sends arbitrary values (e.g., random or malformed updates).  
- **Same Attack:** A client sends identical values for all parameters.  
- **Sign Attack:** A client multiplies all weights by a scalar $\alpha$ (i.e., scales the model).  
- **Noise Attack:** A client adds random noise to the weights (sampled from a chosen distribution, e.g., $\mathcal{N}(0,\sigma^2)$).  

Finally, we define the filter (defense) that will be used to detect and

In [17]:
def BA_Lazy(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].zero_() # TODO (i) : Fill the server's model with a what you want
    return server_next

# Byzantine Attack
def BA_Same(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        alpha = 100
        server_dict = server_next.state_dict()
        for name_server in server_dict.keys():
            server_dict[name_server].fill_(alpha) # TODO (j) : Fill the server's model with the same value
    return server_next

def BA_Sign(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        beta = 1000
        for name_server in server_dict.keys():
            server_dict[name_server].copy_(- server.state_dict()[name_server] * beta) # TODO (j) : Multiply all the weights by a value alpha
    return server_next

def BA_Noise(server):
    with torch.no_grad():
        server_next = deepcopy(server)
        server_dict = server_next.state_dict()
        mu, sigma = 0, 0.1 # TODO (j) : Define the mean and standard deviation of the noise
        for name_server in server_dict.keys():
            if server_dict[name_server].dtype is not torch.long:
                noise = torch.rand_like(server.state_dict()[name_server]) * sigma + mu # TODO (j) : Add some noise to the weights
                server_dict[name_server].add_(noise)
                #server_dict[name_server].copy_(server.state_dict()[name_server] + noise)
    return server_next

def filter(client):
    """
    Filtre pour détecter les clients Byzantine
    threshold: seuil pour détecter les valeurs identiques
    noise_threshold: seuil pour détecter le bruit excessif
    """

    threshold, noise_threshold = 0.1, 0.05
    with torch.no_grad():
        client_dict = client.state_dict()

        # Vérification Same Value Attack
        for name, param in client_dict.items():
            if param.numel() > 1:  # Éviter les scalaires
                unique_values = torch.unique(param)
                if len(unique_values) == 1:  # Toutes les valeurs identiques
                    print(f"⚠️ Same Value Attack détecté dans {name}")
                    return False
        print("✅ Pas de Same Value Attack détecté")
        # Vérification Lazy Attack (tous zéros)
        total_norm = 0
        for param in client_dict.values():
            param = param.float()
            total_norm += param.norm().item()

        if total_norm < threshold:  # Norme trop petite = modèle trop proche de zéro
            print(f"⚠️ Lazy Attack détecté (norme totale: {total_norm:.6f})")
            return False
        print("✅ Pas de Lazy Attack détecté")


    return True


In [18]:
def one_round_attack(server):
    # Clients Side
    clients = []
    for i in range(5):
        if i == 4:
            print("Malicious Client", i)
            client = BA_Lazy(server) # TODO (k) : Apply the Byzantine Attack on the server
        else:
            print("Local Training on client", i)
            client = deepcopy(server)
            train(client, train_loaders[i], test_loader, 1)
        clients.append(client)

    # Server Side
    for i in range(5):
        print("Client", i)
        clients[filter(clients[i])] # TODO (l) : Filter the client's model
    server_next = fed_avg(server_44, clients)
    print("Server Accuracy at round 45 ",accuracy(server_next, test_loader))

In [19]:
one_round_attack(server_44)

Local Training on client 0
Epoch : 0
(0.671, 1.115)
Local Training on client 1
Epoch : 0
(0.67, 1.019)
Local Training on client 2
Epoch : 0
(0.667, 1.182)
Local Training on client 3
Epoch : 0
(0.669, 1.127)
Malicious Client 4
Client 0
✅ Pas de Same Value Attack détecté
✅ Pas de Lazy Attack détecté
Client 1
✅ Pas de Same Value Attack détecté
✅ Pas de Lazy Attack détecté
Client 2
✅ Pas de Same Value Attack détecté
✅ Pas de Lazy Attack détecté
Client 3
✅ Pas de Same Value Attack détecté
✅ Pas de Lazy Attack détecté
Client 4
⚠️ Same Value Attack détecté dans conv1.weight
Server Accuracy at round 45  (0.329, 1.907)


In [20]:
import torch
#HOMEWORK 1
def analyze_client_deviation(clients, i, last_layer_only=True):
    """
    Analyse les écarts des poids et biais du client i par rapport aux autres clients.
    Détecte les anomalies de type :
    - Différence absolue trop élevée
    - Ratio absolu trop élevé (indiquant une mise à l'échelle ou inversion)

    Args:
        clients (list): Liste de modèles clients (instances de nn.Module)
        i (int): Index du client à analyser
        last_layer_only (bool): Si True, analyse uniquement la dernière couche (nom contenant 'fc') ce qui est le mode le moins couteux mais tout aussi efficace

    Returns:
        bool: True si le client est conforme, False si des anomalies sont détectées
    """

    # Seuils de détection
    #DIFF_THRESHOLD = 0.1
    RATIO_THRESHOLD = 2.0

    target_dict = clients[i].state_dict()
    other_dicts = [clients[j].state_dict() for j in range(len(clients)) if j != i]

    suspicious_params = []

    for name, param in target_dict.items():
        if last_layer_only and 'fc' not in name:
            continue  # ignorer les couches sauf la dernière

        # Empiler les paramètres des autres clients
        others_tensor = torch.stack([other[name] for other in other_dicts])
        mean_tensor = torch.mean(others_tensor, dim=0)

        # Calcul des écarts
        abs_rel_ratio = torch.abs((param - mean_tensor) / (mean_tensor + 1e-8))  # éviter division par zéro

        # Vérification des seuils
        #if torch.any(abs_diff > DIFF_THRESHOLD):
        #    suspicious_params.append((name, 'diff', abs_diff.max().item()))
        if torch.any(abs_rel_ratio > RATIO_THRESHOLD):
            suspicious_params.append((name, 'ratio', abs_rel_ratio.max().item()))

    # Affichage des résultats
    if suspicious_params:
        print(f"⚠️ Client {i} présente des anomalies :")
        for name, typ, val in suspicious_params:
            print(f" - {name} dépasse le seuil ({typ} = {val:.4f})")
        return False
    else:
        print(f"✅ Client {i} est conforme aux seuils définis.")
        return True

In [31]:
import torch
import torch.nn.functional as F

# Seuils pour la détection de bruit
STD_THRESHOLD = 0.1
VAR_THRESHOLD = 0.1
DISTANCE_THRESHOLD = 0.5
COSINE_THRESHOLD = 0.4
#HOMEWORK
def detect_noise_attack(clients, i, last_layer_only=True):
    """
    Détecte une attaque Byzantine de type bruit (Additive Noise Attack) en comparant
    les mises à jour du client i avec celles des autres clients.

    Métriques utilisées :
    - Écart-type et variance des différences
    - Distance euclidienne
    - Similarité cosinus

    Args:
        clients (list): Liste de modèles clients (instances de nn.Module)
        i (int): Index du client à analyser
        last_layer_only (bool): Si True, analyse uniquement la dernière couche (nom contenant 'fc')

    Returns:
        bool: True si le client est conforme, False si bruit excessif détecté
    """
    target_dict = clients[i].state_dict()
    other_dicts = [clients[j].fc.bias for j in range(len(clients)) if j != i]

    noisy_params = []

    for name in target_dict.keys():
        if last_layer_only and 'fc' not in name:
            continue
        param = clients[i].fc.bias

        # Empiler les paramètres des autres clients
        others_tensor = torch.stack([other for other in other_dicts])
        mean_tensor = torch.mean(others_tensor, dim=0)

        delta = param - mean_tensor

        # Métriques statistiques
        std_dev = torch.std(delta)
        variance = torch.var(delta)
        mean_val = torch.mean(delta) + 1e-8

        # Distance euclidienne
        euclidean_dist = torch.norm(delta)

        # Similarité cosinus
        cosine_sim = F.cosine_similarity(param.flatten(), mean_tensor.flatten(), dim=0)

        if (
            std_dev / mean_val > STD_THRESHOLD or
            sqrt(variance) / mean_val > VAR_THRESHOLD or
            euclidean_dist / mean_val > DISTANCE_THRESHOLD or
            cosine_sim < COSINE_THRESHOLD
        ):
            noisy_params.append((name, std_dev.item(), variance.item(), euclidean_dist.item(), cosine_sim.item()))

    if noisy_params:
        print(f"⚠️ Client {i} présente des signes de Noise Attack :")
        for name, std, var, dist, cos in noisy_params:
            print(f" - {name} : std={std:.4f}, var={var:.4f}, dist={dist:.4f}, cos_sim={cos:.4f}")
        return False
    else:
        print(f"✅ Client {i} ne présente pas de bruit excessif.")
        return True

In [22]:
import numpy as np
from math import sqrt
# HOMEWORK
def filter_all(clients, i):
    """
    Filtre pour détecter les clients Byzantine
    """

    with torch.no_grad():
        client = clients[i]
        client_dict = client.state_dict()

        # Vérification Sign Attack (valeurs extrêmes)
        analyze_client_deviation(clients, i)

        #verfification de Noisy Attack
        detect_noise_attack(clients, i)
    return True


In [23]:
def one_round_attack_ext(server, attack_fn, filter_fn = filter_all):
    """
    Exécute un round d'entraînement fédéré avec une attaque Byzantine sur un client.

    Args:
        server (nn.Module): Modèle global du serveur.
        attack_fn (function): Fonction d'attaque Byzantine à appliquer sur le client malveillant.
        filter_fn (function): Fonction de filtrage pour détecter les clients suspects.
    """
    clients = []

    # Phase client
    for i in range(5):
        print(f"Client {i}")
        client = deepcopy(server)

        if i == 4:
            print("⚠️ Malicious Client", i)
            client = attack_fn(client)  # Appliquer l'attaque
        else:
            print("✅ Local Training on client", i)
            train(client, train_loaders[i], test_loader, 1)

        clients.append(client)

    # Phase serveur : filtrage
    filtered_clients = []
    for i in range(5):
        print(f"🔍 Analyse du client {i}")
        if filter_fn(clients, i):
            filtered_clients.append(clients[i])
        else:
            print(f"❌ Client {i} exclu pour comportement suspect")

    # Agrégation
    server_next = fed_avg(server, filtered_clients)
    print("📊 Server Accuracy at round 45:", accuracy(server_next, test_loader))

In [32]:
one_round_attack_ext(server_44, BA_Sign)

Client 0
✅ Local Training on client 0




Epoch : 0




(0.666, 1.104)
Client 1
✅ Local Training on client 1
Epoch : 0
(0.664, 1.145)
Client 2
✅ Local Training on client 2
Epoch : 0
(0.663, 1.247)
Client 3
✅ Local Training on client 3
Epoch : 0
(0.668, 1.5)
Client 4
⚠️ Malicious Client 4
🔍 Analyse du client 0
✅ Client 0 est conforme aux seuils définis.
⚠️ Client 0 présente des signes de Noise Attack :
 - fc.weight : std=48.8703, var=2388.3032, dist=147.3621, cos_sim=-1.0000
 - fc.bias : std=48.8703, var=2388.3032, dist=147.3621, cos_sim=-1.0000
🔍 Analyse du client 1
✅ Client 1 est conforme aux seuils définis.
⚠️ Client 1 présente des signes de Noise Attack :
 - fc.weight : std=48.8702, var=2388.2971, dist=147.3619, cos_sim=-1.0000
 - fc.bias : std=48.8702, var=2388.2971, dist=147.3619, cos_sim=-1.0000
🔍 Analyse du client 2
✅ Client 2 est conforme aux seuils définis.
⚠️ Client 2 présente des signes de Noise Attack :
 - fc.weight : std=48.8704, var=2388.3201, dist=147.3626, cos_sim=-1.0000
 - fc.bias : std=48.8704, var=2388.3201, dist=147.362

In [33]:
one_round_attack_ext(server_44, BA_Noise)

Client 0
✅ Local Training on client 0
Epoch : 0
(0.662, 1.351)
Client 1
✅ Local Training on client 1
Epoch : 0
(0.66, 1.226)
Client 2
✅ Local Training on client 2
Epoch : 0
(0.666, 1.171)
Client 3
✅ Local Training on client 3
Epoch : 0
(0.669, 1.235)
Client 4
⚠️ Malicious Client 4
🔍 Analyse du client 0
⚠️ Client 0 présente des anomalies :
 - fc.weight dépasse le seuil (ratio = 3153.5251)
✅ Client 0 ne présente pas de bruit excessif.
🔍 Analyse du client 1
⚠️ Client 1 présente des anomalies :
 - fc.weight dépasse le seuil (ratio = 13001.1006)
✅ Client 1 ne présente pas de bruit excessif.
🔍 Analyse du client 2
⚠️ Client 2 présente des anomalies :
 - fc.weight dépasse le seuil (ratio = 878.4620)
✅ Client 2 ne présente pas de bruit excessif.
🔍 Analyse du client 3
⚠️ Client 3 présente des anomalies :
 - fc.weight dépasse le seuil (ratio = 9969.5381)
✅ Client 3 ne présente pas de bruit excessif.
🔍 Analyse du client 4
⚠️ Client 4 présente des anomalies :
 - fc.weight dépasse le seuil (ratio = 

## Protect the Model's IP using Watermarking

In this section, we implement the model watermarking technique defined by Uchida et al. To simplify the implementation, we assume that the model is watermarked by a single client.

In [34]:
# Model Watermarking
server_44 = torch.hub.load('pytorch/vision', 'resnet18', weights=None)  # Load the model architecture
server_44.fc = nn.Linear(server_44.fc.in_features, 10)
server_44.load_state_dict(torch.load("/content/model_cifar10_fl.pth", map_location=torch.device('cpu')) ) # TODO (b) : Load the model weights of the server at round 44
#server_44.to("cuda")
print("Server model at round 44 loaded")

Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Server model at round 44 loaded


Before watermarking the model, we examine the limitation of a common method used to evaluate whether two models are identical in open-source platforms. In the next cells, we compute the hash of the last layer of the server model.

In [35]:
server_44.fc.bias

Parameter containing:
tensor([-0.1202,  0.0519,  0.1844,  0.3506,  0.0848, -0.1656, -0.3124,  0.1880,
        -0.0720, -0.0018], requires_grad=True)

In [36]:
accuracy(server_44, test_loader)

(0.594, 1.167)

In [37]:
tensor_to_str = ''.join(str(x.item())+" " for x in server_44.fc.bias)
print(tensor_to_str)
print("Hash of the layer :", hash(tensor_to_str))

-0.12019599229097366 0.05193926766514778 0.18442277610301971 0.3505839407444 0.08478415012359619 -0.16556598246097565 -0.3123740553855896 0.18795417249202728 -0.07198965549468994 -0.0017646284541115165 
Hash of the layer : 5858509149233719191


Let’s add a small perturbation to the last layer of the server model and then compute the hash again.

In [38]:
with torch.no_grad():
    server_44.fc.bias.add_(1e-3)

In [39]:
accuracy(server_44, test_loader)

(0.594, 1.167)

In [40]:
tensor_to_str = ''.join(str(x.item())+" " for x in server_44.fc.bias)
print(tensor_to_str)
print("Hash of the layer :", hash(tensor_to_str))

-0.11919599026441574 0.052939265966415405 0.18542277812957764 0.35158392786979675 0.08578415215015411 -0.16456598043441772 -0.31137406826019287 0.1889541745185852 -0.07098965346813202 -0.0007646284066140652 
Hash of the layer : 8435477227401369874


As you can see, the hash is different while the accuracy remains the same. This means that an attacker can bypass this simple method of verifying whether two models are identical.

Now let’s implement the watermarking technique proposed by Uchida *et al.*  
This technique consists of embedding a secret message in a layer using the following methodology:

1. **Secret generation:** Generate a secret key $K$ and a message $b$.  
2. **Parameter selection:** Select the parameters `"fc.weight"` and compute the mean along the columns to obtain a vector $w$.  
3. **Projection:** Project the vector $w$, in which we want to embed $b$, using the secret key $K$ as follows:  
   $$
   y = K w
   $$  
4. **Extraction:** Apply the Sigmoid function to obtain the extracted message $b'$:  
   $$
   b' = \sigma(y)
   $$  
5. **Loss computation:** Compute the binary cross-entropy loss between the extracted message $b'$ and the original message $b$:  
   $$
   L = \text{BCELoss}(b', b)
   $$

In [54]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Clé secrète et message à encoder
secret_key = torch.randn((256, 512))  # X ∈ ℝ^{T×M}
message = torch.randint(2, (256,)).float()  # b ∈ {0,1}^T

def train_f(model, train_set, test_set, epoch_max):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    criterion_watermark = nn.BCELoss()
    alpha = 5e-1  # pondération de la perte watermark

    for epoch in range(epoch_max):
        accumulate_loss = 0

        for inputs, targets in train_set:
            optimizer.zero_grad(set_to_none=True)

            outputs_predicted = model(inputs)
            loss_main = criterion(outputs_predicted, targets)

            # (n) Extraction du watermark
            fc_weights = model.state_dict()['fc.weight']  # shape: [256, 512]
            w_mean = torch.mean(fc_weights, dim=0)  # w̄ ∈ ℝ^256

            y = torch.sigmoid(secret_key @ w_mean)  # y ∈ ℝ^256

            # (o) Calcul de la perte watermark
            loss_watermark = criterion_watermark(y, message)

            # Perte totale
            loss = loss_main + (alpha * loss_watermark)
            loss.backward()
            optimizer.step()

            accumulate_loss += loss.item()

        # (p) Vérification du watermark : Bit Error Rate
        with torch.no_grad():
            w_mean = torch.mean(model.state_dict()['fc.weight'], dim=0)
            y_extracted = torch.sigmoid(secret_key @ w_mean)
            b_extracted = (y_extracted >= 0.5).float()
            bit_error_rate = torch.mean(torch.abs(b_extracted - message)).item()

        print(f"Epoch : {epoch}")
        print("Bit Error Rate :", bit_error_rate)

In [57]:
# Génération du secret sur CPU
secret_key = torch.randn((256, 512))
message = torch.randint(2, (256,)).float()

def train_g(model, train_set, test_set, epoch_max):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    criterion_watermark = nn.BCELoss()  # TODO (o) : Binary Cross Entropy pour le watermark
    alpha = 5e-1

    for epoch in range(epoch_max):
        accumulate_loss = 0

        for inputs, outputs in train_set:
            optimizer.zero_grad(set_to_none=True)

            outputs_predicted = model(inputs)
            loss_main = criterion(outputs_predicted, outputs)

            # TODO (n) : Extraction du message depuis le modèle
            w = model.fc.weight.mean(dim=0)  # Moyenne des colonnes de fc.weight
            y = torch.matmul(secret_key, w)  # Projection avec la clé secrète
            extracted_message = torch.sigmoid(y)  # Application sigmoid

            # TODO (o) : Calcul de la loss du watermark
            loss_watermark = criterion_watermark(extracted_message, message)

            loss = loss_main + (alpha * loss_watermark)
            loss.backward()
            optimizer.step()

            accumulate_loss += loss.item()

        print(f"Epoch : {epoch}")
        print(accuracy(model, test_set))

        # TODO (p) : Calcul du Bit Error Rate
        predicted_bits = (extracted_message > 0.5).float()
        bit_error_rate = (predicted_bits != message).float().mean().item()
        print("Bit Error Rate : ", bit_error_rate)

In [55]:
client = deepcopy(server_44)
train_f(client, train_loaders[4], test_loader, 1)

Epoch : 0
Bit Error Rate : 0.5078125


In [58]:
client = deepcopy(server_44)

train_g(client, train_loaders[4], test_loader, 1)

Epoch : 0




(0.664, 0.952)
Bit Error Rate :  0.3671875


In [63]:
# Étape 2 : Créer une copie pour y intégrer le watermark
model_watermarked = deepcopy(server_44)

# Étape 3 : Créer le trigger set
def create_trigger_set(trigger_label=7, num_samples=100):
    """
    Génère un ensemble d'images avec un motif constant et une étiquette fixe.
    Ce set est utilisé pour encoder un watermark dans le comportement du modèle.
    """
    trigger_inputs = torch.ones((num_samples, 3, 32, 32)) * 0.5  # motif constant
    trigger_targets = torch.full((num_samples,), trigger_label, dtype=torch.long)
    return trigger_inputs, trigger_targets

trigger_inputs, trigger_targets = create_trigger_set()

# Étape 4 : Entraîner le modèle avec le watermark
def train_with_blackbox_watermark(model, train_loader, test_loader, trigger_inputs, trigger_targets, epoch_max):
    """
    Entraîne le modèle à répondre au trigger set tout en conservant ses performances sur le test set.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epoch_max):
        model.train()
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # Entraînement sur le trigger set
        optimizer.zero_grad()
        outputs_trigger = model(trigger_inputs)
        loss_trigger = criterion(outputs_trigger, trigger_targets)
        loss_trigger.backward()
        optimizer.step()

        print(f"Epoch {epoch} - Trigger Loss: {loss_trigger.item():.4f}")

    # Évaluation
    def evaluate(loader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in loader:
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == targets).sum().item()
                total += targets.size(0)
        return correct / total

    acc_test = evaluate(test_loader)
    acc_trigger = torch.mean((torch.argmax(model(trigger_inputs), dim=1) == trigger_targets).float()).item()

    print("✅ Accuracy on test set:", acc_test)
    print("🔐 Accuracy on trigger set:", acc_trigger)

# Étape 5 : Appel de la fonction avec des loaders CIFAR-10
train_with_blackbox_watermark(
    model_watermarked,
    train_loader=train_loaders[4],  # loader client
    test_loader=test_loader,        # loader global
    trigger_inputs=trigger_inputs,
    trigger_targets=trigger_targets,
    epoch_max=5
)

Epoch 0 - Trigger Loss: 2.4723
Epoch 1 - Trigger Loss: 2.4471
Epoch 2 - Trigger Loss: 2.4433
Epoch 3 - Trigger Loss: 2.4269
Epoch 4 - Trigger Loss: 2.4330
✅ Accuracy on test set: 0.6685
🔐 Accuracy on trigger set: 0.0
