In [None]:
cd ../../../

In [None]:
import os
import re
import numpy as np
import torch
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Base path template for the logs
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240628_q2_1/seed_{}'
extension = '.pkl'

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) != 10],
        key=extract_number
    )

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

all_accs = []

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)
    
    num_files = len(filtered_sorted_files)
    acc_matrix = np.zeros((1, num_files))

    for i, f1 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Outer Loop')):
        m1 = ResNet18(n_class).to(device)
        m1 = modify_module_for_slth(m1, remain_rate=remain_rate, is_print=False).to(device)
        m1.load_state_dict(torch.load(f1))

        m1.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = m1(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            acc = 100 * correct / total
            acc_matrix[0, i] = acc

    all_accs.append(acc_matrix)

# Print or save the results as needed
print(all_accs)


In [None]:
filtered_sorted_files

In [None]:
all_accs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


# 平均と標準偏差を計算
mean_accs = np.mean(all_accs, axis=0).reshape(-1)
std_accs = np.std(all_accs, axis=0).reshape(-1)

# モデル名のリスト作成
models = [f'{i}' for i in range(mean_accs.shape[0])]

# 棒グラフの描画
plt.figure(figsize=(14, 8))
sns.barplot(x=models, y=mean_accs, capsize=0.2)  # 標準偏差をエラーバーとして追加
plt.errorbar(x=models, y=mean_accs, yerr=std_accs, fmt='none', c='black', capsize=5)  # エラーバーを追加
plt.ylim(88, 95)  # y軸の範囲を88から95に設定
plt.xlabel('Score dist', fontsize=32)
plt.ylabel('Accuracy (%)', fontsize=32)
plt.title('Mean Accuracy for Each Score dist', fontsize=32)
plt.grid(True)  # グリッドを追加
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)
plt.show()


In [None]:
mean_accs.shape

In [None]:
mean_accs.shape

In [None]:
all_accs

In [None]:
# モデル名のリスト作成
#all_accs_for_single_model = all_accs
mean_accs = np.mean(all_accs, axis=0)
models = [f'{i}' for i in range(mean_accs.shape[1])]

# 棒グラフの描画
plt.figure(figsize=(14, 8))
sns.barplot(x=models, y=mean_accs[0])
plt.ylim(88, 95)  # y軸の範囲を0.9から1.0に設定
plt.xlabel('Score dist', fontsize=32)
plt.ylabel('Accuracy (%)', fontsize=32)
plt.title('Mean Accuracy for Each Score dist', fontsize=32)
plt.grid(True)  # グリッドを追加
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)
plt.show()