In [None]:
cd ../../../

In [None]:
# all_ensemble
# partial ensemble(ex. 任意に選んだ2個, or 3個など)

import pandas as pd
import os
import numpy as np

# ensemble
ensemble_base_path = "./logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1"
ensemble_date = "2024_06_07_09_47_21"

# no prune
no_prune_base_path = "./logs/CIFAR10/no_prune/20240611_q1_no_prune"
no_prune_date = "2024_06_11_16_13_45"

# is prune (no_ensemble)
prune_base_path = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30"
prune_date = "2024_06_07_00_03_15"


def calculate_statistics(base_path, date):
    
    seeds = [pd.read_csv(os.path.join(base_path, f'seed_{i}/{date}/training_results.csv')) for i in range(5)]
    accuracies = np.array([seed.iloc[:, 2] for seed in seeds])
    return np.mean(accuracies, 0), np.std(accuracies, 0)

score_ensemble_mean, score_ensemble_std = calculate_statistics(ensemble_base_path, ensemble_date)
no_prune_mean, no_prune_std = calculate_statistics(no_prune_base_path, no_prune_date)
is_prune_mean, is_prune_std = calculate_statistics(prune_base_path, prune_date)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

thresh = 60
x = np.linspace(0, 99, 100)

plt.figure(figsize=(14, 8))
sns.set(style="whitegrid")

# Plot for Score Ensemble
ax = sns.lineplot(x=x[thresh:], y=score_ensemble_mean[thresh:], marker='o', linewidth=2.5, label=r'$Score \;Ensemble$')
plt.fill_between(x[thresh:], score_ensemble_mean[thresh:] - score_ensemble_std[thresh:], score_ensemble_mean[thresh:] + score_ensemble_std[thresh:], alpha=0.3)

# Plot for No Prune
ax = sns.lineplot(x=x[thresh:], y=no_prune_mean[thresh:], marker='o', linewidth=2.5, label=r'$No \;Prune$')
plt.fill_between(x[thresh:], no_prune_mean[thresh:] - no_prune_std[thresh:], no_prune_mean[thresh:] + no_prune_std[thresh:], alpha=0.3)

# Plot for IS Prune
ax = sns.lineplot(x=x[thresh:], y=is_prune_mean[thresh:], marker='o', linewidth=2.5, label=r'$IS\;Prune\;(remain\;30\%,\;Weights\;\sim\;U_k)$')
plt.fill_between(x[thresh:], is_prune_mean[thresh:] - is_prune_std[thresh:], is_prune_mean[thresh:] + is_prune_std[thresh:], alpha=0.3)

ax.set_xlabel('Index', fontsize=32)
ax.set_ylabel('Accuracy', fontsize=32)
ax.set_title('Accuracy Comparison with Standard Deviation', fontsize=32)
ax.tick_params(axis='both', which='major', labelsize=32)

# Move the legend outside the plot
plt.legend(fontsize=24, loc='upper left', bbox_to_anchor=(1, 1))

plt.show()

In [None]:
import os
import re
import numpy as np
import torch
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Base path template for the logs
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_{}'
extension = '.pkl'

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) != 4],
        key=extract_number
    )

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

all_accs = []

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)
    
    num_files = len(filtered_sorted_files)
    acc_matrix = np.zeros((1, num_files))

    for i, f1 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Outer Loop')):
        m1 = ResNet18(n_class).to(device)
        m1 = modify_module_for_slth(m1, remain_rate=remain_rate, is_print=False).to(device)
        m1.load_state_dict(torch.load(f1))

        m1.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = m1(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            acc = 100 * correct / total
            acc_matrix[0, i] = acc

    all_accs.append(acc_matrix)

# Print or save the results as needed
print(all_accs)


In [None]:
# モデル名のリスト作成
#all_accs_for_single_model = all_accs
mean_accs = np.mean(all_accs_for_single_model, axis=0)
models = [f'{i}' for i in range(mean_accs.shape[1])]

# 棒グラフの描画
plt.figure(figsize=(14, 8))
sns.barplot(x=models, y=mean_accs[0])
plt.ylim(88, 95)  # y軸の範囲を0.9から1.0に設定
plt.xlabel('Score dist', fontsize=32)
plt.ylabel('Accuracy (%)', fontsize=32)
plt.title('Mean Accuracy for Each Score dist', fontsize=32)
plt.grid(True)  # グリッドを追加
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)
plt.show()

In [None]:
import torch

import torch
import torch.nn.functional as F

def entropy(ensemble_outputs):
    """
    アンサンブルの不確実性をエントロピーで計算する関数
    
    :param outputs_list: モデルの出力のリスト。各要素はshape (n_samples, n_classes) のtensor
    :return: shape (n_samples,) の不確実性スコア
    """
    # 各モデルの出力にsoftmaxを適用
    #outputs_list = [F.softmax(output, dim=1) for output in outputs_list]
    
    # アンサンブルの平均予測を計算
    #ensemble_outputs = torch.stack(outputs_list).mean(dim=0)
    
    # 数値の安定性のために、非常に小さい値をクリップ
    ensemble_outputs = torch.clamp(ensemble_outputs, min=1e-8, max=1-1e-8)
    
    # エントロピーを計算
    entropy = -torch.sum(ensemble_outputs * torch.log2(ensemble_outputs), dim=1)
    
    # nanをチェックし、必要に応じて0に置き換え
    #entropy = torch.where(torch.isnan(entropy), torch.zeros_like(entropy), entropy)
    #assert not torch.isnan(entropy), "error"
    #print(entropy)
    
    return entropy

In [None]:
import os
import re
import numpy as np
import torch
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm  # tqdmをインポート

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Example usage
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_{}'
extension = '.pkl'  # Specify the extension you are looking for

# 正規表現を使って数値部分を抽出する関数
def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

# slth4を除外し、ファイルリストを数値部分で整列
def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) > 6],
        key=extract_number
    )

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

all_accs = []

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)
    
    num_files = len(filtered_sorted_files)
    acc_matrix = np.zeros((num_files, num_files))

    uncertainty_list = []
    for i, f1 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Outer Loop')):
        for j, f2 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Inner Loop', leave=False)):
            m1 = ResNet18(n_class).to(device)
            m1 = modify_module_for_slth(
                m1, remain_rate=remain_rate, is_print=False
            ).to(device)
            m1.load_state_dict(torch.load(f1))
            m1.eval()

            m2 = ResNet18(n_class).to(device)
            m2 = modify_module_for_slth(
                m2, remain_rate=remain_rate, is_print=False
            ).to(device)
            m2.load_state_dict(torch.load(f2))
            m2.eval()

            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs_list = [model(images) for model in [m1, m2]]
                    ensemble_outputs = torch.stack(outputs_list).mean(dim=0)
                    uncertainty = entropy(ensemble_outputs)
                    uncertainty_list.append(uncertainty.cpu().numpy())

                    _, predicted = torch.max(ensemble_outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                acc = 100 * correct / total
                #acc_matrix[i, j] = np.array(uncertainty_list).flatten().mean()
                print(f1)
                print(f2)

                print(np.mean([np.mean(arr) for arr in uncertainty_list]))
                print(np.std([np.std(arr) for arr in uncertainty_list]))


    all_accs.append(acc_matrix)

# Mean of all acc matrices
mean_acc_matrix = np.mean(all_accs, axis=0)

# Optionally, print the mean accuracy matrix
print(mean_acc_matrix)


In [None]:
# Plotting the heatmap
plt.figure(figsize=(16, 12))
sns.heatmap(mean_acc_matrix, annot=True, fmt=".2f", cmap="viridis", annot_kws={"size": 16})
plt.title('Mean Accuracy Matrix Heatmap', fontsize=32)
plt.xlabel('Scores', fontsize=32)
plt.ylabel('Scores', fontsize=32)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

In [None]:
import os
import re
import numpy as np
import torch
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm  # tqdmをインポート

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Example usage
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_{}'
extension = '.pkl'  # Specify the extension you are looking for

# 正規表現を使って数値部分を抽出する関数
def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

# slth4を除外し、ファイルリストを数値部分で整列
def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) > 6],
        key=extract_number
    )

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

all_uncertainty = []

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)
    
    num_files = len(filtered_sorted_files)
    acc_matrix = np.zeros((num_files, num_files))

    for i, f1 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Outer Loop')):
        for j, f2 in enumerate(tqdm(filtered_sorted_files, desc=f'Seed {seed} - Inner Loop', leave=False)):
            m1 = ResNet18(n_class).to(device)
            m1 = modify_module_for_slth(
                m1, remain_rate=remain_rate, is_print=False
            ).to(device)
            m1.load_state_dict(torch.load(f1))
            m1.eval()

            m2 = ResNet18(n_class).to(device)
            m2 = modify_module_for_slth(
                m2, remain_rate=remain_rate, is_print=False
            ).to(device)
            m2.load_state_dict(torch.load(f2))
            m2.eval()

            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs_list = [model(images) for model in [m1, m2]]
                    ensemble_outputs = torch.stack(outputs_list).mean(dim=0)

                    _, predicted = torch.max(ensemble_outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                acc = 100 * correct / total
                acc_matrix[i, j] = acc

    all_accs.append(acc_matrix)

# Mean of all acc matrices
mean_acc_matrix = np.mean(all_accs, axis=0)

# Optionally, print the mean accuracy matrix
print(mean_acc_matrix)


In [None]:
acc = [92.13, 93.08, 93.58, 94.1, 94.08]
uncertainty_mean = [0.12595454, 0.15622562, 0.19518027, 0.36009267, 0.33390602]
uncertainty_std = [0.022389269, 0.025769085, 0.033253014, 0.030091353, 0.02570036]

import matplotlib.pyplot as plt
import seaborn as sns

# データ
acc = [92.07, 92.13, 93.08, 93.58, 94.1, 94.08]
uncertainty_mean = [0.13901885, 0.12595454, 0.15622562, 0.19518027, 0.36009267, 0.33390602]
uncertainty_std = [0.0216064, 0.022389269, 0.025769085, 0.033253014, 0.030091353, 0.02570036]

# Seabornのスタイルを設定
sns.set(style="whitegrid")

# プロットの作成
plt.figure(figsize=(18, 12))
plt.errorbar(acc, uncertainty_mean, yerr=uncertainty_std, fmt='o', capsize=5, markersize=10, color='b')

# ラベルとタイトルの設定
plt.xlabel('Accuracy (%)', fontsize=32)
plt.ylabel('Uncertainty Mean', fontsize=32)
plt.title('Uncertainty Mean vs Accuracy with Error Bars', fontsize=32)

# 軸のフォントサイズ設定
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)

# プロットの表示
plt.show()
