In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
import os

In [2]:
batch_size = 50
epochs = 5
average_weight = [0.5, 0.5]

root = './data'
if not os.path.exists(root):
    os.mkdir(root)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

subset1, subset2, _ = random_split(mnist, [600, 600, len(mnist)-1200])

loader1 = DataLoader(subset1, batch_size=batch_size, shuffle=True)
loader2 = DataLoader(subset2, batch_size=batch_size, shuffle=True)

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64*24*24, 128)
        self.fc2 = nn.Linear(128, 10)
        self.drop = nn.Dropout(0.8)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.drop(x)
        x = x.view(-1, 64*24*24)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
model1 = SimpleCNN()
model2 = SimpleCNN()

In [5]:
from collections.abc import Iterator
def average_model_parameters(models: Iterator, average_weight: Iterator) -> Iterator:
  averaged_model = SimpleCNN()
  for name, param in averaged_model.named_parameters():
    param.data.copy_(torch.zeros_like(param.data))
    for i, model in enumerate(models):
      param.data += average_weight[i] * model.state_dict()[name]
    yield param

In [6]:
def update_model_parameters(model, new_parameters):
  i = 0
  for name, param in model.named_parameters():
    param.data = torch.nn.Parameter(torch.tensor(new_parameters[i]))
    i += 1

In [7]:
def test_model(model, loader):
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for images, labels in loader:
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  return accuracy

accuracy_model1 = test_model(model1, loader1)
print(f"Accuracy of model1: {accuracy_model1}%")

accuracy_model2 = test_model(model2, loader2)
print(f"Accuracy of model2: {accuracy_model2}%")

Accuracy of model1: 11.666666666666666%
Accuracy of model2: 6.833333333333333%


In [8]:
def algorithm_1(model1, model2, loader1, loader2):
    optimizer1 = optim.Adam(model1.parameters())
    optimizer2 = optim.Adam(model2.parameters())
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model1.train()
        for images, labels in loader1:
            optimizer1.zero_grad()
            outputs = model1(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer1.step()

        model2.train()
        for images, labels in loader2:
            optimizer2.zero_grad()
            outputs = model2(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer2.step()

        averaged_parameters = list(average_model_parameters([model1, model2], average_weight))

        update_model_parameters(model1, averaged_parameters)
        update_model_parameters(model2, averaged_parameters)

    return model1, model2

In [9]:
algorithm_1(model1, model2, loader1, loader2)

  param.data = torch.nn.Parameter(torch.tensor(new_parameters[i]))


(SimpleCNN(
   (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
   (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
   (fc1): Linear(in_features=36864, out_features=128, bias=True)
   (fc2): Linear(in_features=128, out_features=10, bias=True)
   (drop): Dropout(p=0.8, inplace=False)
 ),
 SimpleCNN(
   (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
   (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
   (fc1): Linear(in_features=36864, out_features=128, bias=True)
   (fc2): Linear(in_features=128, out_features=10, bias=True)
   (drop): Dropout(p=0.8, inplace=False)
 ))

In [10]:
accuracy_model1 = test_model(model1, loader1)
print(f"Accuracy of model1: {accuracy_model1}%")

accuracy_model2 = test_model(model2, loader2)
print(f"Accuracy of model2: {accuracy_model2}%")

Accuracy of model1: 93.5%
Accuracy of model2: 92.0%
