In [None]:
cd ../..

In [None]:
from src.models.resnet.resnet import ResNet18
import torch
from src.pruning.slth.edgepopup_ensemble_output import modify_module_for_slth
from src.data.cifar.cifar10 import get_cifar10_data_loaders
from src.utils.seed import torch_fix_seed
device = "cuda"
batch_size = 128
seed = 2

if seed == 0:
    torch_fix_seed(0)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_0/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_0/2024_03_28_15_30_04/model_state.pkl"

if seed == 1:
    torch_fix_seed(1)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_1/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_1/2024_03_28_15_30_04/model_state.pkl"

if seed == 2:
    torch_fix_seed(2)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_2/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_2/2024_03_28_15_30_04/model_state.pkl"

In [None]:
train_loader, test_loader = get_cifar10_data_loaders(batch_size)

In [None]:
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn

device = 'cuda'

paths = {"slth{}_path".format(str(i)): base_path.format(str(i)) for i in np.arange(1, 7)}
weights = {"slth{}_weight".format(str(i)): torch.load(base_path.format(str(i))) for i in np.arange(1, 7)}
models = {}
for i in np.arange(1, 7):
    model = modify_module_for_slth(ResNet18(), 0.3).to("cuda")
    model.load_state_dict(torch.load(base_path.format(str(i))))
    model.eval()
    models["slth{}_model".format(str(i))] = model

fc_identity_models = {}
for i in np.arange(1, 7):
    model = modify_module_for_slth(ResNet18(), 0.3).to("cuda")
    model.load_state_dict(torch.load(base_path.format(str(i))))
    model.fc = nn.Identity()
    model.eval()
    fc_identity_models["slth{}_model".format(str(i))] = model
    
weight_keys = list(weights.keys())
model_keys = list(models.keys())


In [None]:
no_prune_model = ResNet18().to("cuda")
no_prune_model.load_state_dict(torch.load(no_prune_path))
no_prune_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        _, predicted = torch.max(no_prune_model(images).data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    acc = 100 * correct / total
    print(acc)

In [None]:

for m in models.values():
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            _, predicted = torch.max(m(images), 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        acc = 100 * correct / total
        print(acc)

In [None]:
for idx in np.arange(1, 6):
    weight = weights[weight_keys[idx]]
    for name, param in weight.items():
        if (
            "weight" in name
        ):  # 'weight'を含む名前のパラメータのみチェック
            # 初期状態のモデルから同じ名前のパラメータを取得
            init_param = weights[weight_keys[0]][name]
            # 現在のパラメータと初期パラメータを比較
            assert torch.equal(
                param.data, init_param
            ), f"Weight mismatch found in {name} "
        if idx != 0: 
            if "scores" in name:  # 'weight'を含む名前のパラメータのみチェック
                # 初期状態のモデルから同じ名前のパラメータを取得
                init_param = weights[weight_keys[0]][name]
                # 現在のパラメータと初期パラメータを比較
                not_equal = not torch.equal(param.data, init_param)
                assert not_equal, f"score match found in {name} "


In [None]:
model_output_sim = np.zeros((6, 6))
for k1 in tqdm(np.arange(0, 6)):
    model1 = models[model_keys[k1]]
    for k2 in np.arange(0, 6):
        model2 = models[model_keys[k2]]
        with torch.no_grad():
            cossims = []
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs1 = model1(images)
                outputs2 = model2(images)
                cossims.append(F.cosine_similarity(outputs1, outputs2).mean().item())

            model_output_sim[k1, k2] = np.mean(cossims)

In [None]:
from tqdm import tqdm
model_acc = np.zeros((6, 6))
for k1 in tqdm(np.arange(0, 6)):
    model1 = models[model_keys[k1]]
    for k2 in np.arange(0, 6):
        model2 = models[model_keys[k2]]
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs1 = model1(images)
                outputs2 = model2(images)

                merged_out = (outputs1 + outputs2)/2
                _, predicted = torch.max(merged_out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            acc = 100 * correct / total
            model_acc[k1, k2] = acc

In [None]:
from tqdm import tqdm
import itertools
import matplotlib.pyplot as plt
import seaborn as sns

def ensemble_accuracy(models, num_models, test_loader, device):
    model_combinations = list(itertools.combinations(models.items(), num_models))
    ensemble_acc = []
    
    for combination in tqdm(model_combinations, desc=f"{num_models} models"):
        ensemble_models = [model for _, model in combination]
        
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = []
                for model in ensemble_models:
                    outputs.append(model(images))
                    
                merged_out = sum(outputs) / len(outputs)

                _, predicted = torch.max(merged_out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
            acc = 100 * correct / total
            ensemble_acc.append(acc)
            
    return ensemble_acc

# 2つ~6つのモデルのアンサンブルの結果を計算
results = {}

#results[6] = ensemble_accuracy(models, 6, test_loader, device)
for num_models in range(2, 7):
    results[num_models] = ensemble_accuracy(models, num_models, test_loader, device)

# 結果の可視化
fig, ax = plt.subplots(figsize=(10, 6))
sns.boxplot(data=list(results.values()), ax=ax)
ax.set_xticklabels(list(results.keys()))
ax.set_xlabel("Number of Models in Ensemble")
ax.set_ylabel("Accuracy (%)")
ax.set_title("Ensemble Accuracy Distribution")
plt.show()

In [None]:
def get_binary_array(scores, k):

    out = scores.clone()
    _, idx = scores.flatten().sort()
    j = int((1 - k) * scores.numel())

    # flat_out and out access the same memory.
    flat_out = out.flatten()
    flat_out[idx[:j]] = 0
    flat_out[idx[j:]] = 1

    return flat_out.cpu().numpy().astype(int)

model_similarity = np.zeros((6, 6))
k = 0.3
for i in np.arange(0, 6):
    for j in np.arange(0, 6):
        all_matching_scores = []
        for name, param in weights[weight_keys[i]].items():
            if (
                "scores" in name
            ):  
                next_param = weights[weight_keys[j]][name]
                matching_score = np.sum(get_binary_array(param, k) == get_binary_array(next_param, k)) / len(get_binary_array(next_param, k))
                all_matching_scores.append(matching_score)

        model_similarity[i, j] = np.median(all_matching_scores)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np



# ヒートマップの作成
plt.figure(figsize=(8, 6))
sns.heatmap(model_acc, annot=True, cmap="coolwarm", fmt=".4f")
plt.title("model acc")
plt.xlabel("Subnetwork")
plt.ylabel("Subnetwork")
plt.show()
plt.close()

# ヒートマップの作成
plt.figure(figsize=(8, 6))
sns.heatmap(model_output_sim, annot=True, cmap="coolwarm", fmt=".4f")
plt.title("model features simirality (cossine sim)")
plt.xlabel("Subnetwork")
plt.ylabel("Subnetwork")
plt.show()
plt.close()

# ヒートマップの作成
plt.figure(figsize=(8, 6))
sns.heatmap(model_similarity, annot=True, cmap="coolwarm", fmt=".4f")
plt.title("model architecture simirality (hamming distance)")
plt.xlabel("Subnetwork")
plt.ylabel("Subnetwork")
plt.show()
plt.close()

In [None]:
x = model_similarity.flatten()[~np.eye(model_similarity.shape[0], dtype=bool).flatten()]
y = model_acc.flatten()[~np.eye(model_acc.shape[0], dtype=bool).flatten()]

plt.scatter(x, y, alpha=0.6)
plt.xlabel('model architecture simirality (hamming distance)')
plt.ylabel('Ensemble Performance(ACC)')
plt.title('model architecture simirality (hamming distance) vs. Ensemble Performance (ACC)')
plt.show()

In [None]:
x = model_output_sim.flatten()[~np.eye(model_output_sim.shape[0], dtype=bool).flatten()]
y = model_acc.flatten()[~np.eye(model_acc.shape[0], dtype=bool).flatten()]

plt.scatter(x, y, alpha=0.6)
plt.xlabel('model features simirality (cossine sim)')
plt.ylabel('Ensemble Performance(ACC)')
plt.title('model features simirality (cossine sim) vs. Ensemble Performance (ACC)')
plt.show()