In [None]:
cd ../..

In [None]:
from src.models.resnet.resnet import ResNet18
import torch
from src.pruning.slth.edgepopup_ensemble_output import modify_module_for_slth
from src.data.cifar.cifar10 import get_cifar10_data_loaders
from src.utils.seed import torch_fix_seed
device = "cuda"
batch_size = 128
seed = 1

if seed == 0:
    torch_fix_seed(0)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_0/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_0/2024_03_28_15_30_04/model_state.pkl"

if seed == 1:
    torch_fix_seed(1)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_1/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_1/2024_03_28_15_30_04/model_state.pkl"

if seed == 2:
    torch_fix_seed(2)
    base_path = "./logs/CIFAR10/is_prune/ensemble_output/seed_2/2024_03_27_16_44_39/resnet_slth{}_state.pkl"
    no_prune_path = "./logs/CIFAR10/no_prune/seed_2/2024_03_28_15_30_04/model_state.pkl"

In [None]:
train_loader, test_loader = get_cifar10_data_loaders(batch_size)

import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn

device = 'cuda'

paths = {"slth{}_path".format(str(i)): base_path.format(str(i)) for i in np.arange(1, 7)}
weights = {"slth{}_weight".format(str(i)): torch.load(base_path.format(str(i))) for i in np.arange(1, 7)}
models = {}
for i in np.arange(1, 7):
    model = modify_module_for_slth(ResNet18(), 0.3).to("cuda")
    model.load_state_dict(torch.load(base_path.format(str(i))))
    model.eval()
    models["slth{}_model".format(str(i))] = model

fc_identity_models = {}
for i in np.arange(1, 7):
    model = modify_module_for_slth(ResNet18(), 0.3).to("cuda")
    model.load_state_dict(torch.load(base_path.format(str(i))))
    model.fc = nn.Identity()
    model.eval()
    fc_identity_models["slth{}_model".format(str(i))] = model
    
weight_keys = list(weights.keys())
model_keys = list(models.keys())


In [None]:
no_prune_model = ResNet18().to("cuda")
no_prune_model.load_state_dict(torch.load(no_prune_path))
no_prune_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        _, predicted = torch.max(no_prune_model(images).data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    acc = 100 * correct / total
    print(acc)

In [None]:

for m in models.values():
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            _, predicted = torch.max(m(images), 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        acc = 100 * correct / total
        print(acc)

In [None]:
from tqdm import tqdm
import itertools
import matplotlib.pyplot as plt
import seaborn as sns

def ensemble_accuracy(models, num_models, test_loader, device):
    model_combinations = list(itertools.combinations(models.items(), num_models))
    ensemble_acc = []
    
    for combination in tqdm(model_combinations, desc=f"{num_models} models"):
        ensemble_models = [model for _, model in combination]
        
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = []
                for model in ensemble_models:
                    outputs.append(model(images))
                    
                merged_out = sum(outputs) / len(outputs)

                _, predicted = torch.max(merged_out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
            acc = 100 * correct / total
            ensemble_acc.append(acc)
            
    return ensemble_acc

# 2つ~6つのモデルのアンサンブルの結果を計算
results = {}

#results[6] = ensemble_accuracy(models, 6, test_loader, device)
for num_models in range(2, 7):
    results[num_models] = ensemble_accuracy(models, num_models, test_loader, device)

# 結果の可視化
fig, ax = plt.subplots(figsize=(10, 6))
sns.boxplot(data=list(results.values()), ax=ax)
ax.set_xticklabels(list(results.keys()))
ax.set_xlabel("Number of Models in Ensemble")
ax.set_ylabel("Accuracy (%)")
ax.set_title("Ensemble Accuracy Distribution")
plt.show()

In [None]:
for num_models in range(2, 7):
    print(max(results[num_models]))

In [None]:
dic = {
    0 : {
        "no_prune" : [94.09],
        "is_prune" : [93.01, 92.22, 93.06, 92.53, 86.1, 80.71],
        "ensemble_prune" : [94.1], #2
    },
    1 : {
        "no_prune" : [94.29],
        "is_prune" : [93.03, 92.43, 92.88, 92.95, 86.63, 82.34],
        "ensemble_prune" : [93.99], #3
    },
    2 : {
        "no_prune" : [93.9],
        "is_prune" : [93.14, 92.17, 92.83, 92.49, 86.61, 81.73],
        "ensemble_prune" : [93.88], #4
    }
}

In [None]:
no_prune_mean = np.mean([94.09, 94.29, 93.9])
is_prune_means = np.mean([
    [93.01, 92.22, 93.06, 92.53, 86.1, 80.71],
    [93.03, 92.43, 92.88, 92.95, 86.63, 82.34],
    [93.14, 92.17, 92.83, 92.49, 86.61, 81.73]
], axis=0)
is_ensemble_mean = np.mean([94.1, 93.99, 93.88])

In [None]:
# Preparing data for plotting
values = [no_prune_mean, *is_prune_means, is_ensemble_mean]
labels = ['No Prune', 'Is Prune 1', 'Is Prune 2', 'Is Prune 3', 'Is Prune 4', 'Is Prune 5', 'Is Prune 6', 'Is prune Ensemble']

# Creating the bar plot with value labels
plt.figure(figsize=(10, 6))
bars = plt.bar(labels, values, color=['blue', 'green', 'green', 'green', 'green', 'green', 'green', 'red'])
plt.xlabel('Model Configuration')
plt.ylabel('Mean Accuracy (%)')
plt.title('Comparison of Mean Accuracies Across Different Configurations')
plt.xticks(rotation=45)
plt.ylim(75, 95)  # Set the y-axis limits to better showcase the range of values
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Adding text labels above each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, round(yval, 2), va='bottom', ha='center')

plt.tight_layout()
plt.show()
