In [None]:
cd ../../../

In [None]:
import re
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
import torch

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Base path template for the logs
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240616_q2_1_remain_rate_010/seed_{}'
extension = '.pkl'

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

def sort_files(file_list):
    return sorted(
        [file for file in file_list],
        key=extract_number
    )

def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) != 4],
        key=extract_number
    )

In [None]:
dataset_name = "CIFAR10"
batch_size = 128
device = 'cuda'
train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

device = 'cuda'
for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = sort_files(file_list)

    # モデルのリストを作成
    models = []
    for file in filtered_sorted_files:
        resnet = ResNet18(10).to(device)
        resnet = modify_module_for_slth(resnet, 0.1, is_print=False).to(device)
        resnet.load_state_dict(torch.load(file))
        #resnet.half()
        resnet.eval()
        models.append(resnet)

    total = 0.0
    correct = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)#.half()  # 入力データをFP16に変換
            labels = labels.to(device)
            
            # 各モデルの出力を取得し、アンサンブル
            #outputs = [model(images) for model in models]
            #ensemble_output = torch.stack(outputs).mean(dim=0)
            
            # 予測を取得
            #_, predicted = torch.max(ensemble_output, 1)

            outputs_list = [model(images) for model in models]
            ensemble_outputs = torch.stack(outputs_list).mean(dim=0)

            _, predicted = torch.max(ensemble_outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
    print(accuracy)

In [None]:
base_path_template = './logs/CIFAR10/no_prune/20240611_q1_no_prune/seed_{}'
extension = '.pkl'

dataset_name = "CIFAR10"
batch_size = 128
device = 'cuda'
train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

device = 'cuda'
for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = sort_files(file_list)

    resnet = ResNet18(10).to(device)
    resnet.load_state_dict(torch.load(filtered_sorted_files[0]))
    resnet.eval()

    total = 0.0
    correct = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)#.half()  # 入力データをFP16に変換
            labels = labels.to(device)
            
            # 各モデルの出力を取得し、アンサンブル
            #outputs = [model(images) for model in models]
            #ensemble_output = torch.stack(outputs).mean(dim=0)
            
            # 予測を取得
            #_, predicted = torch.max(ensemble_output, 1)

            #outputs_list = [model(images) for model in models]
            #ensemble_outputs = torch.stack(outputs_list).mean(dim=0)

            _, predicted = torch.max(resnet(images), 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
    print(accuracy)

In [None]:
filtered_sorted_files

In [None]:
import numpy as np
dense = np.array([94.29, 94.09, 93.9, 93.75, 93.7])
output_ensemble = np.array([94.0, 94.11, 94.1, 94.23, 94.03])
optimize_output_ensemble = np.array([94.12, 94.32, 93.94, 94.09, 94.12])