In [None]:
cd ../..

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

dataset_name = 'STL10'
# 仮定でこれらのフォーマットが事前に定義されている
if dataset_name == 'SVHN':
    is_prune_base = "./logs/SVHN/is_prune/remain_rate_30/seed_{}/2024_03_29_15_17_28/training_results.csv"
    is_transfer_base = "./logs/SVHN/is_transfer/remain_rate_30/seed_{}/2024_03_29_15_17_24/training_results.csv"
elif dataset_name == 'STL10':
    is_prune_base = "./logs/STL10/is_prune/remain_rate_30/seed_{}/2024_03_31_07_06_28/training_results.csv"
    is_transfer_base = "./logs/STL10/is_transfer/remain_rate_30/seed_{}/2024_03_31_07_15_00/training_results.csv"

prune_list = []
transfer_list = []

start = 0
end = 100

for idx in range(3):
    prune_list.append(pd.read_csv(is_prune_base.format(idx)).iloc[start:, 2].to_list())
    transfer_list.append(pd.read_csv(is_transfer_base.format(idx)).iloc[start:, 2].to_list())

# データをNumpy配列に変換
prune_array = np.array(prune_list)
transfer_array = np.array(transfer_list)

# 平均と標準偏差を計算
prune_mean = np.mean(prune_array, axis=0)
transfer_mean = np.mean(transfer_array, axis=0)
prune_std = np.std(prune_array, axis=0)
transfer_std = np.std(transfer_array, axis=0)

# プロット
plt.figure(figsize=(10, 6))
plt.plot(np.arange(start, end), prune_mean, label='Only Prune', color='blue')
plt.fill_between(np.arange(start, end), prune_mean-prune_std, prune_mean+prune_std, color='blue', alpha=0.2)

plt.plot(np.arange(start, end), transfer_mean, label='Is Transfer', color='green')
plt.fill_between(np.arange(start, end), transfer_mean-transfer_std, transfer_mean+transfer_std, color='green', alpha=0.2)

plt.legend()
plt.title("[{}] Comparison of Prune and Transfer Learning".format(dataset_name))
plt.xlabel("Epoch")
plt.ylabel("Metric Value")
plt.show()


In [None]:
import torch
dataset_name = 'STL10'
# 仮定でこれらのフォーマットが事前に定義されている
if dataset_name == 'SVHN':
    is_prune_base = "./logs/SVHN/is_prune/remain_rate_30/seed_{}/2024_03_29_15_17_28/resnet_slth_state.pkl"
    is_transfer_base = "./logs/SVHN/is_transfer/remain_rate_30/seed_{}/2024_03_29_15_17_24/resnet_slth_state.pkl"
    is_source_base = "./logs/CIFAR10/is_prune/ensemble_output/seed_{}/2024_03_27_16_44_39/resnet_slth1_state.pkl"

elif dataset_name == 'STL10':
    is_prune_base = "./logs/STL10/is_prune/remain_rate_30/seed_{}/2024_03_31_07_06_28/resnet_slth_state.pkl"
    is_transfer_base = "./logs/STL10/is_transfer/remain_rate_30/seed_{}/2024_03_31_07_15_00/resnet_slth_state.pkl"
    
    is_source_base = "./logs/CIFAR10/is_prune/ensemble_output/seed_{}/2024_03_27_16_44_39/resnet_slth1_state.pkl"



In [None]:
def get_binary_array(scores, k):

    out = scores.clone()
    _, idx = scores.flatten().sort()
    j = int((1 - k) * scores.numel())

    # flat_out and out access the same memory.
    flat_out = out.flatten()
    flat_out[idx[:j]] = 0
    flat_out[idx[j:]] = 1

    return flat_out.cpu().numpy().astype(int)

model_similarity = np.zeros((6, 6))
k = 0.3
seed = 1

source = torch.load(is_source_base.format(seed))
transfer = torch.load(is_transfer_base.format(seed))
is_prune = torch.load(is_prune_base.format(seed))

A = source
B = transfer

tmp = np.zeros((3, 3))

for i, A in enumerate([source, transfer, is_prune]):
    for j, B in enumerate([source, transfer, is_prune]):

        all_matching_scores = []
        for name, param in A.items():
            if (
                "scores" in name
            ):  
                next_param = B[name]
                matching_score = np.sum(get_binary_array(param, k) == get_binary_array(next_param, k)) / len(get_binary_array(next_param, k))
                all_matching_scores.append(matching_score)

        tmp[i, j] =np.mean(all_matching_scores)

In [None]:
import seaborn as sns
labels = ["source", "transfer", "no_transfer(is_prune)"]
plt.figure(figsize=(8, 6))
sns.heatmap(tmp, annot=True, cmap="coolwarm", fmt=".4f",
            xticklabels=labels,
            yticklabels=labels)
plt.title("model features simirality (cossine sim)")
plt.xlabel("Subnetwork")
plt.ylabel("Subnetwork")
plt.show()
plt.close()