In [None]:
cd ../../../

In [None]:
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
import torch

In [None]:
dataset_name = "CIFAR10"
batch_size = 128
device = 'cuda'

no_prune_weight = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30/seed_0/2024_06_07_00_03_15/resnet_slth_state.pkl"
resnet = ResNet18(10).to(device)
resnet = modify_module_for_slth(resnet, 0.3)
resnet.load_state_dict(torch.load(no_prune_weight))

In [None]:
dataset_name = "CIFAR10"
batch_size = 128
device = 'cuda'

no_prune_weight = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30/seed_0/2024_06_07_00_03_15/resnet_slth_state.pkl"
resnet = ResNet18(10).to(device)
resnet = modify_module_for_slth(resnet, 0.3).to(device)
resnet.load_state_dict(torch.load(no_prune_weight))
print(resnet.state_dict()['conv.weight'][0])

train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

resnet.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        output = resnet(images)
        #outputs_list = [model(images) for model in [m1, m2]]
        #ensemble_outputs = torch.stack(outputs_list).mean(dim=0)

        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    acc = 100 * correct / total


print(f"Accuracy: {acc}%")

In [None]:
import torch

# モデルの定義と重みのロード
no_prune_weight = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30/seed_0/2024_06_07_00_03_15/resnet_slth_state.pkl"
resnet = ResNet18(10).to(device)
resnet = modify_module_for_slth(resnet, 0.3).to(device)
resnet.load_state_dict(torch.load(no_prune_weight))

# メモリ使用量の確認（FP32）
print(f"FP32 Memory Usage: {torch.cuda.memory_allocated(device) / 1024**2} MB")

# モデル全体をFP16に変換
resnet.half()

# メモリ使用量の確認（FP16）
print(f"FP16 Memory Usage: {torch.cuda.memory_allocated(device) / 1024**2} MB")

# 精度評価（省略）
# ...


In [None]:
# モデルの定義
no_prune_weight = "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_30/seed_0/2024_06_07_00_03_15/resnet_slth_state.pkl"
resnet = ResNet18(10).to(device)
resnet = modify_module_for_slth(resnet, 0.3).to(device)
resnet.load_state_dict(torch.load(no_prune_weight))


# モデル全体をFP16に変換
resnet.half()

# テストデータのロード
train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

# 評価モードに切り替え
resnet.eval()

# 精度評価
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # 入力データもFP16に変換
        images = images.to(device).half()
        labels = labels.to(device)
        
        # 推論
        output = resnet(images)
        
        # 予測
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    acc = 100 * correct / total

print(f"Accuracy: {acc}%")


In [None]:
import re

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

# Base path template for the logs
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_{}'
extension = '.pkl'

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) != 4],
        key=extract_number
    )

In [None]:
for seed in range(1):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = filter_and_sort_files(file_list)

    # モデルのリストを作成
    models = []
    for file in filtered_sorted_files:
        resnet = ResNet18(10).to(device)
        resnet = modify_module_for_slth(resnet, 0.3).to(device)
        resnet.load_state_dict(torch.load(file))
        #resnet.half()
        resnet.eval()
        models.append(resnet)

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)#.half()  # 入力データをFP16に変換
            labels = labels.to(device)
            
            # 各モデルの出力を取得し、アンサンブル
            outputs = [model(images) for model in models]
            ensemble_output = torch.stack(outputs).mean(dim=0)
            
            # 予測を取得
            _, predicted = torch.max(ensemble_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
    print(accuracy)

In [None]:
# 93.29 to 92.805