In [None]:
import os
import sys
sys.path.append('../../../')

import numpy as np
import torch
import torch.nn as nn

import argparse
import json

from copy import deepcopy

import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

from src.simulator.utils import generate_model, get_client_dataset

In [2]:
def load_config(config_path, proj_name):
    PATH = os.path.join(config_path, f"config_{proj_name}.json")
    with open(PATH, 'r') as f:
        config = json.load(f)
    return config

def load_model(model_name, modelPATH, config, device):
    # device = torch.device('cpu')
    PATH = os.path.join(modelPATH, model_name)
    model_dict = torch.load(PATH, map_location=device)

    global_model = generate_model(config).to(device)
    
    if (config.agg_method != "Center") and (config.agg_method != "Local"):
        global_model.load_state_dict(model_dict['global_model'], strict=False)
        local_model_dict = model_dict['local_model']
        
    else:
        global_model.load_state_dict(model_dict['model'], strict=False)
        local_model_dict = None

    return global_model, local_model_dict

def load_model(model_name, modelPATH, config, device):
    # device = torch.device('cpu')
    PATH = os.path.join(modelPATH, model_name)
    model_dict = torch.load(PATH, map_location=device)

    global_model = generate_model(config).to(device)
    
    if (config.agg_method != "Center") and (config.agg_method != "Local"):
        global_model.load_state_dict(model_dict['global_model'], strict=False)
        local_model_dict = model_dict['local_model']
        
    else:
        global_model.load_state_dict(model_dict['model'], strict=False)
        local_model_dict = None

    return global_model, local_model_dict

In [3]:
basePATH = "Z://Users/moonsh/AdaptFL/ckpt/"

In [4]:
model_type_dict = {"Center": ["generous-sunset-1", -1], # No need to specify epoch
                   "Local": ["good-surf-1", -1], # No need to specify epoch
                   "FedAvg" : ["polished-wood-3", -1],
                   "FedProx": ["spring-sponge-10", -1],  
                   "MOON": ["fancy-disco-2", -1]
                   }

In [5]:
agg_method = "FedAvg"
wandb_name = model_type_dict[agg_method][0]
client_idx = 8


ckptPATH = os.path.join(basePATH, agg_method, wandb_name)
config = load_config(ckptPATH, wandb_name)
config['batch_size'] = 32
config['num_workers'] = 2
config['nowandb'] = True
config = argparse.Namespace(**config)

model_name = f'{wandb_name}_best_model.pth'
glob_model = load_model(model_name, ckptPATH, config, torch.device('cpu'))
local_weight = glob_model[1][client_idx]

glob_model = glob_model[0]
# local_model = deepcopy(glob_model)
# local_model.load_state_dict(local_model, strict=False)

In [7]:
TestDataset = get_client_dataset(config, 10, 
                                 'Test', verbose=False, 
                                 get_info=False, PATH='Z://Users/moonsh/data/FLData/')

client_dataset = TestDataset[client_idx]
dataloader = DataLoader(client_dataset, batch_size=len(client_dataset))

full_batch_dataset = next(iter(dataloader))

In [None]:
local_model = deepcopy(glob_model)
local_model.load_state_dict(local_weight, strict=False)

In [9]:
activation={'Local': {},
            'Global': {}}

In [10]:
def get_activation(model_name, layer_name):
    def hook(module, input, output):
        activation[model_name][layer_name] = output.detach()
    return hook


def register_hooks_for_model(model, model_name):
    hooks = []
    for name, layer in model.named_modules():
        if (isinstance(layer, nn.Conv3d) or isinstance(layer, nn.Linear)) and ('downsample' not in name):
            hook = layer.register_forward_hook(get_activation(model_name, name))
            hooks.append(hook)
    return hooks

def remove_hooks(hooks):
    for hook in hooks:
        hook.remove()

In [11]:
hooks_model_a = register_hooks_for_model(local_model, 'Local')
hooks_model_b = register_hooks_for_model(glob_model, 'Global')

In [12]:
loc_rep = local_model(full_batch_dataset[0]) # predict 값
glob_rep = glob_model(full_batch_dataset[0]) # predict 값

In [None]:
num_layers = 0
for (layer_name, loc_output), glob_output in zip(activation['Local'].items(), activation['Global'].values()):
    print(layer_name, loc_output.shape)

In [None]:
num_show_ch = 16
interval_ch = channel // num_show_ch
channel_index_list = [i for i in range(0, channel, interval_ch)]
len(channel_index_list)

if len(channel_index_list) > num_show_ch:
    channel_index_list = channel_index_list[:num_show_ch]
len(channel_index_list)

In [137]:
activations = loc_output
activations2 = glob_output

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))

for plot_idx, i in enumerate(channel_index_list):
    ax = axs[plot_idx // 4, plot_idx % 4]

    channel_activations = activations[:, i, :, :, :].flatten()
    channel_activations2 = activations2[:, i, :, :, :].flatten()

    # Kernel Density Estimation
    kde = gaussian_kde(channel_activations)
    activation_range = np.linspace(channel_activations.min(), channel_activations.max(), 100)
    kde_values = kde(activation_range)

    kde2 = gaussian_kde(channel_activations2)
    activation_range2 = np.linspace(channel_activations2.min(), channel_activations2.max(), 100)
    kde_values2 = kde2(activation_range2)

    # Plot KDE
    ax.plot(activation_range, kde_values, color='red', label='Local')
    ax.plot(activation_range2, kde_values2, color='blue', label='FedAvg')
    ax.set_title('KDE of Activations (Channel {})'.format(i))
    ax.set_xlabel('Activation Value')
    ax.set_ylabel('Probability Density')
    ax.legend()

plt.tight_layout()
plt.show()

remove_hooks(hooks_model_a)
remove_hooks(hooks_model_b)