In [None]:
cd ../../../

In [None]:
# all_ensemble
# partial ensemble(ex. 任意に選んだ2個, or 3個など)

import pandas as pd
import os
import numpy as np

# ensemble
ensemble_base_path = "./logs/CIFAR10/is_prune/ensemble_output_diff_weight/20240605_q2_3"
ensemble_date = "2024_06_08_14_47_05"

# no prune
no_prune_base_path = "./logs/CIFAR10/no_prune/20240611_q1_no_prune"
no_prune_date = "2024_06_11_16_13_45"

# is prune (no_ensemble)
prune_base_path = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30"
prune_date = "2024_06_07_00_03_15"


def calculate_statistics(base_path, date):
    
    seeds = [pd.read_csv(os.path.join(base_path, f'seed_{i}/{date}/training_results.csv')) for i in range(5)]
    accuracies = np.array([seed.iloc[:, 2] for seed in seeds])
    return np.mean(accuracies, 0), np.std(accuracies, 0)

score_ensemble_mean, score_ensemble_std = calculate_statistics(ensemble_base_path, ensemble_date)
no_prune_mean, no_prune_std = calculate_statistics(no_prune_base_path, no_prune_date)
is_prune_mean, is_prune_std = calculate_statistics(prune_base_path, prune_date)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

thresh = 60
x = np.linspace(0, 99, 100)

plt.figure(figsize=(14, 8))
sns.set(style="whitegrid")

# Plot for Score Ensemble
ax = sns.lineplot(x=x[thresh:], y=score_ensemble_mean[thresh:], marker='o', linewidth=2.5, label=r'$Weight \;Ensemble$')
plt.fill_between(x[thresh:], score_ensemble_mean[thresh:] - score_ensemble_std[thresh:], score_ensemble_mean[thresh:] + score_ensemble_std[thresh:], alpha=0.3)

# Plot for No Prune
ax = sns.lineplot(x=x[thresh:], y=no_prune_mean[thresh:], marker='o', linewidth=2.5, label=r'$No \;Prune$')
plt.fill_between(x[thresh:], no_prune_mean[thresh:] - no_prune_std[thresh:], no_prune_mean[thresh:] + no_prune_std[thresh:], alpha=0.3)

# Plot for IS Prune
ax = sns.lineplot(x=x[thresh:], y=is_prune_mean[thresh:], marker='o', linewidth=2.5, label=r'$IS\;Prune\;(remain\;30\%,\;Weights\;\sim\;U_k)$')
plt.fill_between(x[thresh:], is_prune_mean[thresh:] - is_prune_std[thresh:], is_prune_mean[thresh:] + is_prune_std[thresh:], alpha=0.3)

ax.set_xlabel('Index', fontsize=32)
ax.set_ylabel('Accuracy', fontsize=32)
ax.set_title('Accuracy Comparison with Standard Deviation', fontsize=32)
ax.tick_params(axis='both', which='major', labelsize=32)

# Move the legend outside the plot
plt.legend(fontsize=24, loc='upper left', bbox_to_anchor=(1, 1))

plt.show()

In [None]:
import os
import re
import numpy as np
import torch
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm  # tqdmをインポート

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Example usage
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_weight/20240605_q2_3/seed_{}'
extension = '.pkl'  # Specify the extension you are looking for

# 正規表現を使って数値部分を抽出する関数
def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

# slth4を除外し、ファイルリストを数値部分で整列
def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list],
        key=extract_number
    )

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

all_accs = []

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)
    
    num_files = len(filtered_sorted_files)
    acc_matrix = np.zeros((num_files, num_files))

    for i, f1 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Outer Loop')):
        for j, f2 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Inner Loop', leave=False)):
            m1 = ResNet18(n_class).to(device)
            m1 = modify_module_for_slth(
                m1, remain_rate=remain_rate
            ).to(device)
            m1.load_state_dict(torch.load(f1))

            m2 = ResNet18(n_class).to(device)
            m2 = modify_module_for_slth(
                m2, remain_rate=remain_rate
            ).to(device)
            m2.load_state_dict(torch.load(f2))

            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs_list = [model(images) for model in [m1, m2]]
                    ensemble_outputs = torch.stack(outputs_list).mean(dim=0)

                    _, predicted = torch.max(ensemble_outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                acc = 100 * correct / total
                acc_matrix[i, j] = acc

    all_accs.append(acc_matrix)

# Mean of all acc matrices
mean_acc_matrix = np.mean(all_accs, axis=0)

# Optionally, print the mean accuracy matrix
print(mean_acc_matrix)


In [None]:
# Plotting the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(mean_acc_matrix, annot=True, fmt=".2f", cmap="viridis")
plt.title('Mean Accuracy Matrix Heatmap')
plt.xlabel('Models')
plt.ylabel('Models')
plt.show()