In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

# Simple logistic regression model
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# Simulate client training
def client_update(model, data, labels, epochs=5, lr=0.01):
    model = LogisticRegressionModel(data.shape[1])
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return model.state_dict()

# Server-side federated averaging
def federated_avg(models):
    avg_model = models[0]
    for key in avg_model:
        for i in range(1, len(models)):
            avg_model[key] += models[i][key]
        avg_model[key] = avg_model[key] / len(models)
    return avg_model

# Simulate 3 clients with separate data
def generate_client_data(n_clients=3, n_samples=500, input_dim=10):
    X, y = make_classification(n_samples=n_clients * n_samples, n_features=input_dim, n_classes=2)
    X = StandardScaler().fit_transform(X)
    data_splits = np.array_split(X, n_clients)
    label_splits = np.array_split(y, n_clients)
    return [(torch.tensor(d, dtype=torch.float32), torch.tensor(l, dtype=torch.float32).view(-1,1)) for d, l in zip(data_splits, label_splits)]

# Main federated learning simulation
def main():
    clients_data = generate_client_data()
    input_dim = clients_data[0][0].shape[1]

    local_models = []
    for client_id, (data, labels) in enumerate(clients_data):
        print(f"Training client {client_id+1}")
        model = LogisticRegressionModel(input_dim)
        state_dict = client_update(model, data, labels)
        local_models.append(state_dict)

    global_model = LogisticRegressionModel(input_dim)
    averaged_weights = federated_avg(local_models)
    global_model.load_state_dict(averaged_weights)

    # Evaluate on all data combined
    all_data = torch.cat([d for d, _ in clients_data])
    all_labels = torch.cat([l for _, l in clients_data])
    with torch.no_grad():
        outputs = global_model(all_data)
        predictions = (outputs > 0.5).float()
        accuracy = (predictions == all_labels).float().mean()
    print(f"\nGlobal model accuracy: {accuracy.item():.4f}")

if __name__ == "__main__":
    main()


Training client 1
Training client 2
Training client 3

Global model accuracy: 0.5180
