In [1]:
import os
import sys
sys.path.append("../../")
import argparse

import pandas as pandas
import numpy as np

import torch
import matplotlib.pyplot as plt

from copy import deepcopy

from utils import load_config, load_model
from src.simulator.utils import get_client_dataset
from cka import CKACalculator

import torch.nn as nn

## Pre-setting

In [2]:
agg_method = 'Center'
proj_name = 'sunny-serenity-1'
ckptPATH = f'/NFS/Users/moonsh/AdaptFL/ckpt/{agg_method}/{proj_name}' # 'Z://Users/moonsh/AdaptFL/ckpt'

config = load_config(ckptPATH, proj_name)
config['batch_size'] = 64
config['num_workers'] = 4
config['nowandb'] = True

config = argparse.Namespace(**config)

In [3]:
main_name = proj_name.split('-')[1]
device = torch.device("cuda:0")

model_name = 'Center_best_model.pth'
glob_model, loc_model_dict = load_model(model_name, ckptPATH, config, device)

In [4]:
TestDataset_dict = get_client_dataset(config, config.num_clients, 
                                    _mode='test', verbose=False, 
                                    PATH=config.data_path,
                                    get_info=True)

temp_loader = torch.utils.data.DataLoader(TestDataset_dict[0],
                                          batch_size=4, shuffle=False,
                                          num_workers=0)

In [6]:
model1 = glob_model
model2 = deepcopy(glob_model)
# model2.load_state_dict(loc_model_dict[0], strict=False)

In [None]:
prev_name = "layer"

for name, child in model1.named_children():
    print(child)

In [11]:
def apply_important_weights(output, importance_weights):
    return output * importance_weights

# Hook으로 레이어의 출력을 추적하는 함수
def hook_fn(module, input, output):
    # 가중치 곱하기 (중요한 가중치)
    importance_weights = torch.ones_like(output) * 0.5  # 예시로 0.5를 곱함
    modified_output = apply_important_weights(output, importance_weights)
    return modified_output

In [17]:
def register_hooks(model):
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv3d, nn.Linear)):  # Conv3d와 Linear 레이어에서만 hook 등록
            print(name)
            handle = module.register_forward_hook(hook_fn)
            handles.append(handle)
    return handles

In [None]:
handles = register_hooks(model1)

In [None]:
handles

## Overall Analysis

In [None]:
calculator = CKACalculator(model1=model1, model2=model2, dataloader=temp_loader)

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

In [None]:
# Extract the layer names
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

plt.imshow(cka_output.cpu().numpy(), cmap='inferno')

## Analaysis per Layer

In [11]:
layers = (nn.Conv2d, nn.BatchNorm2d)

In [None]:
calculator.reset()
calculator = CKACalculator(model1=model1, model2=model2, dataloader=temp_loader,
                           hook_layer_types=layers)


In [None]:
cka_output = calculator.calculate_cka_matrix()
plt.imshow(cka_output.cpu().numpy(), cmap='inferno')

In [None]:
# Extract the layer names
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")