__import des librairies nécessaires__

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
import copy
import argparse
import os
import sys
#from utils import *

__import des autres fichiers__

In [15]:
print(os.getcwd())
os.chdir('..')


/Users/Theo/Documents/Dépots Githubs/2023/chocoEA


In [16]:
print(os.getcwd())

/Users/Theo/Documents/Dépots Githubs/2023


In [17]:
from nv_orga.Models import Net, Net_eNTK
from nv_orga.FedAvg import average_models,client_update
from nv_orga.Eval import evaluate_many_models
from nv_orga.NTK import client_compute_eNTK
from nv_orga.Scaffold import scaffold_update

__définition des hyperparamètres__

In [18]:
args = {
    'num_client': 5,
    'seed': 123,
    'num_samples_per_client': 500,
    'rounds_stage1': 100, #100 de base
    'local_epochs_stage1': 5,
    'mini_batchsize_stage1': 64,
    'local_lr_stage1': 0.1,
    'rounds_stage2': 200, #100 de base
    'local_steps_stage2': 100,
    'local_lr_stage2': 0.001,
    'rounds_stage3': 5,
}

__création d'un dossier de sauvegarde pour les modèles successifs du stage 1__

In [19]:
isExist = os.path.exists('data/ckpt_stage1')
if not isExist:
   os.makedirs('data/ckpt_stage1')

## __Stage 1__

### __hyperparamètres__

In [20]:
num_clients = args["num_client"]
num_rounds_stage1 = args["rounds_stage1"]

epochs_stage1 = args["local_epochs_stage1"]
batch_size_stage1 = args["mini_batchsize_stage1"]
lr_stage1 = args["local_lr_stage1"]

### __création des datasets décentralisés (ie non idd)__

In [21]:
## load les data MNIST, les transformer en tensor et les normaliser
traindata = datasets.MNIST('data/data_mnist', train=True, download=True,
                           transform=transforms.Compose([transforms.ToTensor(),
                                                         transforms.Normalize((0.1307,), (0.3081,))]))


target_labels = torch.stack([traindata.targets == i for i in range(10)]) # 10 x 60000 (one-hot qui détermine la label correpondant à la ligne)

In [22]:
target_labels_split = []
torch.manual_seed(args["seed"]) # pour que les splits soient les mêmes à chaque fois
torch.cuda.manual_seed(args["seed"])  # pour que les splits soient les mêmes à chaque fois

for i in range(num_clients):
    index_split = torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0] # on prend les labels 2i et 2i+1
    perm_split = torch.randperm(index_split.size(0)) # on mélange les indices
    index_split_subsample = index_split[perm_split[:args["num_samples_per_client"]]] # on prend les 500 premiers
    target_labels_split += [index_split_subsample] # on ajoute à la liste des labels splités

#Chacun des 5 clients reçoit 500 images d'un des deux labels associés

In [23]:
# Training datasets (subsampled)
traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split] # chaque élément contient les images et labels d'un client
train_loader = [torch.utils.data.DataLoader(train_subset, batch_size=batch_size_stage1, shuffle=True)
                for train_subset in traindata_split] # on crée les dataloader associés

### __création du dataset global de test__

In [24]:
# Test dataset (subsampled)
testdata = datasets.MNIST('data/data_mnist', train=False,
                          transform=transforms.Compose([transforms.ToTensor(),
                                                        transforms.Normalize((0.1307,), (0.3081,))])) # on charge les données de test

torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
perm_split_test = torch.randperm(testdata.targets.shape[0])
testdata_subset = torch.utils.data.Subset(testdata, perm_split_test[:1000])
test_loader = torch.utils.data.DataLoader(testdata_subset, batch_size=batch_size_stage1, shuffle=False) #pas de shuffle pour le test

### __modèle de réseau de neurones de base__

In [25]:
global_model = Net() #modifié depuis Net().cuda() #Modèle fédéré
client_models = [Net() for _ in range(num_clients)] #modifié depuis Net().cuda() #Modèles des clients
for model in client_models:
    model.load_state_dict(global_model.state_dict())
opt = [optim.SGD(model.parameters(), lr=lr_stage1) for model in client_models]

### __imple de FedAvg__

In [26]:
# Run TCT-Stage1 (i.e., FedAvg)
for r in range(num_rounds_stage1):
    # load global weights
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    # client update
    loss = 0
    for i in range(num_clients):
        loss += client_update(client_models[i], opt[i], train_loader[i], epoch=epochs_stage1)

    # average params across neighbors
    average_models(global_model, client_models)

    # evaluate
    test_losses, accuracies = evaluate_many_models(client_models, test_loader)
    # torch.save(client_models[0].state_dict(), 'data/ckpt_stage1/model_tct_stage1.pth')

    print('%d-th round: average train loss %0.3g | average test loss %0.3g | average test acc: %0.3f' % (
    r, loss / num_clients, test_losses.mean(), accuracies.mean()))

0-th round: average train loss 0.18 | average test loss 6.58 | average test acc: 0.194
1-th round: average train loss 0.0922 | average test loss 7.6 | average test acc: 0.192
2-th round: average train loss 0.0428 | average test loss 7.51 | average test acc: 0.195
3-th round: average train loss 0.0274 | average test loss 7.25 | average test acc: 0.195
4-th round: average train loss 0.0339 | average test loss 7.56 | average test acc: 0.196
5-th round: average train loss 0.0219 | average test loss 6.34 | average test acc: 0.195


## __Stage 2__

### __hyperparamètres__

In [None]:
num_rounds_stage2 = args["rounds_stage2"]
batch_size = args["num_samples_per_client"]

### __modèle eNTK__

In [None]:
# Init and load model ckpt
global_model = Net_eNTK() #supprimer .cuda()
global_model.load_state_dict(torch.load('data/ckpt_stage1/stage1_100rounds_5workers.pth'))
global_model.fc2 = nn.Linear(128, 1) #supprimer .cuda() #récupérer une unique sortie ici #supprimer le dernier layer pour le remplacer (passer de 128->10 à 128->1)
print('load model')

load model


### __Compute eNTK__

In [None]:
# Train
grad_all = []
target_all = []
target_onehot_all = []
for i in range(num_clients):
    grad_i, target_onehot_i, target_i = client_compute_eNTK(global_model, train_loader[i])
    grad_all.append(copy.deepcopy(grad_i).cpu())
    target_all.append(copy.deepcopy(target_i).cpu())
    target_onehot_all.append(copy.deepcopy(target_onehot_i).cpu())
    del grad_i
    del target_onehot_i
    del target_i
    torch.cuda.empty_cache()

 56%|█████▋    | 36/64 [00:00<00:00, 354.08it/s]

100%|██████████| 64/64 [00:00<00:00, 373.20it/s]
100%|██████████| 64/64 [00:00<00:00, 493.31it/s]
100%|██████████| 64/64 [00:00<00:00, 525.38it/s]
100%|██████████| 64/64 [00:00<00:00, 506.04it/s]


In [None]:
grad_all[0].shape

torch.Size([64, 100000])

In [None]:
# Test
grad_eval, target_eval_onehot, target_eval  = client_compute_eNTK(global_model, test_loader)

100%|██████████| 64/64 [00:00<00:00, 493.51it/s]


### __run stage 2__

In [None]:
# Init linear models
theta_global = torch.zeros(100000, 10) #supprimer .cuda()
theta_global = torch.tensor(theta_global, requires_grad=False)
client_thetas = [torch.zeros_like(theta_global) for _ in range(num_clients)] #supprimer .cuda()
client_hi_s = [torch.zeros_like(theta_global) for _ in range(num_clients)] #supprimer .cuda()

  theta_global = torch.tensor(theta_global, requires_grad=False)


In [None]:
# Run TCT-Stage2
for round_idx in range(num_rounds_stage2):
    theta_list = []
    for i in range(num_clients):
        theta_hat_update, h_i_client_update = scaffold_update(grad_all[i],
                                                              target_all[i],
                                                              client_thetas[i],
                                                              client_hi_s[i],
                                                              theta_global,
                                                              M=args["local_steps_stage2"],
                                                              lr_local=args["local_lr_stage2"])
        client_hi_s[i] = h_i_client_update * 1.0
        client_thetas[i] = theta_hat_update * 1.0
        theta_list.append(theta_hat_update)

    # averaging
    theta_global = torch.zeros_like(theta_list[0]) #supprimer .cuda()
    for theta_idx in range(num_clients):
        theta_global += (1.0 / num_clients) * theta_list[theta_idx]

    # eval on train
    logits_class_train = torch.cat(grad_all) @ theta_global #supprimer .cuda()
    _, targets_pred_train = logits_class_train.max(1)
    train_acc = targets_pred_train.eq(torch.cat(target_all)).sum() / (1.0 * logits_class_train.shape[0]) #supprimer .cuda()
    # eval on test
    logits_class_test = grad_eval @ theta_global
    _, targets_pred_test = logits_class_test.max(1)
    test_acc = targets_pred_test.eq(target_eval).sum() / (1.0 * logits_class_test.shape[0]) #supprimer .cuda()
    print('Round %d: train accuracy=%0.5g test accuracy=%0.5g' % (round_idx, train_acc.item(), test_acc.item()))

Round 0: train accuracy=0.94141 test accuracy=0.73438
Round 1: train accuracy=0.96875 test accuracy=0.78125
Round 2: train accuracy=0.98047 test accuracy=0.78125
Round 3: train accuracy=0.98438 test accuracy=0.78125
Round 4: train accuracy=0.98828 test accuracy=0.76562
Round 5: train accuracy=0.99219 test accuracy=0.76562
Round 6: train accuracy=0.99219 test accuracy=0.76562
Round 7: train accuracy=0.99219 test accuracy=0.78125
Round 8: train accuracy=0.99609 test accuracy=0.78125
Round 9: train accuracy=0.99609 test accuracy=0.78125
Round 10: train accuracy=1 test accuracy=0.78125
Round 11: train accuracy=1 test accuracy=0.78125
Round 12: train accuracy=1 test accuracy=0.78125
Round 13: train accuracy=1 test accuracy=0.78125
Round 14: train accuracy=1 test accuracy=0.78125
Round 15: train accuracy=1 test accuracy=0.78125
Round 16: train accuracy=1 test accuracy=0.78125
Round 17: train accuracy=1 test accuracy=0.78125
Round 18: train accuracy=1 test accuracy=0.78125
Round 19: train acc