In [10]:
    
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  

# 1. Load FashionMNIST


In [11]:
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)

test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

  


100.0%
100.0%
100.0%
100.0%


# 2. Simple CNN & Check


In [12]:
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(1024, 512) 
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 1024) 
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Quick check on subset
subset = Subset(train_dataset, range(200))
loader = DataLoader(subset, batch_size=10)
model = Net().to(device)
opt = optim.SGD(model.parameters(), lr=0.01)
crit = nn.CrossEntropyLoss()

model.train()
for epoch in range(2):
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = crit(out, y)
        loss.backward()
        opt.step()
    print(f"Check epoch {epoch}, loss: {loss.item():.3f}")

  


Check epoch 0, loss: 2.258
Check epoch 1, loss: 2.191


# 3. Average function


In [13]:
    
def average_model_parameters(models, weights):
    avg_params = copy.deepcopy(models[0].state_dict())
    
    with torch.no_grad():
        for key in avg_params:
            avg_params[key] = avg_params[key] * weights[0]
            
        for i in range(1, len(models)):
            params = models[i].state_dict()
            for key in avg_params:
                avg_params[key] += params[key] * weights[i]
                
    return avg_params

  


# 4. Algo 1 helpers


In [14]:
    
def client_update(model, train_loader, epochs=1):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    criterion = nn.CrossEntropyLoss()
    
    for _ in range(epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model

def test(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            out = model(data)
            pred = out.argmax(1)
            correct += pred.eq(target).sum().item()
    return correct / len(loader.dataset)

  


# 5. Independent training (Fail case)


In [15]:
    
idx = list(range(len(train_dataset)))
# 2 clients, 600 pts each
ds1 = Subset(train_dataset, idx[0:600])
ds2 = Subset(train_dataset, idx[600:1200])
ld1 = DataLoader(ds1, batch_size=50, shuffle=True)
ld2 = DataLoader(ds2, batch_size=50, shuffle=True)

# Diff init
net1 = Net().to(device)
net2 = Net().to(device)

# Train
net1 = client_update(net1, ld1, epochs=5)
net2 = client_update(net2, ld2, epochs=5)

print(f"Net1 Acc: {test(net1, test_loader):.3f}")
print(f"Net2 Acc: {test(net2, test_loader):.3f}")

# Avg
avg_state = average_model_parameters([net1, net2], [0.5, 0.5])
global_net = Net().to(device)
global_net.load_state_dict(avg_state)

print(f"Avg Model Acc (Indep Init): {test(global_net, test_loader):.3f}")

  


Net1 Acc: 0.700
Net2 Acc: 0.624
Avg Model Acc (Indep Init): 0.395


Why it does not work:
The averaging fails because the two models were initialized randomly and independently. In deep learning, the loss landscape is non-convex. Since net1 and net2 started from different points, they converged to different local minima. Averaging their parameters results in a point in the parameter space that lies somewhere between these two minima, which typically corresponds to a region with high loss (a "bad" model). As stated in the paper, we need a common initialization to ensure the models stay in the same basin of attraction.


# 6. Common Initialization


In [16]:
    
# Same split as before
global_init = Net().to(device)
init_state = global_init.state_dict()

net1 = Net().to(device) 
net1.load_state_dict(copy.deepcopy(init_state))
net2 = Net().to(device)
net2.load_state_dict(copy.deepcopy(init_state))

net1 = client_update(net1, ld1, epochs=5)
net2 = client_update(net2, ld2, epochs=5)

avg_state = average_model_parameters([net1, net2], [0.5, 0.5])
global_init.load_state_dict(avg_state)

print(f"Avg Model Acc (Common Init): {test(global_init, test_loader):.3f}")

  


Avg Model Acc (Common Init): 0.672


# 7. & 8. Study: Data points vs Performance


In [17]:
    
def run_fed_avg(n_models, pts_per_model, rounds=5):
    all_idx = np.random.permutation(len(train_dataset))
    
    loaders = []
    for i in range(n_models):
        subset_idx = all_idx[i*pts_per_model : (i+1)*pts_per_model]
        # batch size 50 or full batch if small data
        bs = min(50, pts_per_model)
        loaders.append(DataLoader(Subset(train_dataset, subset_idx), batch_size=bs, shuffle=True))
        
    global_model = Net().to(device)
    
    for r in range(rounds):
        w_global = global_model.state_dict()
        local_models = []
        
        for k in range(n_models):
            m = Net().to(device)
            m.load_state_dict(copy.deepcopy(w_global))
            m = client_update(m, loaders[k], epochs=5)
            local_models.append(m)
            
        # uniform weights
        weights = [1.0/n_models] * n_models
        new_state = average_model_parameters(local_models, weights)
        global_model.load_state_dict(new_state)
        
    return test(global_model, test_loader)

# Run loops
res = []
models_list = [2, 3, 5]
points_list = [25, 50, 100, 200, 500]

print("Running study...")
for m in models_list:
    for p in points_list:
        acc = run_fed_avg(m, p, rounds=10) # 10 rounds to be sure
        res.append({'models': m, 'points': p, 'acc': acc})
        print(f"M={m}, P={p} -> {acc:.3f}")

df = pd.DataFrame(res)
pivoted = df.pivot(index='points', columns='models', values='acc')
print("\n")
print(pivoted)

  


Running study...
M=2, P=25 -> 0.610
M=2, P=50 -> 0.658
M=2, P=100 -> 0.690
M=2, P=200 -> 0.733
M=2, P=500 -> 0.769
M=3, P=25 -> 0.571
M=3, P=50 -> 0.655
M=3, P=100 -> 0.701
M=3, P=200 -> 0.757
M=3, P=500 -> 0.819
M=5, P=25 -> 0.591
M=5, P=50 -> 0.650
M=5, P=100 -> 0.721
M=5, P=200 -> 0.754
M=5, P=500 -> 0.823


models       2       3       5
points                        
25      0.6095  0.5714  0.5913
50      0.6580  0.6548  0.6501
100     0.6902  0.7014  0.7212
200     0.7328  0.7572  0.7540
500     0.7695  0.8185  0.8228


# 9. Repeat on CIFAR-10

In [None]:
    
print("Downloading/Loading CIFAR-10...")

transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_cifar = datasets.CIFAR10('./data', train=True, download=True, transform=transform_cifar)
test_cifar = datasets.CIFAR10('./data', train=False, download=True, transform=transform_cifar)
test_loader_cifar = DataLoader(test_cifar, batch_size=1000, shuffle=False)

class NetRGB(nn.Module):
    def __init__(self):
        super(NetRGB, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(64 * 5 * 5, 512) 
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def run_cifar_study(n_models, pts_per_model, rounds=5):
    indices = np.random.permutation(len(train_cifar))
    loaders = []
    
    for i in range(n_models):
        idx = indices[i*pts_per_model : (i+1)*pts_per_model]
        loaders.append(DataLoader(Subset(train_cifar, idx), batch_size=min(50, pts_per_model), shuffle=True))
        
    global_net = NetRGB().to(device)
    
    for r in range(rounds):
        w_global = global_net.state_dict()
        locals_list = []
        
        for k in range(n_models):
            m = NetRGB().to(device)
            m.load_state_dict(copy.deepcopy(w_global))
            
            m.train()
            opt = optim.SGD(m.parameters(), lr=0.01)
            crit = nn.CrossEntropyLoss()
            
            for _ in range(5):
                for x, y in loaders[k]:
                    x, y = x.to(device), y.to(device)
                    opt.zero_grad()
                    loss = crit(m(x), y)
                    loss.backward()
                    opt.step()
            locals_list.append(m)
        
        avg = average_model_parameters(locals_list, [1.0/n_models]*n_models)
        global_net.load_state_dict(avg)
        
    global_net.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_loader_cifar:
            x, y = x.to(device), y.to(device)
            out = global_net(x)
            pred = out.argmax(1)
            correct += pred.eq(y).sum().item()
    return correct / len(test_cifar)

print("\nRunning CIFAR-10 Study...")
res_cifar = []
models_c = [2, 5] 
points_c = [100, 500] 

for m in models_c:
    for p in points_c:
        acc = run_cifar_study(m, p, rounds=10)
        res_cifar.append({'models': m, 'points': p, 'acc': acc})
        print(f"CIFAR Models={m}, Points={p} -> Acc: {acc:.3f}")

print("\nCIFAR Results:")
print(pd.DataFrame(res_cifar).pivot(index='points', columns='models', values='acc'))

  


Downloading/Loading CIFAR-10...


100.0%
  entry = pickle.load(f, encoding="latin1")



Running CIFAR-10 Study (Takes longer)...
CIFAR Models=2, Points=100 -> Acc: 0.100
CIFAR Models=2, Points=500 -> Acc: 0.301
CIFAR Models=5, Points=100 -> Acc: 0.165
CIFAR Models=5, Points=500 -> Acc: 0.314

CIFAR Results:
models       2       5
points                
100     0.1001  0.1647
500     0.3007  0.3141
