In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
import torch
from torch.utils.data import DataLoader, TensorDataset

def generate_data_loaders(n_clients, X, y, batch_size=32, shuffle=True):
    # Split the data into n_clients subsets
    X_split = torch.chunk(X, n_clients)
    y_split = torch.chunk(y, n_clients)

    # Create a data loader for each client
    data_loaders = []
    for i in range(n_clients):
        dataset = TensorDataset(X_split[i], y_split[i])
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
        data_loaders.append(data_loader)

    return data_loaders

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from mecvae.federated import FederatedTrainer

# Generate synthetic data
X, y = make_regression(n_samples=1000, n_features=10, random_state=42)

# Convert data to PyTorch tensors
X_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).float()

# Define the linear regression model
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# Split the data into train and test sets
train_ratio = 0.8
train_size = int(train_ratio * len(X))
X_train, X_test = X_tensor[:train_size], X_tensor[train_size:]
y_train, y_test = y_tensor[:train_size], y_tensor[train_size:]

# Generate data loaders for n_clients
n_clients = 5
train_loaders = generate_data_loaders(n_clients, X_train, y_train, batch_size=32, shuffle=True)

In [17]:
# Initialize the linear regression model, criterion, optimizer, and learning rate
model = LinearRegression()
criterion = nn.MSELoss()
optimizer_fn = optim.SGD
learning_rate = 0.01

# Initialize the federated trainer
federated_trainer = FederatedTrainer(model, criterion, optimizer_fn, learning_rate, n_clients, patience=3)

# Train the model using federated learning
federated_trainer.train(train_loaders, [X_test, y_test], grad_steps='auto')

  return F.mse_loss(input, target, reduction=self.reduction)


ValueError: too many values to unpack (expected 2)

In [None]:
# Get the aggregated model after federated averaging
aggregated_model = federated_trainer.get_aggregated_model()

# Test the aggregated model
aggregated_model.eval()
with torch.no_grad():
    y_pred = aggregated_model(X_test_tensor)
    test_loss = criterion(y_pred, y_test_tensor)
    print(f"Test Loss: {test_loss.item():.4f}")

# Plot train losses
train_losses = federated_trainer.get_train_losses()
plt.figure(figsize=(10, 6))
sns.lineplot(data=train_losses, x='Epoch', y='Loss', hue='Client')
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Train Losses')
plt.grid(True)
plt.show()

# Plot test losses
test_losses = federated_trainer.get_test_losses()
plt.figure(figsize=(10, 6))
sns.lineplot(data=test_losses, x='Epoch', y='Loss', hue='Client')
plt.xlabel('Epoch')
plt.ylabel('Test Loss')
plt.title('Test Losses')
plt.grid(True)
plt.show()

In [50]:
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from copy import deepcopy
from torch.utils.data.dataset import random_split
from torch.utils.tensorboard import SummaryWriter


class Client:
    def __init__(self, dataset, model, loss_fn, num_epochs, client_id, lr=1e-2):
        self.dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
        self.model = model
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        self.loss_fn = loss_fn
        self.num_epochs = num_epochs
        self.writer = SummaryWriter(f'runs/Client_{client_id}')
        self.iteration = 0

    def train(self):
        self.model.train()
        total_loss = 0
        for X, y in self.dataloader:
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            self.writer.add_scalar('Loss/train', loss.item(), self.iteration)
            self.iteration += 1

        return self.model.state_dict(), len(self.dataloader)

class Server:
    clients = None

    def __init__(self, model, dataset, loss_fn, lr=1e-3, validation_split=0.2, early_stopping=True, patience=5):
        self.model = model
        self.loss_fn = loss_fn
        self.lr = lr
        self.early_stopping = early_stopping
        self.patience = patience
        self.writer = SummaryWriter('runs/Server')

        num_validation = int(len(dataset) * validation_split)
        num_train = len(dataset) - num_validation
        self.train_dataset, self.val_dataset = random_split(dataset, [num_train, num_validation])

        self.best_val_loss = float('inf')
        self.wait = 0  # for early stopping

    def create_clients(self, num_clients):
        self.clients = [Client(self.train_dataset, deepcopy(self.model), self.loss_fn, num_epochs=1, client_id=i+1, lr=self.lr) for i in range(num_clients)]

    def federated_learning(self, num_rounds):
        with tqdm(range(num_rounds), desc='Rounds') as pbar:
            for round in pbar:
                global_weights = []
                global_batches = 0
                for client in self.clients:
                    client_weights, num_batches = client.train()
                    global_weights.append((client_weights, num_batches))
                    global_batches += num_batches

                # federated averaging
                new_state_dict = {}
                for key in self.model.state_dict().keys():
                    global_sum = sum([client_weights[key] * num_batches for client_weights, num_batches in global_weights])
                    new_state_dict[key] = global_sum / global_batches
                self.model.load_state_dict(new_state_dict)

                # validate the model
                val_loss = self.validate()

                # Update progress bar description
                pbar.set_description(f"Round: {round+1}, Validation Loss: {val_loss:.4f}")
                pbar.update()

                self.writer.add_scalar('Loss/validation', val_loss, round)

                # update client models
                for client in self.clients:
                    client.model.load_state_dict(new_state_dict)

                # early stopping
                if self.early_stopping:
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.wait = 0
                    else:
                        self.wait += 1
                    if self.wait >= self.patience:
                        print("Early stopping!")
                        break


    def validate(self):
        val_loader = DataLoader(self.val_dataset, batch_size=32, shuffle=False)
        total_loss = 0
        with torch.no_grad():
            for X, y in val_loader:
                pred = self.model(X)
                loss = self.loss_fn(pred, y)
                total_loss += loss.item() * len(X)
        return total_loss / len(self.val_dataset)



In [51]:
!rm -rf runs/*

In [54]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from copy import deepcopy

# # Define the CNN architecture
# class CNN(nn.Module):
#     def __init__(self):
#         super(CNN, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
#         self.fc1 = nn.Linear(7*7*64, 128)
#         self.fc2 = nn.Linear(128, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = nn.functional.relu(x)
#         x = nn.functional.max_pool2d(x, 2)
#         x = self.conv2(x)
#         x = nn.functional.relu(x)
#         x = nn.functional.max_pool2d(x, 2)
#         x = x.view(x.size(0), -1) # Flatten layer
#         x = self.fc1(x)
#         x = nn.functional.relu(x)
#         x = self.fc2(x)
#         return x

# Define the model class
class SimpleClassifier(nn.Module):
    def __init__(self):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layer
        self.fc2 = nn.Linear(128, 10)   # Hidden layer to output layer

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = torch.relu(self.fc1(x))  # Apply ReLU activation to hidden layer
        x = self.fc2(x)  # Output layer
        return x

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Normalizing with MNIST mean and std
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Initialize the server
model = SimpleClassifier()
loss_fn = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.001)
server = Server(model, dataset, loss_fn, lr=1e-3)

# Create the clients
num_clients = 5
num_samples_per_client = len(dataset) // num_clients
client_datasets = random_split(dataset, [num_samples_per_client]*num_clients)
print([len(cd) for cd in client_datasets])
server.create_clients(num_clients)

for i, client in enumerate(server.clients):
    client.dataloader = DataLoader(client_datasets[i], batch_size=32, shuffle=True)

# Run federated learning
server.federated_learning(num_rounds=100)


[12000, 12000, 12000, 12000, 12000]


Round: 100, Validation Loss: 0.2066: 100%|██████████| 100/100 [07:12<00:00,  4.32s/it]


In [48]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

# Define the data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Create an instance of the model
model = SimpleClassifier()

# Define the loss criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    # Print the training loss for each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

# Evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy}%")

Epoch [1/10], Loss: 0.5492603778839111
Epoch [2/10], Loss: 0.3300962448120117
Epoch [3/10], Loss: 0.3385804295539856
Epoch [4/10], Loss: 0.1776585876941681
Epoch [5/10], Loss: 0.4287305176258087
Epoch [6/10], Loss: 0.19159863889217377


KeyboardInterrupt: 