In [6]:
import torch.nn as nn
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from utils import *
from torch.utils.data import DataLoader
from models import *
import random
from collections import Counter
from collections import OrderedDict
import seaborn as sns
import copy

cos = nn.CosineSimilarity(dim=0, eps=1e-9)
device = "cuda"

# study 1 model's training with the new loss function with dist limits

# adjustable parameters
alpha_d = 100 # IID
local_ep = 2 
n_clients = 30 # dataset size for one client
mali_local_ep = 10
global attack 
attack = "untargeted" #"backdoor", "tlp", "ut"
model_name = "ConvNet" # "resnet8", "ConvNet"
num_classes = 10
dataset ="fmnist"

In [11]:
def cos_dist(w1, w2):
    """Compute cosine similarity between two flattened weight tensors"""
    w1_flat, w2_flat = torch.cat([p.view(-1) for p in w1]), torch.cat([p.view(-1) for p in w2])
    return 1 - torch.dot(w1_flat, w2_flat) / (torch.norm(w1_flat) * torch.norm(w2_flat))

def get_delta_cos(model1, model2, model0_sd):
    flat_model0 = flat_dict(model0_sd)
    flat_model1 = flat_dict(model1.state_dict())
    flat_model2 = flat_dict(model2.state_dict())
    
    delta = torch.abs(flat_model1 - flat_model2)
    org_cos = cos((flat_model1 - flat_model0), (flat_model2 - flat_model0))
    return delta, 1-org_cos.item()

def model_eval(model, test_loader, attack):
    acc = eval_op_ensemble([model], test_loader)
    if attack == "tlp":
        asr = eval_op_ensemble_tr_lf_attack([model], test_loader)
    elif attack == "backdoor":
        asr = eval_op_ensemble_attack([model], test_loader)
    elif attack == "untargeted":
        asr = None
    return list(acc.values())[0], list(asr.values())[0]

def reverse_train_w_cos(model, loader, optimizer, epochs, model0_sd, model1_sd, beta, budget):    
    model.train()

    grad_ben = (flat_dict(model1_sd) - flat_dict(model0_sd)).to(device)
    
    losses = []
    # import pdb; pdb.set_trace()
    running_loss, samples = 0.0, 0
    for ep in range(epochs):
        for it, (x, y) in enumerate(loader):
            if it % 2 == 0:
                losses.append(round(eval_epoch(model, loader), 2))
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss_ce = nn.CrossEntropyLoss()(model(x), y)
            # in the untraining reverse the sign of loss
            loss_ce = - loss_ce
            running_loss += loss_ce.item() * y.shape[0]
            samples += y.shape[0]
            
            # add cos loss 
            w = torch.cat([p.clone().detach().view(-1) for p in model.parameters()]).to(device)
            grad_mail = w - flat_dict(model0_sd)
            target = torch.ones(len(w)).to(device)
            loss_cos = nn.CosineEmbeddingLoss()(grad_ben.unsqueeze(0), grad_mail.unsqueeze(0), target)
            loss_obj = (1-beta) * loss_ce + beta * loss_cos
            loss_obj.backward()
            optimizer.step()
            print(f"ep{ep}, loss_cs: {loss_ce}, loss_cos: {loss_cos}, loss_obj: {loss_obj}")
        
        # break
        cos_d = cos_dist(grad_ben, grad_mail)
        print("eval losses", losses)
        
        if cos_d <= budget:
            break
        

    return {"loss": running_loss / samples}
    

In [21]:
# Define transformation (convert images to tensors and normalize)
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image with mean and std
])

# Load the training dataset
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Load the test dataset
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoader for batch processing
client_loaders, test_loader, client_data_subsets =\
    data.get_loaders(train_data, test_data, n_clients,
                    alpha=alpha_d, batch_size=32, n_data=None, num_workers=4, seed=4)
    
model_fn = partial(models.get_model(model_name)[
                        0], num_classes=num_classes, dataset=dataset)

client_loader = client_loaders[0]

# created models 
model0 = model_fn().to(device) # orginal model
model1 = model_fn().to(device) # train with clean data
model2 = model_fn().to(device) # train with new loss function
model3 = model_fn().to(device)

model0_sd = {k: v.clone().detach() for k, v in model1.state_dict().items()}

optimizer0 = optim.SGD(model0.parameters(), lr=0.001)
optimizer1 = optim.SGD(model1.parameters(), lr=0.001)
optimizer2 = optim.SGD(model2.parameters(), lr=0.001)
optimizer3 = optim.SGD(model3.parameters(), lr=0.001)



Data split:
 - Client 0: [202 181 238 193 217 187 220 194 177 186]               -> sum=1995
 - Client 1: [223 209 214 184 202 193 219 166 212 177]               -> sum=1999
 - Client 2: [188 214 176 204 179 200 189 220 210 221]               -> sum=2001
 - Client 3: [215 200 212 168 190 202 210 198 214 193]               -> sum=2002
 - Client 4: [188 192 202 235 237 164 203 204 178 196]               -> sum=1999
 - Client 5: [170 181 191 217 198 220 240 192 196 195]               -> sum=2000
 - Client 6: [200 213 181 219 167 233 185 216 186 199]               -> sum=1999
 - Client 7: [186 190 179 204 229 189 231 218 200 173]               -> sum=1999
 - Client 8: [203 189 199 227 226 144 215 194 179 224]               -> sum=2000
 - Client 9: [182 216 188 175 212 229 196 189 226 188]               -> sum=2001
.  .  .  .  .  .  .  .  .  .  
.  .  .  .  .  .  .  .  .  .  
.  .  .  .  .  .  .  .  .  .  
 - Client 21: [223 192 212 202 217 173 215 199 186 181]               -> sum=2000
 - 

In [22]:
# model1 train benign
train_op(model1, client_loader, optimizer1, epochs=local_ep, print_train_loss=True)

model1_sd = {key: value.clone() for key, value in model1.state_dict().items()}

model1_result = eval_op_ensemble([model1], test_loader)
print("model1_result", model1_result)

[2.36, 2.32, 2.27, 2.24, 2.2, 2.16, 2.12, 2.09, 2.06, 2.02, 2.0, 1.97, 1.94, 1.91, 1.89, 1.86, 1.84, 1.81, 1.79, 1.77, 1.75, 1.74, 1.72, 1.7, 1.68, 1.67, 1.65, 1.63, 1.61, 1.6, 1.58, 1.57, 1.57, 1.55, 1.54, 1.52, 1.51, 1.5, 1.49, 1.47, 1.46, 1.45, 1.44, 1.43, 1.42, 1.4, 1.39, 1.38, 1.37, 1.36, 1.36, 1.35, 1.34, 1.33, 1.32, 1.31, 1.3, 1.29, 1.29, 1.28, 1.27, 1.26, 1.26, 1.25]
model1_result {'test_accuracy': 0.7148}


In [23]:
# model2 train with new loss function
model2.load_state_dict(model1_sd)
reverse_train_w_cos(model2, client_loader, optimizer2, epochs=1, 
                                model0_sd = model0_sd, 
                                model1_sd = model1_sd, 
                                beta = 0.5, 
                                budget = 0.4)



ep0, loss_cs: -1.2648513317108154, loss_cos: 0.0, loss_obj: -0.6324256658554077
ep0, loss_cs: -1.2789223194122314, loss_cos: 2.2292137145996094e-05, loss_obj: -0.6394500136375427
ep0, loss_cs: -1.2330875396728516, loss_cos: 5.525350570678711e-05, loss_obj: -0.61651611328125
ep0, loss_cs: -1.2704012393951416, loss_cos: 0.0001049041748046875, loss_obj: -0.6351481676101685
ep0, loss_cs: -1.401615858078003, loss_cos: 0.00016438961029052734, loss_obj: -0.7007257342338562
ep0, loss_cs: -1.2396160364151, loss_cos: 0.00020742416381835938, loss_obj: -0.6197043061256409
ep0, loss_cs: -1.3055747747421265, loss_cos: 0.00024914738605730236, loss_obj: -0.6526628136634827
ep0, loss_cs: -1.231921672821045, loss_cos: 0.0003154277801513672, loss_obj: -0.6158031225204468
ep0, loss_cs: -1.2865276336669922, loss_cos: 0.00036072731018066406, loss_obj: -0.6430834531784058
ep0, loss_cs: -1.1245653629302979, loss_cos: 0.00036036965320818126, loss_obj: -0.5621024966239929
ep0, loss_cs: -1.1078349351882935, loss

{'loss': -1.5893410528512826}

In [28]:

cos_d_model1_2 = cos_dist(flat_dict(model1_sd) - flat_dict(model0_sd), 
                          flat_dict(model2.state_dict()) - flat_dict(model0_sd))
print("model1_2 cos dist", cos_d_model1_2)



model1_2 cos dist tensor(0.1771, device='cuda:0')


In [30]:
model2_result = eval_op_ensemble([model2], test_loader)
print("model2_result", model2_result)

model2_result {'test_accuracy': 0.1008}
