In [1]:
import os
import sys
sys.path.append("../../")
import argparse

import pandas as pandas
import numpy as np

import torch
import matplotlib.pyplot as plt

from copy import deepcopy

from utils import load_all_models, load_all_client_loader, vizualize_cka_model
from src.simulator.utils import get_client_dataset
from cka import CKACalculator

import torch.nn as nn

## Pre-setting

In [2]:
# {model_type : best_model wandb id}
model_type_dict = {"Center": "sunny-serenity-1",
                   "Local": "lunar-plant-1",
                   "FedAvg" : "treasured-wildflower-4",
                   "FedProx": "dainty-universe-2",
                   "MOON": "whole-cosmos-7",
                   }
BASE = "/NFS/Users/moonsh/AdaptFL/ckpt/"
device = torch.device("cuda:0")
batch_size, worker = 256, 0

In [None]:
model_dict, config = load_all_models(model_type_dict, BASE, device, batch_size, worker)
client_loader_list = load_all_client_loader(config)

In [None]:
vizualize_cka_model(model_dict, client_loader_list, device)

In [None]:
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")

In [9]:
def apply_important_weights(output, importance_weights):
    return output * importance_weights

# Hook으로 레이어의 출력을 추적하는 함수|
def hook_fn(module, input, output):
    # 가중치 곱하기 (중요한 가중치)
    importance_weights = torch.ones_like(output) * 0.5  # 예시로 0.5를 곱함
    modified_output = apply_important_weights(output, importance_weights)
    return modified_output

In [10]:
def register_hooks(model):
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv3d, nn.Linear)):  # Conv3d와 Linear 레이어에서만 hook 등록
            print(name)
            handle = module.register_forward_hook(hook_fn)
            handles.append(handle)
    return handles

In [None]:
handles = register_hooks(model_dict["FedProx"])

In [None]:
handles

## Overall Analysis

In [None]:
calculator = CKACalculator(model1=model1, model2=model2, dataloader=temp_loader)

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

In [None]:
# Extract the layer names
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")

In [None]:
cka_output = calculator.calculate_cka_matrix()
print(f"CKA output size: {cka_output.size()}")

plt.imshow(cka_output.cpu().numpy(), cmap='inferno')

## Analaysis per Layer

In [11]:
layers = (nn.Conv2d, nn.BatchNorm2d)

In [None]:
calculator.reset()
calculator = CKACalculator(model1=model1, model2=model2, dataloader=temp_loader,
                           hook_layer_types=layers)


In [None]:
cka_output = calculator.calculate_cka_matrix()
plt.imshow(cka_output.cpu().numpy(), cmap='inferno')

In [None]:
# Extract the layer names
for i, name in enumerate(calculator.module_names_X):
    print(f"Layer {i}: \t{name}")