In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# Global parameters for the learning process
learning_rate = 0.01
gamma = 0.5  # Balance between client and server losses
lambda_value = 0.2  # Personalization vs. generalization balance
num_clients = 50
num_global_rounds = 120
device = torch.device("cpu")  # Explicitly using CPU for demonstration

# MNIST Data loading and preparation
def load_mnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    num_shards_per_client = 2
    shards_idx = np.random.permutation(100)
    client_data_idx = [np.concatenate([np.where(np.array(train_dataset.targets) == (s % 10))[0][:600] for s in shards_idx[i*num_shards_per_client:(i+1)*num_shards_per_client]]) for i in range(num_clients)]

    client_datasets = [DataLoader(Subset(train_dataset, indices), batch_size=50, shuffle=True) for indices in client_data_idx]

    # New part: Printing the label distribution per client
    for i, indices in enumerate(client_data_idx):
        labels = [int(train_dataset.targets[idx]) for idx in indices]
        label_distribution = {label: labels.count(label) for label in set(labels)}
        print(f"Client {i+1}: {label_distribution}")

    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
    return client_datasets, test_loader

# Model Definitions
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.conv_layers(x)

class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc_input_size = 64 * 7 * 7  # Example: Adjust based on ClientModel's output
        self.fc_layers = nn.Sequential(
            nn.Linear(self.fc_input_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 10)
        )

    def forward(self, x):
        x = x.view(-1, self.fc_input_size)  # Reshape the input appropriately
        return self.fc_layers(x)

class AuxiliaryClassifier(nn.Module):
    def __init__(self):
        super(AuxiliaryClassifier, self).__init__()
        self.classifier = nn.Linear(64*7*7, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# Loss function definitions
def client_side_loss(client_output, labels):
    return F.cross_entropy(client_output, labels)

def server_side_loss(server_output, labels):
    return F.cross_entropy(server_output, labels)

# Placeholder for model aggregation logic
def aggregate_models(server_model, client_models, alpha, lambda_value):
    global_dict = server_model.state_dict()
    for k, v in global_dict.items():
        weighted_sum = sum((client_models[i].state_dict()[k] - v) * alpha[i] * lambda_value + v for i in range(len(client_models)))
        global_dict[k] = weighted_sum / len(client_models)
    server_model.load_state_dict(global_dict)
    return server_model

# New function to perform server-side training and update
def train_server_model(server_model, aggregated_outputs, labels, optimizer_server):
    optimizer_server.zero_grad()
    loss = server_side_loss(aggregated_outputs, labels)
    loss.backward()
    optimizer_server.step()

# Training function for a client model
def train_client_model(client_model, server_model, auxiliary_classifier, client_data_loader, optimizer_client, optimizer_aux, gamma, lambda_value):
    client_model.train()
    auxiliary_classifier.train()

    for data, target in client_data_loader:
        data, target = data.to(device), target.to(device)
        optimizer_client.zero_grad()
        optimizer_aux.zero_grad()

        client_output = client_model(data)
        aux_output = auxiliary_classifier(client_output)

        # Adjusting the loss to incorporate lambda_value
        loss = (gamma * client_side_loss(aux_output, target)) + ((1 - lambda_value) * server_side_loss(server_model(client_output), target))
        loss.backward()

        optimizer_client.step()
        optimizer_aux.step()

def perform_inference(client_model, server_model, auxiliary_classifier, data, eth):
    client_model.eval()
    server_model.eval()  # 서버 모델을 평가 모드로 설정
    auxiliary_classifier.eval()
    with torch.no_grad():
        client_output = client_model(data)  # 클라이언트 모델의 출력
        probs = F.softmax(auxiliary_classifier(client_output), dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-5), dim=1).mean()

        if entropy < eth:
            # 엔트로피가 임계값보다 낮으면, 클라이언트 모델과 보조 분류기의 출력을 사용
            decision = 'client'
            output = probs
        else:
            # 엔트로피가 임계값 이상이면, 서버 모델로 전달하여 추론
            decision = 'server'
            # 서버 모델에 클라이언트 모델의 출력을 전달
            server_output = server_model(client_output.view(client_output.size(0), -1))  # Flatten if necessary
            output = F.softmax(server_output, dim=1)

        return output, decision

def main():
    client_datasets, test_loader = load_mnist_data()

    client_models = [ClientModel().to(device) for _ in range(num_clients)]
    server_model = ServerModel().to(device)
    auxiliary_classifiers = [AuxiliaryClassifier().to(device) for _ in range(num_clients)]

    optimizer_clients = [optim.SGD(model.parameters(), lr=learning_rate) for model in client_models]
    optimizer_auxs = [optim.SGD(aux.parameters(), lr=learning_rate) for aux in auxiliary_classifiers]
    optimizer_server = optim.SGD(server_model.parameters(), lr=learning_rate)

    #eth_values = [0.05, 0.1, 0.2, 0.4, 0.8, 1.2, 1.6, 2.3]
    eth_values = [0.4]

    # Simulate global rounds of training
    for epoch in range(num_global_rounds):
        print(f"Global Round {epoch+1}/{num_global_rounds}")
        for i, client_data_loader in enumerate(client_datasets):
            train_client_model(client_models[i], server_model, auxiliary_classifiers[i], client_data_loader, optimizer_clients[i], optimizer_auxs[i], gamma, lambda_value)

    # 각 eth 값에 대해 실행
    for eth in eth_values:
        print(f"\nEvaluating with Eth = {eth}:")
        all_accuracy = []
        all_decision_ratios = []

        for i, client_data_loader in enumerate(client_datasets):
            accuracy, decision_ratio = perform_inference_and_get_accuracy(client_models[i], server_model, auxiliary_classifiers[i], test_loader, eth)
            all_accuracy.append(accuracy)
            all_decision_ratios.append(decision_ratio)
            print(f"Client {i+1} Accuracy: {accuracy:.2f}%")
            print(f"Client {i+1} Decision Ratio: {decision_ratio}")

def perform_inference_and_get_accuracy(client_model, server_model, auxiliary_classifier, test_loader, eth):
    client_model.eval()
    auxiliary_classifier.eval()
    correct = 0
    total = 0
    decision_counts = {'client': 0, 'server': 0}  # 클라이언트/서버 결정 카운터 추가
    with torch.no_grad():
        for data, target in test_loader:  # Evaluating on test data
            data, target = data.to(device), target.to(device)
            output, decision = perform_inference(client_model, server_model, auxiliary_classifier, data, eth)
            decision_counts[decision] += 1  # 결정 업데이트

            pred = output.argmax(dim=1, keepdim=True)  # Use the prediction (either from client or server)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    accuracy = 100. * correct / total if total > 0 else 0
    decision_ratio = {key: value / sum(decision_counts.values()) for key, value in decision_counts.items()}  # 비율 계산
    return accuracy, decision_ratio  # 비율도 반환

if __name__ == "__main__":
    main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 101031097.96it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 54199415.58it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 31523275.43it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1178140.31it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Client 1: {2: 600, 6: 600}
Client 2: {9: 1200}
Client 3: {6: 600, 7: 600}
Client 4: {7: 1200}
Client 5: {2: 1200}
Client 6: {0: 600, 2: 600}
Client 7: {4: 600, 5: 600}
Client 8: {2: 600, 6: 600}
Client 9: {8: 600, 4: 600}
Client 10: {1: 600, 4: 600}
Client 11: {1: 600, 3: 600}
Client 12: {8: 600, 7: 600}
Client 13: {0: 600, 5: 600}
Client 14: {9: 600, 6: 600}
Client 15: {8: 600, 4: 600}
Client 16: {0: 600, 7: 600}
Client 17: {9: 600, 5: 600}
Client 18: {0: 1200}
Client 19: {4: 1200}
Client 20: {2: 1200}
Client 21: {1: 600, 5: 600}
Client 22: {8: 600, 1: 600}
Client 23: {4: 600, 6: 600}
Client 24: {8: 600, 9: 600}
Client 25: {8: 1200}
Client 26: {0: 600, 7: 600}
Client 27: {3: 600, 7: 600}
Client 28: {8: 1200}
Client 29: {2: 600, 3: 600}
Client 30: {1: 600, 7: 600}
Client 31: {5: 600, 6: 600}
Client 32: {0: 600, 4: 600}
Client 33: {1: 600, 5: 600}
Client 34: {9: 600, 1: 600}
Client 35: {0: 600, 7: 600}
Client 36: