In [1]:
############-----CW different app.2------###########

In [2]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
from torchvision import models

In [3]:
# Load the CIFAR-10 data
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
# Define the model architecture
# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(3, 6, 5)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)

#     def forward(self, x):
#         x = self.pool(nn.functional.relu(self.conv1(x)))
#         x = self.pool(nn.functional.relu(self.conv2(x)))
#         x = x.view(-1, 16 * 5 * 5)
#         x = nn.functional.relu(self.fc1(x))
#         x = nn.functional.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x


class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.fc = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x


# class DenseNet(nn.Module):
#     def __init__(self):
#         super(DenseNet, self).__init__()
#         self.densenet = models.densenet161(pretrained=True)
#         self.fc = nn.Linear(1000, 10)

#     def forward(self, x):
#         x = self.densenet(x)
#         x = self.fc(x)
#         return x


# model = torchvision.models.resnet18(pretrained=True)
# model.eval()
model=ResNet18()
model.eval()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [5]:
# Define the targeted CW attack
def targeted_cw_attack(model, images, labels, target_label, max_iterations=1000, learning_rate=0.01):
    # Initialize the perturbation
    delta = torch.zeros_like(images).cuda().requires_grad_()
    # Define the loss function
    criterion = torch.nn.CrossEntropyLoss()
    # Define the target label tensor
    target = torch.tensor([target_label] * len(labels)).cuda()
    # Loop over the maximum number of iterations
    for i in range(max_iterations):
        # Zero the gradients
        model.zero_grad()
        # Forward pass
        output = model(images + delta)
        # Calculate the loss
        loss = -torch.mean(torch.sum(torch.nn.functional.one_hot(target, num_classes=10) * output, dim=1))
        # Calculate the gradients
        loss.backward()
        # Update the perturbation
        delta.data = delta + learning_rate * delta.grad.detach().sign()
        delta.data = torch.clamp(delta, min=-0.5, max=0.5)
        delta.data = torch.min(torch.max(delta.data, images - 1), images + 1) - images
    # Return the adversarial images
    return (images + delta).detach()

In [6]:
# Define the hyperparameters
num_clients = 10
num_epochs = 10

# Split the training data into clients
client_data = torch.utils.data.random_split(trainset, [len(trainset) // num_clients] * num_clients)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [None]:
# Train the model using federated learning
for epoch in range(num_epochs):
    print("Epoch:", epoch+1)
    # Initialize the list of gradients and total loss
    gradients = [None] * num_clients
    total_loss = 0
    # Loop over the clients
    for i in range(num_clients):
        # Initialize the data loader for this client
        trainloader = torch.utils.data.DataLoader(client_data[i], batch_size=32, shuffle=True, num_workers=2)
        # Initialize the gradients for this client
        gradients[i] = []
        # Loop over the data for this client
        for j, (images, labels) in enumerate(trainloader):
            # Generate targeted adversarial examples for this batch of images
            adversarial_images = targeted_cw_attack(model, images.cuda(), labels.cuda(), target_label=0)
            # Forward pass
            output = model(adversarial_images)
            # Calculate the loss
            loss = torch.nn.functional.cross_entropy(output, labels.cuda())
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            gradients[i].append([param.grad.clone() for param in model.parameters()])
            optimizer.zero_grad()
            total_loss += loss.item()
        # Average the gradients for this client
        for j in range(len(gradients[i])):
            for k, param in enumerate(model.parameters()):
                param.grad = gradients[i][j][k] / len(gradients[i])
    # Combine the gradients from all clients and update the model
    optimizer.step()
    # Print the average loss and accuracy for this epoch
    avg_loss = total_loss / len(trainset)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            output = model(images.cuda())
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()
        print('Epoch: %d, Loss: %.4f, Test Accuracy: %.2f %%' % (epoch + 1, avg_loss, 100 * correct / total))


Epoch: 1
Epoch: 1, Loss: 0.0724, Test Accuracy: 10.02 %
Epoch: 2
Epoch: 2, Loss: 0.0724, Test Accuracy: 10.02 %
Epoch: 3
Epoch: 3, Loss: 0.0724, Test Accuracy: 10.02 %
Epoch: 4
