__import des librairies nécessaires__

In [109]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
import copy
import argparse
import os
#from utils import *

__définition des hyperparamètres__

In [110]:
args = {
    'num_client': 5,
    'seed': 123,
    'num_samples_per_client': 500,
    'rounds_stage1': 30, #100 de base
    'local_epochs_stage1': 5,
    'mini_batchsize_stage1': 64,
    'local_lr_stage1': 0.1,
    'rounds_stage2': 30, #100 de base
    'local_steps_stage2': 100,
    'local_lr_stage2': 0.001
}

__création d'un dossier de sauvegarde pour les modèles successifs du stage 1__

In [111]:
os.chdir('C:\\Users\\dmgtr\\OneDrive - Ecole Polytechnique\\3A\\P1\\MAP578 - EA collaborative learning\\Project\\TCT')

isExist = os.path.exists('./ckpt_stage1')
if not isExist:
   os.makedirs('./ckpt_stage1')

## __Stage 1__

### __hyperparamètres__

In [112]:
num_clients = args["num_client"]
num_rounds_stage1 = args["rounds_stage1"]
epochs_stage1 = args["local_epochs_stage1"]
batch_size_stage1 = args["mini_batchsize_stage1"]
lr_stage1 = args["local_lr_stage1"]

### __création des datasets décentralisés (ie non idd)__

In [113]:
## load les data MNIST, les transformer en tensor et les normaliser
traindata = datasets.MNIST('./data_mnist', train=True, download=True,
                           transform=transforms.Compose([transforms.ToTensor(),
                                                         transforms.Normalize((0.1307,), (0.3081,))]))


target_labels = torch.stack([traindata.targets == i for i in range(10)]) # 10 x 60000 (one-hot qui détermine la label correpondant à la ligne)

In [114]:
target_labels_split = []
torch.manual_seed(args["seed"]) # pour que les splits soient les mêmes à chaque fois
torch.cuda.manual_seed(args["seed"])  # pour que les splits soient les mêmes à chaque fois

for i in range(num_clients):
    index_split = torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0] # on prend les labels 2i et 2i+1
    perm_split = torch.randperm(index_split.size(0)) # on mélange les indices
    index_split_subsample = index_split[perm_split[:args["num_samples_per_client"]]] # on prend les 500 premiers
    target_labels_split += [index_split_subsample] # on ajoute à la liste des labels splités

#Chacun des 5 clients reçoit 500 images d'un des deux labels associés

In [115]:
# Training datasets (subsampled)
traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split] # chaque élément contient les images et labels d'un client
train_loader = [torch.utils.data.DataLoader(train_subset, batch_size=batch_size_stage1, shuffle=True)
                for train_subset in traindata_split] # on crée les dataloader associés

### __création du dataset global de test__

In [116]:
# Test dataset (subsampled)
testdata = datasets.MNIST('./data_mnist', train=False,
                          transform=transforms.Compose([transforms.ToTensor(),
                                                        transforms.Normalize((0.1307,), (0.3081,))])) # on charge les données de test

torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
perm_split_test = torch.randperm(testdata.targets.shape[0])
testdata_subset = torch.utils.data.Subset(testdata, perm_split_test[:1000])
test_loader = torch.utils.data.DataLoader(testdata_subset, batch_size=batch_size_stage1, shuffle=False) #pas de shuffle pour le test

### __modèle de réseau de neurones de base__

In [117]:
## ajouté depuis utils

import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [118]:
global_model = Net() #modifié depuis Net().cuda() #Modèle fédéré
client_models = [Net() for _ in range(num_clients)] #modifié depuis Net().cuda() #Modèles des clients
for model in client_models:
    model.load_state_dict(global_model.state_dict())
opt = [optim.SGD(model.parameters(), lr=lr_stage1) for model in client_models]

### __FedAvg__

In [119]:
def client_update(client_model, optimizer, train_loader, epoch=5): #réalise un pas d'optimization et retourne la loss
    """Train a client_model on the train_loder data."""
    client_model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data, target #supprimer data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()

In [120]:
def average_models(global_model, client_models):
    """Average models across all clients."""
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0) #stack les résultas et fait la moyenne par client
    global_model.load_state_dict(global_dict)

In [121]:
def evaluate_model(model, data_loader):
    """Compute loss and accuracy of a single model on a data_loader."""
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data, target #supprimer .cuda()
            output = model(data)
            loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(data_loader.dataset)
    acc = correct / len(data_loader.dataset)

    return loss, acc

def evaluate_many_models(models, data_loader):
    """Compute average loss and accuracy of multiple models on a data_loader."""
    num_nodes = len(models)
    losses = np.zeros(num_nodes)
    accuracies = np.zeros(num_nodes)
    for i in range(num_nodes):
        losses[i], accuracies[i] = evaluate_model(models[i], data_loader)
    return losses, accuracies

In [122]:
# Run TCT-Stage1 (i.e., FedAvg)
for r in range(num_rounds_stage1):
    # load global weights
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    # client update
    loss = 0
    for i in range(num_clients):
        loss += client_update(client_models[i], opt[i], train_loader[i], epoch=epochs_stage1)

    # average params across neighbors
    average_models(global_model, client_models)

    # evaluate
    test_losses, accuracies = evaluate_many_models(client_models, test_loader)
    torch.save(client_models[0].state_dict(), './ckpt_stage1/model_tct_stage1.pth')

    print('%d-th round: average train loss %0.3g | average test loss %0.3g | average test acc: %0.3f' % (
    r, loss / num_clients, test_losses.mean(), accuracies.mean()))

0-th round: average train loss 0.168 | average test loss 6.74 | average test acc: 0.194
1-th round: average train loss 0.0942 | average test loss 7.74 | average test acc: 0.191
2-th round: average train loss 0.0442 | average test loss 7.57 | average test acc: 0.195
3-th round: average train loss 0.0277 | average test loss 7.24 | average test acc: 0.195
4-th round: average train loss 0.0269 | average test loss 7.62 | average test acc: 0.196
5-th round: average train loss 0.0244 | average test loss 6.69 | average test acc: 0.193
6-th round: average train loss 0.0285 | average test loss 6.71 | average test acc: 0.196
7-th round: average train loss 0.0298 | average test loss 6.09 | average test acc: 0.197
8-th round: average train loss 0.0139 | average test loss 6.04 | average test acc: 0.205
9-th round: average train loss 0.0122 | average test loss 5.79 | average test acc: 0.204
10-th round: average train loss 0.0116 | average test loss 5.37 | average test acc: 0.222
11-th round: average 

## __Stage 2__

### __hyperparamètres__

In [123]:
num_rounds_stage2 = args["rounds_stage2"]
batch_size = args["num_samples_per_client"]

### __modèle eNTK__

In [124]:
class Net_eNTK(nn.Module):
    def __init__(self):
        super(Net_eNTK, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [125]:
# Init and load model ckpt
global_model = Net_eNTK() #supprimer .cuda()
global_model.load_state_dict(torch.load('./ckpt_stage1/model_tct_stage1.pth'))
global_model.fc2 = nn.Linear(128, 1) #supprimer .cuda() #récupérer une unique sortie ici #supprimer le dernier layer pour le remplacer (passer de 128->10 à 128->1)
print('load model')

load model


### __Compute eNTK__

In [126]:
def compute_eNTK(model, X, subsample_size=100000, seed=123):
    """"compute eNTK"""
    model.eval()
    params = list(model.parameters()) #liste de tous les paramètres trainable du modèle
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random_index = torch.randperm(355073)[:subsample_size]
    grads = None
    for i in tqdm(range(X.size()[0])):
        model.zero_grad() #réinitialise les gradients
        model.forward(X[i : i+1])[0].backward() #calcule les gradients  

        grad = []
        for param in params: #param.requires_grad dans PyTorch est un attribut des tenseurs (y compris ceux qui représentent les paramètres d'un modèle, comme les poids et les biais d'un réseau de neurones) qui détermine si PyTorch doit calculer les gradients pour ces tenseurs pendant la rétropropagation.
            if param.requires_grad:
                grad.append(param.grad.flatten())
        grad = torch.cat(grad) #Concatène tous les gradients aplatis en un seul vecteur
        grad = grad[random_index] #Réduit la dimensionnalité du vecteur de gradient en utilisant l'indexation aléatoire définie par random_index

        if grads is None:
            grads = torch.zeros((X.size()[0], grad.size()[0]), dtype=torch.half) #Si grads est None, une matrice zéro de la forme appropriée est initialisée.
        grads[i, :] = grad #Stocke le vecteur de gradient subsamplé pour l'échantillon i dans la matrice grads

    return grads

def client_compute_eNTK(client_model, train_loader):
    """Train a client_model on the train_loder data."""
    client_model.train()

    data, targets = next(iter(train_loader))
    grads_data = compute_eNTK(client_model, data) #supprimer .cuda()
    grads_data = grads_data.float() #supprimer .cuda()
    targets = targets #supprimer .cuda()
    # gradient
    targets_onehot = F.one_hot(targets, num_classes=10) - (1.0 / 10.0) #supprimer .cuda()
    del data
    torch.cuda.empty_cache()
    return grads_data, targets_onehot, targets

In [127]:
# Train
grad_all = []
target_all = []
target_onehot_all = []
for i in range(num_clients):
    grad_i, target_onehot_i, target_i = client_compute_eNTK(global_model, train_loader[i])
    grad_all.append(copy.deepcopy(grad_i).cpu())
    target_all.append(copy.deepcopy(target_i).cpu())
    target_onehot_all.append(copy.deepcopy(target_onehot_i).cpu())
    del grad_i
    del target_onehot_i
    del target_i
    torch.cuda.empty_cache()

100%|██████████| 64/64 [00:00<00:00, 265.21it/s]
 44%|████▍     | 28/64 [00:00<00:00, 277.41it/s]

100%|██████████| 64/64 [00:00<00:00, 210.86it/s]
100%|██████████| 64/64 [00:00<00:00, 379.58it/s]
100%|██████████| 64/64 [00:00<00:00, 363.16it/s]
100%|██████████| 64/64 [00:00<00:00, 310.73it/s]


In [128]:
grad_all[0].shape

torch.Size([64, 100000])

In [129]:
# Test
grad_eval, target_eval_onehot, target_eval  = client_compute_eNTK(global_model, test_loader)

  0%|          | 0/64 [00:00<?, ?it/s]

100%|██████████| 64/64 [00:00<00:00, 120.84it/s]


### __run stage 2__

In [130]:
# Init linear models
theta_global = torch.zeros(100000, 10) #supprimer .cuda()
theta_global = torch.tensor(theta_global, requires_grad=False)
client_thetas = [torch.zeros_like(theta_global) for _ in range(num_clients)] #supprimer .cuda()
client_hi_s = [torch.zeros_like(theta_global) for _ in range(num_clients)] #supprimer .cuda()

  theta_global = torch.tensor(theta_global, requires_grad=False)


In [131]:
def scaffold_update(grads_data, targets, theta_client, h_i_client_pre, theta_global,
                    M=200, lr_local=0.00001):
    # set up data / eNTK
    grads_data = grads_data.float() #supprimer .cuda()
    targets = targets #supprimer .cuda()

    # compute transformed onehot label
    targets_onehot = F.one_hot(targets, num_classes=10) - (1.0 / 10.0) #supprimer .cuda()
    num_samples = targets_onehot.shape[0]

    # compute updates
    h_i_client_update = h_i_client_pre + (1 / (M * lr_local)) * (theta_global - theta_client)
    theta_hat_local = (theta_global) * 1.0

    # local gd
    for local_iter in range(M):
        theta_hat_local -= lr_local * ((1.0 / num_samples) * grads_data.t() @ (grads_data @ theta_hat_local - targets_onehot) - h_i_client_update)

    del targets
    del grads_data
    torch.cuda.empty_cache()
    return theta_hat_local, h_i_client_update

In [132]:
# Run TCT-Stage2
for round_idx in range(num_rounds_stage2):
    theta_list = []
    for i in range(num_clients):
        theta_hat_update, h_i_client_update = scaffold_update(grad_all[i],
                                                              target_all[i],
                                                              client_thetas[i],
                                                              client_hi_s[i],
                                                              theta_global,
                                                              M=args["local_steps_stage2"],
                                                              lr_local=args["local_lr_stage2"])
        client_hi_s[i] = h_i_client_update * 1.0
        client_thetas[i] = theta_hat_update * 1.0
        theta_list.append(theta_hat_update)

    # averaging
    theta_global = torch.zeros_like(theta_list[0]) #supprimer .cuda()
    for theta_idx in range(num_clients):
        theta_global += (1.0 / num_clients) * theta_list[theta_idx]

    # eval on train
    logits_class_train = torch.cat(grad_all) @ theta_global #supprimer .cuda()
    _, targets_pred_train = logits_class_train.max(1)
    train_acc = targets_pred_train.eq(torch.cat(target_all)).sum() / (1.0 * logits_class_train.shape[0]) #supprimer .cuda()
    # eval on test
    logits_class_test = grad_eval @ theta_global
    _, targets_pred_test = logits_class_test.max(1)
    test_acc = targets_pred_test.eq(target_eval).sum() / (1.0 * logits_class_test.shape[0]) #supprimer .cuda()
    print('Round %d: train accuracy=%0.5g test accuracy=%0.5g' % (round_idx, train_acc.item(), test_acc.item()))

Round 0: train accuracy=0.79062 test accuracy=0.78125
Round 1: train accuracy=0.83125 test accuracy=0.85938
Round 2: train accuracy=0.86563 test accuracy=0.875
Round 3: train accuracy=0.88125 test accuracy=0.875
Round 4: train accuracy=0.90312 test accuracy=0.875
Round 5: train accuracy=0.92188 test accuracy=0.85938
Round 6: train accuracy=0.93437 test accuracy=0.85938
Round 7: train accuracy=0.94375 test accuracy=0.875
Round 8: train accuracy=0.95 test accuracy=0.875
Round 9: train accuracy=0.95938 test accuracy=0.875
Round 10: train accuracy=0.96562 test accuracy=0.875
Round 11: train accuracy=0.96875 test accuracy=0.875
Round 12: train accuracy=0.96875 test accuracy=0.89062
Round 13: train accuracy=0.96875 test accuracy=0.90625
Round 14: train accuracy=0.97188 test accuracy=0.90625
Round 15: train accuracy=0.975 test accuracy=0.92188
Round 16: train accuracy=0.975 test accuracy=0.92188
Round 17: train accuracy=0.97812 test accuracy=0.92188
Round 18: train accuracy=0.97812 test accur