In [None]:
cd ../../../

In [None]:
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm
import torch
import numpy as np
import copy
import re
import pandas as pd


def get_files_with_extension_recursively(base_path: str, extension: str):
    if not extension.startswith('.'):
        extension = '.' + extension
    search_pattern = os.path.join(base_path, '**', '*' + extension)
    files = glob(search_pattern, recursive=True)
    return files

base_path = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30"
base_csvs = get_files_with_extension_recursively(base_path, '.csv')
base_acc = np.array([pd.read_csv(csv).iloc[:, 2] for csv in base_csvs])

big_model_path = "./logs/CIFAR10/is_prune/baseline/ResNet152/20240619_follow_up_pruning_big_model/remain_rate_6"
big_modele_csvs = get_files_with_extension_recursively(big_model_path, '.csv')
big_modele_acc = np.array([pd.read_csv(csv).iloc[:, 2] for csv in big_modele_csvs])


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt



# 平均と標準偏差の計算
base_acc_mean = np.mean(base_acc, axis=0)
base_acc_std = np.std(base_acc, axis=0)

big_modele_acc_mean = np.mean(big_modele_acc, axis=0)
big_modele_acc_std = np.std(big_modele_acc, axis=0)

# 表示するエポック数の指定
start_epoch = 50

# プロットの作成
epochs = np.arange(1, base_acc_mean.shape[0] + 1)

plt.figure(figsize=(14, 10))

sns.lineplot(x=epochs[start_epoch-1:], y=base_acc_mean[start_epoch-1:], label=r'$ResNet18  \;remain=30\%$', color="blue")
plt.fill_between(epochs[start_epoch-1:], base_acc_mean[start_epoch-1:] - base_acc_std[start_epoch-1:], base_acc_mean[start_epoch-1:] + base_acc_std[start_epoch-1:], alpha=0.2, color="blue")

sns.lineplot(x=epochs[start_epoch-1:], y=big_modele_acc_mean[start_epoch-1:], label=r'$ResNet152  \;remain=6\%$', color="orange")
plt.fill_between(epochs[start_epoch-1:], big_modele_acc_mean[start_epoch-1:] - big_modele_acc_std[start_epoch-1:], big_modele_acc_mean[start_epoch-1:] + big_modele_acc_std[start_epoch-1:], alpha=0.2, color="orange")

plt.title("Accuracy over Epochs with Error Bars", fontsize=32)
plt.xlabel("Epoch", fontsize=32)
plt.ylabel("Accuracy", fontsize=32)
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)
plt.legend(fontsize=32)
plt.grid(True)
plt.tight_layout()
plt.show()