In [1]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random
from torchvision import transforms

# FEMNIST 데이터셋 정의
class FEMNISTDataset(Dataset):
    def __init__(self, data_path):
        self.data = []
        self.labels = []

        # JSON 파일 로드
        with open(data_path, 'r') as f:
            raw_data = json.load(f)

        # 데이터와 레이블 추출
        for user_data in raw_data['user_data'].values():
            self.data.extend(user_data['x'])
            self.labels.extend(user_data['y'])

        # 이미지 전처리
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize: 픽셀 값을 -1 ~ 1로 스케일링
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = np.array(self.data[idx], dtype=np.float32).reshape(28, 28)
        label = self.labels[idx]
        image = self.transform(image)
        return image, label

# 간단한 CNN 모델 정의
def SimpleCNN():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Flatten(),
        nn.Linear(64 * 7 * 7, 128),
        nn.ReLU(),
        nn.Linear(128, 62)
    )

# Federated SGD Implementation
def federated_sgd(global_model, client_loaders, num_rounds, client_fraction, lr):
    for round_num in range(num_rounds):
        print(f"\n=== Round {round_num + 1}/{num_rounds} - FedSGD ===")

        # 2. Client Sampling
        sampled_clients = np.random.choice(
            range(len(client_loaders)),
            int(client_fraction * len(client_loaders)),
            replace=False
        )
        print(f"Sampled Clients: {sampled_clients}")

        global_gradient = None

        # 3. Local Learning (Compute Gradients)
        for client_idx in sampled_clients:
            client_loader = client_loaders[client_idx]
            model = SimpleCNN()
            model.load_state_dict(global_model.state_dict())

            optimizer = optim.SGD(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()

            model.train()
            for batch in client_loader:
                images, labels = batch
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                break  # Only one mini-batch for FedSGD

            # Aggregate gradients
            if global_gradient is None:
                global_gradient = [param.grad.clone() for param in model.parameters()]
            else:
                for g, param in zip(global_gradient, model.parameters()):
                    g += param.grad

        # 4. Update Global Parameters
        for param, grad in zip(global_model.parameters(), global_gradient):
            param.data -= lr * (grad / len(sampled_clients))

# Federated Averaging Implementation
def federated_avg(global_model, client_loaders, num_rounds, client_fraction, num_epochs, lr):
    for round_num in range(num_rounds):
        print(f"\n=== Round {round_num + 1}/{num_rounds} - FedAvg ===")

        # 2. Client Sampling
        sampled_clients = np.random.choice(
            range(len(client_loaders)),
            int(client_fraction * len(client_loaders)),
            replace=False
        )
        print(f"Sampled Clients: {sampled_clients}")

        client_weights = []

        # 3. Local Learning (Compute Weights)
        for client_idx in sampled_clients:
            client_loader = client_loaders[client_idx]
            model = SimpleCNN()
            model.load_state_dict(global_model.state_dict())

            optimizer = optim.SGD(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()

            model.train()
            for _ in range(num_epochs):
                for batch in client_loader:
                    images, labels = batch
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

            client_weights.append({k: v.clone() for k, v in model.state_dict().items()})

        # 4. Update Global Parameters (Weighted Average)
        global_state_dict = global_model.state_dict()
        for key in global_state_dict.keys():
            global_state_dict[key] = sum([client[key] for client in client_weights]) / len(client_weights)
        global_model.load_state_dict(global_state_dict)

# 테스트 데이터 평가 함수
def evaluate_model(model, test_loader):
    model.eval()
    total, correct = 0, 0
    client_accuracies = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            accuracy = 100 * (predicted == labels).sum().item() / labels.size(0)
            client_accuracies.append(accuracy)
            
    overall_accuracy = 100 * correct / total
    mean_accuracy = np.mean(client_accuracies)
    median_accuracy = np.median(client_accuracies)
    max_accuracy = np.max(client_accuracies)
    
    print(f"Test Accuracy: {accuracy:.2f}%")
    print(f"Mean Client Accuracy: {mean_accuracy:.2f}%")
    print(f"Median Client Accuracy: {median_accuracy:.2f}%")
    print(f"Max Client Accuracy: {max_accuracy:.2f}%")
    
# Main
if __name__ == "__main__":
    # 클라이언트별 데이터 경로
    client_data_paths = [
        'C:/Users/ahsld/leaf/data/femnist/data/train/all_data_0_niid_05_keep_5_train_9.json',
        'C:/Users/ahsld/leaf/data/femnist/data/train/all_data_1_niid_05_keep_5_train_9.json',
        'C:/Users/ahsld/leaf/data/femnist/data/train/all_data_2_niid_05_keep_5_train_9.json',
        'C:/Users/ahsld/leaf/data/femnist/data/train/all_data_3_niid_05_keep_5_train_9.json',
        'C:/Users/ahsld/leaf/data/femnist/data/train/all_data_4_niid_05_keep_5_train_9.json'
    ]

    # 클라이언트별 DataLoader 생성
    client_loaders = [DataLoader(FEMNISTDataset(path), batch_size=32, shuffle=True) for path in client_data_paths]

    # 테스트 데이터 경로 및 로더 생성
    test_data_path = 'C:/Users/ahsld/leaf/data/femnist/data/test/all_data_0_niid_05_keep_5_test_9.json'
    test_loader = DataLoader(FEMNISTDataset(test_data_path), batch_size=32)

    # 글로벌 모델 초기화
    SGD_Global = SimpleCNN()
    AVG_Global = SimpleCNN()

    # Federated SGD
    print("Starting Federated SGD")
    federated_sgd(SGD_Global, client_loaders, num_rounds=25, client_fraction=0.75, lr=0.01)

    # Federated Averaging
    print("\nStarting Federated Averaging")
    federated_avg(AVG_Global, client_loaders, num_rounds=25, client_fraction=0.75, num_epochs=5, lr=0.01)

    # 모델 평가
    print("\nEvaluating Federated SGD Model")
    evaluate_model(SGD_Global, test_loader)

    print("\nEvaluating Federated AVG Model")
    evaluate_model(AVG_Global, test_loader)


Starting Federated SGD

=== Round 1/25 - FedSGD ===
Sampled Clients: [3 2 1]

=== Round 2/25 - FedSGD ===
Sampled Clients: [1 2 0]

=== Round 3/25 - FedSGD ===
Sampled Clients: [4 3 0]

=== Round 4/25 - FedSGD ===
Sampled Clients: [3 2 1]

=== Round 5/25 - FedSGD ===
Sampled Clients: [3 1 2]

=== Round 6/25 - FedSGD ===
Sampled Clients: [4 0 2]

=== Round 7/25 - FedSGD ===
Sampled Clients: [2 1 4]

=== Round 8/25 - FedSGD ===
Sampled Clients: [1 0 4]

=== Round 9/25 - FedSGD ===
Sampled Clients: [0 4 3]

=== Round 10/25 - FedSGD ===
Sampled Clients: [2 1 0]

=== Round 11/25 - FedSGD ===
Sampled Clients: [1 0 4]

=== Round 12/25 - FedSGD ===
Sampled Clients: [4 0 1]

=== Round 13/25 - FedSGD ===
Sampled Clients: [4 0 2]

=== Round 14/25 - FedSGD ===
Sampled Clients: [2 0 3]

=== Round 15/25 - FedSGD ===
Sampled Clients: [3 4 2]

=== Round 16/25 - FedSGD ===
Sampled Clients: [0 2 3]

=== Round 17/25 - FedSGD ===
Sampled Clients: [3 1 0]

=== Round 18/25 - FedSGD ===
Sampled Clients: [0 3