In [109]:
import torch.nn as nn
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
cos = nn.CosineSimilarity(dim=0, eps=1e-9)
from torchvision import transforms
# from codes.models import ConvNet, resnet8, MLP
from utils import *
from torch.utils.data import DataLoader
from models import *
import random
from collections import Counter
from collections import OrderedDict
import seaborn as sns
device = "cuda"

# adjustable parameters
alpha_d = 10
local_ep = 5
mali_local_ep = 20
points = 41
global attack 
attack = "backdoor" #"backdoor", "tlp"
model_name = "resnet8" # "resnet8", "ConvNet"
num_classes = 10
dataset ="fmnist"

In [110]:

def get_delta_cos(model1, model2, model0_sd):
    flat_model0 = flat_dict(model0_sd)
    flat_model1 = flat_dict(model1.state_dict())
    flat_model2 = flat_dict(model2.state_dict())
    
    delta = torch.abs(flat_model1 - flat_model2)
    org_cos = cos((flat_model1 - flat_model0), (flat_model2 - flat_model0))
    return delta, 1-org_cos.item()

def model_eval(model, test_loader, attack):
    acc = eval_op_ensemble([model], test_loader)
    if attack == "tlp":
        asr = eval_op_ensemble_tr_lf_attack([model], test_loader)
    elif attack == "backdoor":
        asr = eval_op_ensemble_attack([model], test_loader)
    return list(acc.values())[0], list(asr.values())[0]


def sample_and_replace(model1, model2, model3, k_percent):
    """
    Randomly selects k% of parameters within each layer from model2 and replaces them in model1.
    Returns a new model (model3) with mixed parameters and a dictionary of replaced indices.
    """
    model_1_sd = {key: value.clone() for key, value in model1.state_dict().items()}
    model3.load_state_dict(model_1_sd)  # Start with model1's parameters
    replaced_params = {}
    
    state_dict = model3.state_dict()
    for param_name in model1.state_dict().keys():
        param1 = model1.state_dict()[param_name].clone()
        param2 = model2.state_dict()[param_name].clone()
        num_elements = param1.numel()
        num_replace = int((k_percent / 100) * num_elements)
        
        if num_replace > 0:
            indices = random.sample(range(num_elements), num_replace)
            param1.view(-1)[indices] = param2.view(-1)[indices]
            param1_flat = param1.view(-1)
            param2_flat = param2.view(-1)
            param1_flat[indices] = param2_flat[indices]
            state_dict[param_name] = param1.view(param1.shape)
            replaced_params[param_name] = indices
    
    model3.load_state_dict(state_dict)
    return model3, replaced_params




def sampling_experiment(model1, model2, model3, model0_sd, main_dataloader, device, k_percent, p):
    """
    Runs p sampling experiments, evaluates model3 on both main and side tasks.
    """
    model_1_sd = {key: value.clone() for key, value in model1.state_dict().items()}
    
    results = {}
    replaced_params_list=[]
    for i in range(p):    
        model3.load_state_dict(model_1_sd)
        model3, replaced_params = sample_and_replace(model1, model2, model3, k_percent)
        acc_, asr_  = model_eval(model3, main_dataloader, attack)
        delta, cos_dist = get_delta_cos(model1, model3, model0_sd)
        results[i] = (acc_, asr_, cos_dist)
        replaced_params_list.append({asr_:  replaced_params})
    return results, replaced_params_list, model3


def replaced_params_count(replaced_params_list):
    """
    Counts occurrences of values per layer across multiple replaced_params dictionaries.

    :param replaced_params_list: List of replaced_params dictionaries.
    :return: Dictionary where keys are layer names and values are Counters of occurrences.
    """
    layer_counts = {}

    for replaced_params in replaced_params_list:
        for layer, values in replaced_params.items():
            if layer not in layer_counts:
                layer_counts[layer] = Counter()
            layer_counts[layer].update(values)

    return layer_counts


def convert_to_state_dict(layer_counts, state_dict):
    ind_w = OrderedDict()
    
    # Iterate over the outer dictionary (float keys)
    for weight, layers in layer_counts.items():
        for layer_name, indices in layers.items():
            print("layer_name", layer_name, indices)
            # If the layer is not in state_dict, initialize it with an empty list of zeros
            if layer_name not in ind_w:
                ind_w[layer_name] = {}
            
            for index in range(state_dict[layer_name].numel()):  
                if index in indices:
                    # Add the weight to the corresponding index in the layer
                    if index not in ind_w[layer_name]:
                        ind_w[layer_name][index] = 0
                    ind_w[layer_name][index] += weight
                else:
                    # other not selected indices
                    ind_w[layer_name][index] = 0
            
    
    # print("ind_w keys", ind_w.keys())
    # print("state_dict keys", state_dict.keys())
    
    for name, asr_dict in ind_w.items():
        asr_w = torch.tensor([value for key, value in sorted(asr_dict.items())])
        print(f"name:{name}, asr_w:{asr_w.numel()}, state_dict:{state_dict[name].numel()}")
        state_dict[name] = asr_w.view(state_dict[name].shape)
    
    return state_dict


def plot_layer_weights(layer_name, params, title, save_plot=False):
    """
    Plots the weights of a specific layer in a PyTorch model as a heatmap.
    
    Parameters:
        layer (torch.nn.Module): The PyTorch layer (e.g., torch.nn.Linear, torch.nn.Conv2d).
        layer_name (str): A text string to label the plot and use in the filename.
        save_plot (bool): If True, saves the plot as a PNG file with the layer_name as part of the filename.
    """
    # Check if the layer has weights
    # if not hasattr(layer, 'weight'):
    #     raise ValueError(f"The provided layer does not have weights.")
    
    # Extract the weights
    params = params.reshape(params.size(0), -1)
    
    weights = params.cpu().numpy()
    
    # Plot the weights as a heatmap
    plt.figure(figsize=(10, 6))
    sns.heatmap(weights, cmap="coolwarm", annot=False, cbar=True)
    plt.title(f'Weight Matrix of {layer_name} _ {title}')
    plt.xlabel('Neuron Index')
    plt.ylabel('Input Index')
    
    if save_plot:
        # Save the plot with the layer_name as part of the filename
        filename = f"{layer_name}_{title}_weights_heatmap.png"
        plt.savefig(filename, bbox_inches='tight', dpi=300)
        print(f"Plot saved as {filename}")
    
    plt.show()

In [111]:


# Define transformation (convert images to tensors and normalize)
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image with mean and std
])

# Load the training dataset
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Load the test dataset
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batch processing
client_loaders, test_loader, client_data_subsets =\
    data.get_loaders(train_data, test_data, n_clients=100,
                    alpha=alpha_d, batch_size=32, n_data=None, num_workers=4, seed=4)
    
model_fn = partial(models.get_model(model_name)[
                        0], num_classes=num_classes, dataset=dataset)

client_loader = client_loaders[0]

# created models
model1 = model_fn().to(device)
model2 = model_fn().to(device)
model3 = model_fn().to(device)

model0_sd = {k: v.clone().detach() for k, v in model1.state_dict().items()}

optimizer1 = optim.SGD(model1.parameters(), lr=0.001)
optimizer2 = optim.SGD(model2.parameters(), lr=0.001)
optimizer3 = optim.SGD(model3.parameters(), lr=0.001)


Data split:
 - Client 0: [57 58 69 45 68 75 60 68 56 39]                         -> sum=595
 - Client 1: [61 71 49 54 64 68 58 52 59 64]                         -> sum=600
 - Client 2: [43 51 78 54 53 85 81 54 47 55]                         -> sum=601
 - Client 3: [76 54 72 52 69 71 33 41 55 76]                         -> sum=599
 - Client 4: [52 89 66 45 43 64 65 36 56 83]                         -> sum=599
 - Client 5: [ 33  28  59  88 114  69  34  62  43  71]               -> sum=601
 - Client 6: [61 72 27 86 50 74 72 66 40 52]                         -> sum=600
 - Client 7: [40 54 59 36 68 52 55 89 97 50]                         -> sum=600
 - Client 8: [72 53 25 96 71 60 67 55 55 47]                         -> sum=601
 - Client 9: [52 69 51 45 79 60 50 78 47 69]                         -> sum=600
.  .  .  .  .  .  .  .  .  .  
.  .  .  .  .  .  .  .  .  .  
.  .  .  .  .  .  .  .  .  .  
 - Client 91: [39 66 82 85 72 58 45 50 41 63]                         -> sum=601
 - Client 92: 

In [112]:
# model1 train benign
train_op(model1, client_loader, optimizer1, epochs=local_ep, print_train_loss=True)

model_1_sd = {key: value.clone() for key, value in model1.state_dict().items()}

# model2 train from model1
model2.load_state_dict(model_1_sd)
if attack == "tlp":
    train_op_tr_flip(model2, client_loader, optimizer2, epochs=mali_local_ep, class_num=10, print_train_loss=True)
elif attack == "backdoor":
    train_op_backdoor(model2, client_loader, optimizer2, epochs=local_ep)

acc1, asr1 = model_eval(model1, test_loader, attack)
acc2, asr2 = model_eval(model2, test_loader, attack)

# delta, org_cos
delta0, org_cos2 = get_delta_cos(model1, model2, model0_sd)
print(f"model1 acc:{acc1}, asr:{asr1}, cos dist:{0}")
print(f"model2 acc:{acc2}, asr:{asr2}, cos dist:{org_cos2}")

# random average the attack results and main task performance
random_exp={}
for k in np.arange(5, 100, 5, dtype=int):
    print("current k", k)
    results, replaced_params_list, model3 = sampling_experiment(model1, model2, model3, model0_sd, client_loader, device, k_percent=k, p=20)
    
    values = np.array(list(results.values()))
    means = values.mean(axis=0)
    stds = values.std(axis=0)
    random_exp[k] = list(zip(means, stds))
    
print("results:\n", random_exp)

[2.3, 2.3, 2.3, 2.3, 2.3, 2.3, 2.3, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.29, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.28, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27, 2.27]
model1 acc:0.1091, asr:0.0, cos dist:0
model2 acc:0.1, asr:1.0, cos dist:0.9998775282292627
current k 5
current k 10
current k 15
current k 20
current k 25
current k 30
current k 35
current k 40
current k 45
current k 50
current k 55
current k 60
current k 65
current k 70
current k 75
current k 80
current k 85
current k 90
current k 95
results:
 {5: [(0.12478991596638656, 0.014100379596042413), (0.0, 0.0), (0.9924283019034192, 0.00020923111986223916)], 10: [(0.1372268907563025, 0.033794783964460784), (0.01654275092936803, 0.0721081795492788), (0.9944799275370315, 0.00015803352583459234)], 15: [(0.13159663865546217, 0.026134859085187055), (0.05204460966542751, 0.13814048362931422), (0.9954169044271

In [113]:
random_exp

{5: [(0.12478991596638656, 0.014100379596042413),
  (0.0, 0.0),
  (0.9924283019034192, 0.00020923111986223916)],
 10: [(0.1372268907563025, 0.033794783964460784),
  (0.01654275092936803, 0.0721081795492788),
  (0.9944799275370315, 0.00015803352583459234)],
 15: [(0.13159663865546217, 0.026134859085187055),
  (0.05204460966542751, 0.13814048362931422),
  (0.9954169044271112, 9.354627979076157e-05)],
 20: [(0.12478991596638656, 0.022314840846962836),
  (0.032620817843866175, 0.09517286773537231),
  (0.9959875620203092, 6.519513378763989e-05)],
 25: [(0.12193277310924369, 0.022948870566197925),
  (0.20678438661710033, 0.2967114961677403),
  (0.996340685558971, 7.536772941767523e-05)],
 30: [(0.12571428571428572, 0.017177424367934433),
  (0.18633828996282525, 0.29511313154384183),
  (0.9965738208149559, 4.239466953040799e-05)],
 35: [(0.11899159663865547, 0.026893907378402672),
  (0.44507434944237917, 0.39638106482226554),
  (0.9967893622000702, 3.42743421535208e-05)],
 40: [(0.10815126050

In [114]:
# Covert the award for experiments to a model state dict to save 
# results, replaced_params_list, model3 = sampling_experiment(model1, model2, model3, model0_sd, client_loader, device, k_percent=50, p=20)

# # Take long time
# layer_counts = replaced_params_count(replaced_params_list)
# print("layer_counts", layer_counts)

# state_dict = convert_to_state_dict(layer_counts, model3.state_dict())
# model3.load_state_dict(state_dict)
# torch.save(model3.state_dict(), 'model3_sd.pth')


# for name, params in state_dict.items():
#     plot_layer_weights(name, state_dict[name], title="k50", save_plot=True)

In [115]:
results

{0: (0.0957983193277311, 1.0, 0.9975994608830661),
 1: (0.0957983193277311, 1.0, 0.9975759075023234),
 2: (0.0957983193277311, 1.0, 0.9975880386773497),
 3: (0.0957983193277311, 1.0, 0.9975948047358543),
 4: (0.0957983193277311, 1.0, 0.9975846805609763),
 5: (0.0957983193277311, 1.0, 0.9976033759303391),
 6: (0.0957983193277311, 1.0, 0.9975934950634837),
 7: (0.0957983193277311, 1.0, 0.9975917611736804),
 8: (0.0957983193277311, 1.0, 0.997599811758846),
 9: (0.0957983193277311, 1.0, 0.9976017780136317),
 10: (0.0957983193277311, 1.0, 0.9975874363444746),
 11: (0.0957983193277311, 1.0, 0.9975867567118257),
 12: (0.0957983193277311, 1.0, 0.9975916817784309),
 13: (0.0957983193277311, 1.0, 0.9975778355728835),
 14: (0.0957983193277311, 1.0, 0.9975984443444759),
 15: (0.0957983193277311, 1.0, 0.997605849057436),
 16: (0.0957983193277311, 1.0, 0.9975960054434836),
 17: (0.0957983193277311, 1.0, 0.9975901450961828),
 18: (0.0957983193277311, 1.0, 0.9975862675346434),
 19: (0.0957983193277311

## Replace based on the critial parameters from analysis, if can save more

In [116]:
delta_model = model_fn().to(device)
delta_model.load_state_dict(torch.load("model3_60_sd.pth", weights_only=True))

RuntimeError: Error(s) in loading state_dict for resnet8:
	Missing key(s) in state_dict: "f.0.weight", "f.1.weight", "f.1.bias", "f.1.running_mean", "f.1.running_var", "f.3.0.conv1.weight", "f.3.0.bn1.weight", "f.3.0.bn1.bias", "f.3.0.bn1.running_mean", "f.3.0.bn1.running_var", "f.3.0.conv2.weight", "f.3.0.bn2.weight", "f.3.0.bn2.bias", "f.3.0.bn2.running_mean", "f.3.0.bn2.running_var", "f.4.0.conv1.weight", "f.4.0.bn1.weight", "f.4.0.bn1.bias", "f.4.0.bn1.running_mean", "f.4.0.bn1.running_var", "f.4.0.conv2.weight", "f.4.0.bn2.weight", "f.4.0.bn2.bias", "f.4.0.bn2.running_mean", "f.4.0.bn2.running_var", "f.4.0.downsample.0.weight", "f.4.0.downsample.1.weight", "f.4.0.downsample.1.bias", "f.4.0.downsample.1.running_mean", "f.4.0.downsample.1.running_var", "f.5.0.conv1.weight", "f.5.0.bn1.weight", "f.5.0.bn1.bias", "f.5.0.bn1.running_mean", "f.5.0.bn1.running_var", "f.5.0.conv2.weight", "f.5.0.bn2.weight", "f.5.0.bn2.bias", "f.5.0.bn2.running_mean", "f.5.0.bn2.running_var", "f.5.0.downsample.0.weight", "f.5.0.downsample.1.weight", "f.5.0.downsample.1.bias", "f.5.0.downsample.1.running_mean", "f.5.0.downsample.1.running_var", "f.6.0.conv1.weight", "f.6.0.bn1.weight", "f.6.0.bn1.bias", "f.6.0.bn1.running_mean", "f.6.0.bn1.running_var", "f.6.0.conv2.weight", "f.6.0.bn2.weight", "f.6.0.bn2.bias", "f.6.0.bn2.running_mean", "f.6.0.bn2.running_var", "f.6.0.downsample.0.weight", "f.6.0.downsample.1.weight", "f.6.0.downsample.1.bias", "f.6.0.downsample.1.running_mean", "f.6.0.downsample.1.running_var", "classification_layer.weight", "classification_layer.bias". 
	Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "classifier.weight", "classifier.bias". 

In [None]:
def replace_top_r_percent(a, b, delta, k):
    n = delta.numel()
    top_k = max(1, int(n * (k / 100)))  # Ensure at least one element is selected
    threshold = torch.topk(delta, top_k, sorted=True).values[-1]  # Get the r%-th largest value

    mask = delta >= threshold  # Mask for top r% values
    result = torch.where(mask, b, a)  # Replace selected positions in a with c

    return result

In [None]:
def restore_flat_to_state_dict(flat_grad, model_dict):
    state_dict = {}
    start = 0
    for name, param in model_dict.items():
        num_elements = param.numel()
        state_dict[name] = flat_grad[start:start + num_elements].view(param.shape)
        start += num_elements
    return state_dict

In [None]:
flat_model1 = flat_dict(model1.state_dict())
flat_model2 = flat_dict(model2.state_dict())
flat_delta = flat_dict(delta_model.state_dict())

results = {}
for k in range(0, 101, 2):
    model3.load_state_dict(model1.state_dict())
    crafted_flat = replace_top_r_percent(flat_model1, flat_model2, flat_delta, k)
    restored_crafted = restore_flat_to_state_dict(crafted_flat, model3.state_dict())
    model3.load_state_dict(restored_crafted)
    
    acc1, asr1 = model_eval(model3, test_loader, attack)

    # delta, org_cos
    delta0, org_cos2 = get_delta_cos(model1, model3, model0_sd)
    print(f"k: {k}, acc: {acc1}, asr: {asr1}, cos: {org_cos2}")
    
    results[k] = (acc1, asr1, org_cos2)



k: 0, acc: 0.1978, asr: 0.0, cos: 0.002549290657043457
k: 2, acc: 0.1981, asr: 0.0, cos: 0.01845616102218628
k: 4, acc: 0.1982, asr: 0.0, cos: 0.031197071075439453
k: 6, acc: 0.1982, asr: 0.0, cos: 0.037589848041534424
k: 8, acc: 0.1984, asr: 0.0, cos: 0.05892503261566162
k: 10, acc: 0.1984, asr: 0.0, cos: 0.05892503261566162
k: 12, acc: 0.1997, asr: 0.0, cos: 0.09739971160888672
k: 14, acc: 0.1997, asr: 0.0, cos: 0.09739971160888672
k: 16, acc: 0.1997, asr: 0.0, cos: 0.09739971160888672
k: 18, acc: 0.1997, asr: 0.0, cos: 0.09739971160888672
k: 20, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 22, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 24, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 26, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 28, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 30, acc: 0.2122, asr: 0.001, cos: 0.14976048469543457
k: 32, acc: 0.274, asr: 0.328, cos: 0.2299501895904541
k: 34, acc: 0.274, asr: 0.328, cos: 0.2299501895904541


In [None]:
results

{0: (0.1978, 0.0, 0.002549290657043457),
 2: (0.1981, 0.0, 0.01845616102218628),
 4: (0.1982, 0.0, 0.031197071075439453),
 6: (0.1982, 0.0, 0.037589848041534424),
 8: (0.1984, 0.0, 0.05892503261566162),
 10: (0.1984, 0.0, 0.05892503261566162),
 12: (0.1997, 0.0, 0.09739971160888672),
 14: (0.1997, 0.0, 0.09739971160888672),
 16: (0.1997, 0.0, 0.09739971160888672),
 18: (0.1997, 0.0, 0.09739971160888672),
 20: (0.2122, 0.001, 0.14976048469543457),
 22: (0.2122, 0.001, 0.14976048469543457),
 24: (0.2122, 0.001, 0.14976048469543457),
 26: (0.2122, 0.001, 0.14976048469543457),
 28: (0.2122, 0.001, 0.14976048469543457),
 30: (0.2122, 0.001, 0.14976048469543457),
 32: (0.274, 0.328, 0.2299501895904541),
 34: (0.274, 0.328, 0.2299501895904541),
 36: (0.274, 0.328, 0.2299501895904541),
 38: (0.274, 0.328, 0.2299501895904541),
 40: (0.274, 0.328, 0.2299501895904541),
 42: (0.274, 0.328, 0.2299501895904541),
 44: (0.274, 0.328, 0.2299501895904541),
 46: (0.274, 0.328, 0.2299501895904541),
 48: (

acc: 0.2754, asr: 0.269, cos: 0.22751080989837646
