In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from typing import List, Dict
from models.loader import load_model
from client import Client
from cluster import ClusterDaddy
from datasets.dataloader import load_global_dataset, create_clustered_dataset
from aggregation.strategies import load_aggregator
import random
from copy import deepcopy

In [2]:
train_set, test_set = load_global_dataset('cifar10')
val_size = 5000
# split the training set into training and validation
train_dataset = Subset(train_set, range(len(train_set) - val_size))
val_dataset = Subset(train_set, range(len(train_set) - val_size, len(train_set)))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [40]:
model = load_model('cifarcnn')
model = model.cuda()


In [38]:

def permute_linear_layer_nodes(seq, layer_indices):
    
    permuted_seq = deepcopy(seq)
    layer1 = permuted_seq[layer_indices[0]]
    layer2 = permuted_seq[layer_indices[1]]

    # reorder the hidden neurons from layer1 to layer2
    layer1_weight = layer1.weight
    layer1_bias = layer1.bias
    layer2_weight = layer2.weight
    layer2_bias = layer2.bias

    perm =  torch.randperm(layer1_weight.size(0))
    layer1_weight = layer1_weight[perm]
    layer2_weight = layer2_weight[:,perm]

    permuted_seq[layer_indices[0]].weight.data = layer1_weight
    permuted_seq[layer_indices[0]].bias.data = layer1_bias
    permuted_seq[layer_indices[1]].weight.data = layer2_weight
    permuted_seq[layer_indices[1]].bias.data = layer2_bias

    return permuted_seq



In [41]:
client = Client(id=0, device=torch.device('cuda:0'), cluster_assignment=0)
# train the model on the client
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
model = client.train(deepcopy(model), train_loader, criterion, optimizer, num_epochs=5)
acc, loss = client.evaluate(model, val_loader, criterion)


# Example: Permuting only the Linear layer at index 4 (the last layer in the model)
layer_indicies = [7,9]

# Permute the parameters in the model
permuted_seq = permute_linear_layer_nodes(model.nn, layer_indicies)
permuted_model = deepcopy(model)
permuted_model.nn = permuted_seq

perm_acc, perm_loss = client.evaluate(permuted_model, val_loader, criterion)
print(f'Original model accuracy: {acc}, loss: {loss}')
print(f'Permuted model accuracy: {perm_acc}, loss: {perm_loss}')

assert acc == perm_acc and loss == perm_loss, "Permutation failed"


Client 0 initialized on device:  cuda:0
Client 0 epoch 0 loss: 2.3500362668525088
Client 0 epoch 1 loss: 2.3500593216581778
Client 0 epoch 2 loss: 2.349995043806054
Client 0 epoch 3 loss: 2.3498754680834035
Client 0 epoch 4 loss: 2.349979944865812
Original model accuracy: 2.3482912763764587, loss: 0.101
Permuted model accuracy: 2.3482912763764587, loss: 0.101


In [44]:
orig_layer = model.nn[7]
permuted_layer = permuted_model.nn[7]

# pointwise comparison of the weights
dist = 0
for i in range(orig_layer.weight.shape[0]):
    dist += torch.abs(orig_layer.weight[i] - permuted_layer.weight[i]).sum()
dist

tensor(5115.5806, device='cuda:0', grad_fn=<AddBackward0>)

In [34]:
orig_layer.weight.data, permuted_layer.weight.data

(tensor([[ 0.0388, -0.0616, -0.0528,  ..., -0.0023,  0.0210,  0.0355],
         [ 0.0003, -0.0494,  0.0330,  ...,  0.0345,  0.0391, -0.0115],
         [ 0.0455, -0.0112, -0.0368,  ..., -0.0242,  0.0681, -0.0338],
         ...,
         [ 0.0117,  0.0555, -0.0672,  ...,  0.0337,  0.0095,  0.0407],
         [-0.0585,  0.0345,  0.0570,  ...,  0.0032,  0.0090, -0.0789],
         [ 0.0774,  0.0579,  0.0732,  ..., -0.0116,  0.0414, -0.0384]],
        device='cuda:0'),
 tensor([[ 0.0388, -0.0616, -0.0528,  ..., -0.0023,  0.0210,  0.0355],
         [ 0.0003, -0.0494,  0.0330,  ...,  0.0345,  0.0391, -0.0115],
         [ 0.0455, -0.0112, -0.0368,  ..., -0.0242,  0.0681, -0.0338],
         ...,
         [ 0.0117,  0.0555, -0.0672,  ...,  0.0337,  0.0095,  0.0407],
         [-0.0585,  0.0345,  0.0570,  ...,  0.0032,  0.0090, -0.0789],
         [ 0.0774,  0.0579,  0.0732,  ..., -0.0116,  0.0414, -0.0384]],
        device='cuda:0'))