In [None]:
# 라이브러리 불러오기
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping

# MNIST dataset 불러오기
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 0~255의 픽셀값을 0~1 사이 실수로 정규화
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape the input data for CNN
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))

# MNIST 손글씨 숫자를 분류하는 CNN 모델 만들기
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, (5, 5), activation='relu'),    # 5x5 convolution layer with 32 channels
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (5, 5), activation='relu'),    # 5x5 convolution layers with 64 channels
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),            # a fully connected layer with 512 units and ReLu activation
    # tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation='softmax')           # a final softmax output layer
])

# 모델 컴파일
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Partitioning
num_clients = 100
B = 10
partition_type = "IID"  # "IID" 또는 "Non-IID" 선택

if partition_type == "IID":
    indices = np.arange(len(x_train))                                                     # 데이터셋의 인덱스 생성
    np.random.shuffle(indices)                                                            # 인덱스 배열을 무작위로 섞는다
    client_indices_list = [indices[i * 600 : (i + 1) * 600] for i in range(num_clients)]  # client에게 인덱스 분배

elif partition_type == "Non-IID":
    num_labels = 10
    shards_per_label = 20
    samples_per_shard = 300

    total_shards = num_labels * shards_per_label

    # Create shards
    shards = [[] for _ in range(total_shards)]
    for i in range(len(x_train)):
        label = y_train[i]
        shard_index = label * shards_per_label + i % shards_per_label
        shards[shard_index].append(i)

    client_indices_list = []
    for i in range(0, total_shards, shards_per_label * 2):
        client_indices = []
        for j in range(i, i + shards_per_label * 2):
            client_indices.extend(shards[j])
        client_indices_list.append(np.array(client_indices))

# 클라이언트마다 할당된 레이블 확인
for client_idx, indices in enumerate(client_indices_list):
    client_labels = y_train[indices]
    unique_labels = np.unique(client_labels)

    print(f"Client {client_idx + 1} - Labels: {unique_labels}")




# Define client update function using model.fit
def client_update(client_x, client_y, B, E):
    # EarlyStopping 콜백 정의
    early_stopping = EarlyStopping(monitor='accuracy', patience=3, mode='max', verbose=1)

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # EarlyStopping 콜백을 적용하여 모델 훈련
    history = model.fit(client_x, client_y, batch_size=B, epochs=E, callbacks=[early_stopping], verbose=1)

    return history.history['accuracy'][-1]  # Return the accuracy after the last epoch

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Client 1 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 2 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 3 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 4 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 5 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 6 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 7 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 8 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 9 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 10 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 11 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 12 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 13 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 14 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 15 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 16 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 17 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 18 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 19 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 20 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 21 - Labels: [0 1 2 3 4 5 6 7 8 9]
Client 22 - Labels: [0 1 2 3 4 5 6 

In [None]:
# Training loop

# 상수
num_rounds = 400
E = 20
learning_rate = 2.1544  # 10^(1/3)
C = 0.1

# 초기값 설정
w = model.get_weights()

# 라운드 번호와 해당 라운드의 테스트 정확도 저장할 리스트
rounds_list = []
test_acc_list = []

# Federated averaging training loop
for t in range(num_rounds):
    m = max(int(C * num_clients), 1)
    selected_clients = np.random.choice(num_clients, m, replace=False)

    client_updates = []
    for client_idx in selected_clients:
        client_x = x_train[client_indices_list[client_idx]]
        client_y = y_train[client_indices_list[client_idx]]

        with tf.device('GPU'):
            client_acc = client_update(client_x, client_y, B, E)
            print(f"Client {client_idx + 1} - Accuracy after {E} epochs: {client_acc:.4f}")

        client_updates.append(w)                                          # client마다 업데이트한 w값을 client_updates에 저장

    averaged_weights = [np.mean([client_weights[i] for client_weights in client_updates], axis=0) for i in range(len(w))]   # w의 가중평균 계산
    w = averaged_weights
    model.set_weights(w)

    # Evaluate the model on the test set
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)

    # Store results
    rounds_list.append(t + 1)
    test_acc_list.append(test_acc)

    print(f"Round {t + 1} - Test accuracy: {test_acc:.4f}")

# Plot the test accuracy over rounds
plt.figure(figsize=(10, 6))
plt.plot(rounds_list, test_acc_list, marker='o')
plt.title("Test Accuracy over Rounds")
plt.xlabel("Round")
plt.ylabel("Test Accuracy")
plt.grid(True)
plt.show()

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 12: early stopping
Client 95 - Accuracy after 20 epochs: 1.0000
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 9: early stopping
Client 72 - Accuracy after 20 epochs: 0.9917
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 7: early stopping
Client 63 - Accuracy after 20 epochs: 1.0000
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 6: early stopping
Client 80 - Accuracy after 20 epochs: 1.0000
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 9: early stopping
Client 70 - Accuracy after 20 epochs: 1.0000
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 7: early stopping
Client 68 - Accuracy after 20 epochs: 1.0000
Epoch 1/20
Epoch 2/20
Epoch 3/20