In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy  # Import Syft for Federated Learning

In [2]:
# Define the neural network model (simple example)
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:

# Create a virtual Machine for the server
server = sy.Worker( name="server")

# Create virtual Machines for the clients
client1 = sy.Worker( name="client1")
client2 = sy.Worker( name="client2")

In [44]:
from torch.utils.data import random_split, DataLoader

# Prepare the training and test data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)

# Split the MNIST dataset into training and test sets
train_size = int(0.9 * len(mnist_dataset))  # 90% for training
test_size = len(mnist_dataset) - train_size  # 10% for testing
train_dataset, test_dataset = random_split(mnist_dataset, [train_size, test_size])

# Split the training data between the clients
client1_data, client2_data = random_split(train_dataset, [len(train_dataset) // 2, len(train_dataset) // 2])

# Load data into DataLoaders for clients
client1_loader = DataLoader(client1_data, batch_size=32, shuffle=True)
client2_loader = DataLoader(client2_data, batch_size=32, shuffle=True)

# Load test data into DataLoader
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [45]:
# Training function for each client
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    model.to(device)  # Move the model to the specified device
    correct = 0
    total = 0
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)  # Move data and target to the device
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(output, 1)
        correct += (predicted == target).sum().item()
        total += target.size(0)
        running_loss += loss.item()

    accuracy = 100 * correct / total
    avg_loss = running_loss / len(data_loader)
    return model, avg_loss, accuracy

In [46]:
# Function to evaluate the global model on the test set
def evaluate(model, data_loader, criterion, device):
    model.eval()
    model.to(device)
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(output, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)

    accuracy = 100 * correct / total
    avg_loss = test_loss / len(data_loader)
    return avg_loss, accuracy

In [47]:
# Initialize the global model
global_model = SimpleNN()
device = torch.device('cpu')  # Change to 'cuda' if using a GPU

# Federated learning loop
for epoch in range(5):  # Number of global epochs
    print(f"Epoch {epoch + 1}")

    # Distribute model to clients
    client_models = [SimpleNN() for _ in range(2)]  # Instantiate a new model for each client

    # Load the initial state of the global model into each client model
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    optimizers = [optim.SGD(model.parameters(), lr=0.01) for model in client_models]
    criterion = nn.CrossEntropyLoss()

    # Track metrics across all clients
    total_loss = 0.0
    total_accuracy = 0.0

    # Train the model on each client
    for client_model, client_loader, client_optimizer in zip(client_models, [client1_loader, client2_loader], optimizers):
        client_model, avg_loss, accuracy = train(client_model, client_loader, client_optimizer, criterion, device)
        total_loss += avg_loss
        total_accuracy += accuracy

    # Average the weights
    with torch.no_grad():
        global_state_dict = global_model.state_dict()
        for key in global_state_dict:
            global_state_dict[key] = torch.stack([client_model.state_dict()[key].float() for client_model in client_models]).mean(0)

    global_model.load_state_dict(global_state_dict)
    
    # Calculate average metrics across all clients
    avg_loss = total_loss / len(client_models)
    avg_accuracy = total_accuracy / len(client_models)

    print(f"Updated global model | Avg Loss: {avg_loss:.4f} | Avg Accuracy: {avg_accuracy:.2f}%")

    # Evaluate the global model on the test set
    test_loss, test_accuracy = evaluate(global_model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.2f}%")

print("Training complete.")

Epoch 1
Updated global model | Avg Loss: 0.7961 | Avg Accuracy: 79.95%
Test Loss: 0.4199 | Test Accuracy: 88.20%
Epoch 2
Updated global model | Avg Loss: 0.3829 | Avg Accuracy: 89.16%
Test Loss: 0.3382 | Test Accuracy: 89.95%
Epoch 3
Updated global model | Avg Loss: 0.3326 | Avg Accuracy: 90.42%
Test Loss: 0.3047 | Test Accuracy: 90.85%
Epoch 4
Updated global model | Avg Loss: 0.3041 | Avg Accuracy: 91.13%
Test Loss: 0.2861 | Test Accuracy: 91.37%
Epoch 5
Updated global model | Avg Loss: 0.2817 | Avg Accuracy: 91.83%
Test Loss: 0.2640 | Test Accuracy: 92.05%
Training complete.
