In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import copy

In [6]:
# Transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split the dataset into parts for each client
num_clients = 4
client_datasets = random_split(train_dataset, [15000, 15000, 15000, 15000, len(train_dataset) - 60000])

# Create DataLoader for each client
client_loaders = [DataLoader(ds, batch_size=64, shuffle=True) for ds in client_datasets[:num_clients]]
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
global_model = Net()
criterion = nn.CrossEntropyLoss()

In [9]:
# Training parameters
epochs = 5
lr = 0.01

def train_on_client(model, data_loader, criterion, lr):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    return model.state_dict(), loss.item()

def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

for epoch in range(epochs):
    local_weights = []
    local_losses = []

    for client_loader in client_loaders:
        client_model = copy.deepcopy(global_model)
        client_weights, client_loss = train_on_client(client_model, client_loader, criterion, lr)
        local_weights.append(client_weights)
        local_losses.append(client_loss)

    # Update global model weights
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    # Print average loss for the epoch
    avg_loss = sum(local_losses) / len(local_losses)
    print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')

print("Training complete.")

Epoch 1, Average Loss: 1.5306
Epoch 2, Average Loss: 0.7122
Epoch 3, Average Loss: 0.5323
Epoch 4, Average Loss: 0.5371
Epoch 5, Average Loss: 0.4656
Training complete.


In [10]:
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        output = global_model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy: {accuracy:.2f}%')

Accuracy: 89.13%
