## 1. 데이터 준비

In [1]:
%matplotlib inline
from IPython.display import clear_output
import os
from copy import deepcopy
import pickle

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import torch.nn as nn

from models import ConvNet
from helper import ExperimentLogger, display_train_stats
from fl_devices import Server, Client
from data_utils import generate_server_idcs, CustomSubset, split_noniid
from torchvision.models import resnet18
import seaborn as sns

torch.manual_seed(42)
np.random.seed(42)

In [2]:
LOCAL_EPOCHS = 50
EPS_1 = 0.4
EPS_2 = 1.6
N_CLIENTS = 100

In [3]:
data = datasets.CIFAR10(root="CIFAR10/", download=False)
idcs = np.random.permutation(len(data))

def acc_test(server, clients, client_accs, global_accs):
    acc_clients = [client.evaluate() for client in clients]
    client_acc = round(sum(acc_clients) / len(acc_clients), 3)
    client_accs.append(client_acc)

    accuracies = [server.evaluate_distil(client.model) for client in clients]
    global_acc = round(np.mean(accuracies), 3)
    global_accs.append(global_acc)
    
    return client_accs, global_accs

def cluster(server, clients):
    label_predicted = pd.DataFrame()
    for i, client in enumerate(clients):
        acc, pred, soft_sum, diff = server.evaluate(client.model)
        # print(f'pred: {pred}, acc: {acc}, diff:{diff}')
        label_predicted = pd.concat([label_predicted, pd.DataFrame(pred, index=[i])])
    label_predicted.reset_index(drop=True, inplace=True)
    label_predicted.fillna(0, inplace=True)
    
    print(f'predicted label')
    print(label_predicted)
    c1, c2, c3 = server.cluster_clients_GMM(label_predicted)
    # PCA 후 cluster index를 label로 활용해 시각화
    return [c1, c2, c3]

In [4]:
def do_epoch_experiments(total_client_data, distill_data, ALPHA):
    data_per_class=int(distill_data//10)
    train_idcs, test_idcs = idcs[:total_client_data], idcs[total_client_data:(total_client_data + int(distill_data * 2))]
    train_labels = data.targets
    test_labels = data.targets
    
    server_idcs = generate_server_idcs(test_idcs, test_labels, int(distill_data//10))

    client_idcs = split_noniid(train_idcs, train_labels, alpha=ALPHA, n_clients=N_CLIENTS)
    client_data = [CustomSubset(data, idcs) for idcs in client_idcs]
    test_data = CustomSubset(data, server_idcs, transforms.Compose([transforms.ToTensor()]))

    for i, client_datum in enumerate(client_data):
        client_datum.subset_transform = transforms.Compose([transforms.ToTensor()])

    server = Server(resnet18, lambda x : torch.optim.Adam(x, weight_decay=0.98),test_data)

    clients = [Client(resnet18, lambda x: torch.optim.Adam(x, weight_decay=0.98), dat, i) 
               for i, dat in enumerate(client_data)]

    client_accs = []
    global_accs = []
    client_logits = []
    
    # 1.Local training
    for epoch in range(LOCAL_EPOCHS):
        if epoch % 20 == 0:
            print(f'round {epoch}')
        for i, client in enumerate(clients):
            client.compute_weight_update(epochs=1)

        if epoch % 10 == 0:
            client_accs, global_accs = acc_test(server, clients, client_accs, global_accs)
            print(f'client_acc: {client_accs[-1]}, global_acc: {global_accs[-1]}')

    # 2.Clustering
    cluster(server, clients)
        
    # 3.Get cluster, global loigt
    for i, client in enumerate(clients):
        if i == 0:
            distill_data = server.get_clients_logit(client.model,data_per_class=data_per_class)
            client_logits.append(distill_data[2])
        else:
            client_logits.append(server.get_clients_logit(client.model,data_per_class=data_per_class)[2])
    global_logits = server.get_global_logits(client_logits)
    
    # 3.Distillation
    for i, client in enumerate(clients):
        if i % 10 == 0:
            print(f'client {i} distill')
        client.distill((distill_data[0], distill_data[1], global_logits))
    
    client_accs, global_accs = acc_test(server, clients, client_accs, global_accs)
    
    print(f'total_client_data: {total_client_data}, data_per_class: {data_per_class}')
    print(f'first acc: {client_accs[0]}, {global_accs[0]}')
    print(f'acc before distill: {client_accs[-2]}, {global_accs[-2]}')
    print(f'last acc: {client_accs[-1]}, {global_accs[-1]}')
    return client_accs, global_accs

In [None]:
# Create a multi-index for the columns
now = datetime.datetime.now()
date_time = now.strftime("%m%d_%H%M")

columns = pd.MultiIndex.from_product([['client_accs', 'global_accs'], ['before_distill', 'after_distill']],
                                     names=['acc_type', 'distill_state'])

client_data_values = [10000, 20000, 40000]
distill_data_values = [1000, 2000, 5000]

index = pd.MultiIndex.from_product([client_data_values, distill_data_values], names=['client_data', 'distill_data'])

# Initialize an empty DataFrame with the desired MultiIndex for rows and columns
df = pd.DataFrame(np.nan, index=index, columns=columns)

for i, client_data in enumerate(client_data_values):
    for j, distill_data in enumerate(distill_data_values):
        client_accs, global_accs = do_epoch_experiments(total_client_data=client_data, distill_data=distill_data, ALPHA=1)

        # Set the values in the DataFrame
        df.loc[(client_data, distill_data), ('client_accs', 'before_distill')] = client_accs[-2]
        df.loc[(client_data, distill_data), ('client_accs', 'after_distill')] = client_accs[-1]
        df.loc[(client_data, distill_data), ('global_accs', 'before_distill')] = global_accs[-2]
        df.loc[(client_data, distill_data), ('global_accs', 'after_distill')] = global_accs[-1]

        file_name = f'results/global_distill/CIFAR_unlabeled_{date_time}.csv'
        df.to_csv(file_name)


round 0
client_acc: 0.282, global_acc: 0.132
client_acc: 0.423, global_acc: 0.2
round 20
client_acc: 0.423, global_acc: 0.202
client_acc: 0.425, global_acc: 0.2
round 40
client_acc: 0.428, global_acc: 0.197
predicted label
      0    3    2    4    6    5    1    7    8   9
0   277  365   96  242   18    2    0    0    0   0
1    43  614   30    0    0  220   93    0    0   0
2    79   77  267   21    2  265   64  225    0   0
3    19  105   45   36  128   14  554   27   72   0
4     0   20    0  298   28  118  430  106    0   0
..  ...  ...  ...  ...  ...  ...  ...  ...  ...  ..
95  398   12   17    0  573    0    0    0    0   0
96    3   37  163    0   14    0    0    0  754  29
97   74    8  202    0  121    2  526   67    0   0
98   11    0    0    0   47   10   18  840   69   5
99   30    2    3    2  415    5  254    0  208  81

[100 rows x 10 columns]
[ 1  3  4  7 21 26 34 42 44 49 60 77 78 80 81 84 86 87 88 97 99] [ 0  2  6  8 10 14 17 18 22 23 24 25 27 28 29 31 33 35 38 39 40

In [None]:
df

In [None]:
df = pd.read_csv('results/global_distill/CIFAR_0720_0435.csv', index_col=[0,1], header=[0,1])

# 그릴 데이터와 제목을 리스트로 저장
heatmap_data = [('client_accs', 'change_after_distill', 'Client Accuracy change after Distillation'),
                ('global_accs', 'change_after_distill', 'Global Accuracy change after Distillation')]

# Compute change in accuracy
df[('client_accs', 'change_after_distill')] = df[('client_accs', 'after_distill')] - df[('client_accs', 'before_distill')]
df[('global_accs', 'change_after_distill')] = df[('global_accs', 'after_distill')] - df[('global_accs', 'before_distill')]

# 전체 데이터의 최솟값, 최댓값 계산
vmin = min(df[data1][data2].min() for data1, data2, _ in heatmap_data)
vmax = max(df[data1][data2].max() for data1, data2, _ in heatmap_data)

for data1, data2, title in heatmap_data:
    plt.figure(figsize=(9, 5))
    sns.heatmap(df[(data1, data2)].unstack(), annot=True, cmap='coolwarm', center=0, vmin=-0.1, vmax=0.2)
    plt.title(title)
    plt.show()


## 2.Clustering 실험

In [None]:
def do_clustering_experiments(total_client_data=total_client_data, data_per_class=data_per_class, ALPHA=ALPHA):
    train_idcs, test_idcs = idcs[:int(total_client_data*10)], idcs[int(total_client_data*10):]
    train_labels = data.train_labels.numpy()
    test_labels = data.train_labels.numpy()[int(total_client_data*10):]

    client_idcs = split_noniid(train_idcs, train_labels, alpha=ALPHA, n_clients=N_CLIENTS)#, data_per_class=int(total_client_data/10))
    # server_idcs = generate_server_idcs(test_idcs, test_labels, int(total_client_data*10))

    client_data = [CustomSubset(data, idcs) for idcs in client_idcs]
    test_data = CustomSubset(data, test_idcs, transforms.Compose([transforms.ToTensor()]))
    
    for i, client_datum in enumerate(client_data):
        client_datum.subset_transform = transforms.Compose([transforms.ToTensor()])

    server = Server(resnet18, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9),test_data)

    
    distillation_data_file = f'distillation_data_{data_per_class}_per_class.pth'

    # Check if the file already exists
    if not os.path.exists(distillation_data_file):
        # The file does not exist, generate and save the distillation data
        distillation_data = server.make_distillation_data(data_per_class=data_per_class)
        torch.save(distillation_data, distillation_data_file)

    # Load the distillation data
    distillation_data = torch.load(distillation_data_file)

    clients = [Client(resnet18, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), dat, i, distillation_data) 
               for i, dat in enumerate(client_data)]

    def aggregate(cluster_indices_new):
        cluster_indices = cluster_indices_new
        client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]

        server.aggregate_clusterwise(client_clusters)

        return cluster_indices

    cfl_stats = ExperimentLogger()

    cluster_indices = [np.arange(len(clients)).astype("int")]
    client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]


    for epoch in range(1, LOCAL_EPOCHS+1):

        if epoch == 1:
            for client in clients:
                client.synchronize_with_server(server)

        participating_clients = server.select_clients(clients, frac=1.0)

        for client in participating_clients:
            if epoch == 1:
                client.distill()

            train_stats = client.compute_weight_update(epochs=1) #train client

            if epoch == 1000:
                client.reset()

        cluster_indices_new = []

        for idc in cluster_indices:
            max_norm = server.compute_max_update_norm([clients[i] for i in idc])
            mean_norm = server.compute_mean_update_norm([clients[i] for i in idc])

            #cluster 나누는 기준
            if epoch == LOCAL_EPOCHS: #무조건 한번 나누기
                similarities = server.compute_pairwise_similarities(clients)

                server.cache_model(idc, clients[idc[0]].W, acc_clients)

                c1, c2, c3 = server.cluster_clients_GMM(similarities[idc][:,idc])
                cluster_indices_new += [c1, c2, c3]

        if epoch == 1000:
            cluster_indices = aggregate(cluster_indices_new)

        acc_clients = [client.evaluate() for client in clients]

        if epoch == LOCAL_EPOCHS: #무조건 한번 나누기
            label_accuracies = pd.DataFrame()
            label_predicted = pd.DataFrame()
            label_soft_sum = pd.DataFrame()
            label_diff = pd.DataFrame()

            for i, client in enumerate(clients):
                acc, pred, sum_, diff = server.evaluate(client.model)
                # Convert each dictionary to a DataFrame and append to the respective DataFrame
                label_accuracies = label_accuracies.append(pd.DataFrame(acc, index=[i]))
                label_predicted = label_predicted.append(pd.DataFrame(pred, index=[i]))
                label_soft_sum = label_soft_sum.append(pd.DataFrame(sum_, index=[i]))
                label_diff = label_diff.append(pd.DataFrame(diff, index=[i]))

            # Reset index for all DataFrames
            label_accuracies.reset_index(drop=True, inplace=True)
            label_predicted.reset_index(drop=True, inplace=True)
            label_soft_sum.reset_index(drop=True, inplace=True)
            label_diff.reset_index(drop=True, inplace=True)

        if epoch == 1:
            first_accuracies = pd.DataFrame()
            for i, client in enumerate(clients):
                first_acc, pred, sum_, diff = server.evaluate(client.model)
                first_accuracies = pd.concat([first_accuracies, pd.DataFrame(first_acc, index=[i])])
            first_accuracies = first_accuracies.fillna(0)

            client_acc_after_distill = sum(acc_clients)/len(acc_clients)
            global_acc_after_distill = np.mean(np.ravel(first_accuracies.values))


        elif epoch == LOCAL_EPOCHS:
            client_acc_final = sum(acc_clients)/len(acc_clients)
            global_acc_final = np.mean(np.ravel(label_accuracies.values))

        average_dw = server.get_average_dw(clients)
        #print(average_dw)
        cfl_stats.log({"acc_clients" : acc_clients, "mean_norm" : mean_norm, "max_norm" : max_norm,
                      "rounds" : epoch, "clusters" : cluster_indices, "average_dw": average_dw})


        display_train_stats(cfl_stats, EPS_1, EPS_2, LOCAL_EPOCHS)


    for idc in cluster_indices:    
        server.cache_model(idc, clients[idc[0]].W, acc_clients)
    
    client_acc_after_distill = round(client_acc_after_distill, 3)
    global_acc_after_distill = round(global_acc_after_distill, 3)
    client_acc_final = round(client_acc_final, 3)
    global_acc_final = round(global_acc_final, 3)
    
    return client_acc_after_distill, global_acc_after_distill, client_acc_final, global_acc_final

    print(client_acc_after_distill, global_acc_after_distill)
    print(client_acc_final, global_acc_final)

In [None]:
first_accuracies

In [None]:
label_accuracies.sort_index(axis=1)

In [None]:
label_soft_sum.sort_index(axis=1)

In [None]:
label_diff.sort_index(axis=1)

In [None]:
label_predicted.sort_index(axis=1)

In [None]:
from sklearn.decomposition import PCA
# Instantiate PCA
pca = PCA(n_components=2)

# Apply PCA to the dataframes
label_accuracies_pca = pca.fit_transform(label_accuracies)
label_predicted_pca = pca.fit_transform(label_predicted)
label_soft_sum_pca = pca.fit_transform(label_soft_sum)
label_diff_pca = pca.fit_transform(label_diff)
transformed_data = pca.fit_transform(similarities)

# Create labels
labels = [0, 0, 0, 1, 1, 1, 2, 2, 2]

fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Scatter plots with larger dots
dot_size = 50
axs[0, 0].scatter(label_accuracies_pca[:, 0], label_accuracies_pca[:, 1], c=labels, s=dot_size)
axs[0, 0].set_title('Label Accuracies')
axs[0, 1].scatter(label_predicted_pca[:, 0], label_predicted_pca[:, 1], c=labels, s=dot_size)
axs[0, 1].set_title('Label Predicted')
axs[1, 0].scatter(label_soft_sum_pca[:, 0], label_soft_sum_pca[:, 1], c=labels, s=dot_size)
axs[1, 0].set_title('Label Soft Sum')
axs[1, 1].scatter(label_diff_pca[:, 0], label_diff_pca[:, 1], c=labels, s=dot_size)
axs[1, 1].set_title('Label Soft Diff')

plt.show()

In [None]:
from sklearn.metrics import silhouette_score

# Calculate Silhouette Scores
silhouette_accuracies = silhouette_score(label_accuracies_pca, labels)
silhouette_predicted = silhouette_score(label_predicted_pca, labels)
silhouette_soft_sum = silhouette_score(label_soft_sum_pca, labels)
silhouette_diff = silhouette_score(label_diff_pca, labels)
silhouette_transformed_data = silhouette_score(transformed_data, labels)

print('Silhouette Score for Accuracies:', silhouette_accuracies)
print('Silhouette Score for Predicted:', silhouette_predicted)
print('Silhouette Score for Soft Sum:', silhouette_soft_sum)
print('Silhouette Score for diff:', silhouette_diff)
print('Silhouette Score for Model params:', silhouette_transformed_data)



In [None]:
#df.sort_index(axis=1)

데이터 Cluster 별 모델 파라미터 분포

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Fit and transform your data to 2D
pca = PCA(n_components=2)
transformed_data = pca.fit_transform(similarities)

# Assign labels based on index ranges
labels = [0, 0, 0, 1, 1, 1, 2, 2, 2]
unique_labels = np.unique(labels)
colors = plt.cm.Spectral(np.linspace(0, 0.35, len(unique_labels)))

# Plot the transformed data with labels
for label, color in zip(unique_labels, colors):
    idx = np.where(labels == label)
    plt.scatter(transformed_data[idx, 0], transformed_data[idx, 1], color=color, label=f'Cluster {label}')

# Add a legend
plt.legend()

plt.show()
