In [None]:
cd ../../../

In [None]:
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
import re

def get_files_with_extension(base_path, extension):
    files_with_extension = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(extension):
                files_with_extension.append(os.path.join(root, file))
    return files_with_extension

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state', file_path)
    return int(match.group(1)) if match else -1

def sort_files(file_list):
    return sorted(
        [file for file in file_list],
        key=extract_number
    )

def filter_and_sort_files(file_list):
    return sorted(
        [file for file in file_list if extract_number(file) != 4],
        key=extract_number
    )

In [None]:
class EnsembleNet(nn.Module):
    def __init__(self, num_models):
        super(EnsembleNet, self).__init__()
        # 畳み込み層
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        # プーリング層
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # 全結合層
        self.fc1 = nn.Linear(64 * 4 * 4, 128)  # 64チャネル、4x4画像に縮小
        self.fc2 = nn.Linear(128, num_models)  # num_modelsは出力するモデルの数

    def forward(self, x):
        # 畳み込み + ReLU + プーリング
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # フラット化
        x = x.view(x.size(0), -1)  # バッチサイズを保持しつつフラット化
        # 全結合層 + ReLU
        x = F.relu(self.fc1(x))
        # 出力層
        x = self.fc2(x)
        return x

In [None]:
dataset_name = "CIFAR10"
batch_size = 128
device = 'cuda'

# テストデータのロード
train_loader, test_loader = get_data_loaders(dataset_name=dataset_name, batch_size=batch_size)

# Base path template for the logs
base_path_template = './logs/CIFAR10/is_prune/ensemble_output_diff_score/20240616_q2_1_remain_rate_010/seed_{}'
extension = '.pkl'

for seed in range(5):
    base_path = base_path_template.format(seed)
    file_list = get_files_with_extension(base_path, extension)
    filtered_sorted_files = sort_files(file_list)

    # モデルのリストを作成
    models = []
    for file in filtered_sorted_files:
        resnet = ResNet18(10).to(device)
        resnet = modify_module_for_slth(resnet, 0.10, is_print=False).to(device)
        resnet.load_state_dict(torch.load(file))
        for param in resnet.parameters():
            param.requires_grad = False
        resnet.eval()
        models.append(resnet)

    ensemble_net = EnsembleNet(len(models)).to(device)

    # 損失関数とオプティマイザ
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ensemble_net.parameters(), lr=0.001)

    # トレーニングループ
    num_epochs = 10

    for epoch in range(num_epochs):
        ensemble_net.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            # 各モデルの出力を取得
            with torch.no_grad():
                outputs = [model(images) for model in models]
            outputs = torch.stack(outputs, dim=-1)  # [batch_size, num_classes, num_models]

            # 小型NNの出力をアンサンブル重みとして使用
            optimizer.zero_grad()
            ensemble_weights = ensemble_net(images)  # [batch_size, num_models]
            ensemble_weights = nn.functional.softmax(ensemble_weights, dim=1)  # ソフトマックスで正規化

            # アンサンブルの予測を計算
            ensemble_output = torch.sum(outputs * ensemble_weights.unsqueeze(1), dim=-1)  # [batch_size, num_classes]

            loss = criterion(ensemble_output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

    # テストデータでの評価
    ensemble_net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            # 各モデルの出力を取得
            outputs = [model(images) for model in models]
            outputs = torch.stack(outputs, dim=-1)
            
            # 小型NNの出力をアンサンブル重みとして使用
            ensemble_weights = ensemble_net(images)
            ensemble_weights = nn.functional.softmax(ensemble_weights, dim=1)
            
            # アンサンブルの予測を計算
            ensemble_output = torch.sum(outputs * ensemble_weights.unsqueeze(1), dim=-1)
            
            _, predicted = torch.max(ensemble_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Ensemble Model Accuracy: {accuracy}%")


In [None]:
import torch

import torch
import torch.nn.functional as F

def entropy(ensemble_outputs):
    """
    アンサンブルの不確実性をエントロピーで計算する関数
    
    :param outputs_list: モデルの出力のリスト。各要素はshape (n_samples, n_classes) のtensor
    :return: shape (n_samples,) の不確実性スコア
    """
    # 各モデルの出力にsoftmaxを適用
    #outputs_list = [F.softmax(output, dim=1) for output in outputs_list]
    
    # アンサンブルの平均予測を計算
    #ensemble_outputs = torch.stack(outputs_list).mean(dim=0)
    
    # 数値の安定性のために、非常に小さい値をクリップ
    ensemble_outputs = torch.clamp(ensemble_outputs, min=1e-8, max=1-1e-8)
    
    # エントロピーを計算
    entropy = -torch.sum(ensemble_outputs * torch.log2(ensemble_outputs), dim=1)
    
    # nanをチェックし、必要に応じて0に置き換え
    #entropy = torch.where(torch.isnan(entropy), torch.zeros_like(entropy), entropy)
    #assert not torch.isnan(entropy), "error"
    #print(entropy)
    
    return entropy

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# ネットワークとデータローダーの定義が必要です
ensemble_net.eval()
all_weights = []

uncertainty_list = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = [model(images) for model in models]
        outputs = torch.stack(outputs, dim=-1)
        
        ensemble_weights = ensemble_net(images)
        ensemble_weights = nn.functional.softmax(ensemble_weights, dim=1)
        ensemble_output = torch.sum(outputs * ensemble_weights.unsqueeze(1), dim=-1)
        uncertainty = entropy(ensemble_output)
        uncertainty_list.append(uncertainty.cpu().numpy())
        
        all_weights.append(ensemble_weights.cpu().numpy())  # CPUに移動してからnumpy配列に変換

# 重みのリストをnumpy配列に変換
all_weights = np.concatenate(all_weights, axis=0)

# 平均値と標準偏差の計算
mean_weights = np.mean(all_weights, axis=0)
std_weights = np.std(all_weights, axis=0)

# 棒グラフで可視化
plt.bar(range(len(mean_weights)), mean_weights, yerr=std_weights, capsize=5)
plt.xlabel('Model Index')
plt.ylabel('Ensemble Weights')
plt.title('Ensemble Weights Mean and Std Dev')
plt.show()


In [None]:
print(np.mean([np.mean(arr) for arr in uncertainty_list]))
print(np.std([np.std(arr) for arr in uncertainty_list]))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# seabornで棒グラフを作成
sns.set(style="whitegrid")
plt.figure(figsize=(18, 12))
bars = plt.bar(range(len(mean_weights)), mean_weights, yerr=std_weights, capsize=5, color='b')

# meanの値を棒グラフの上に表示
for bar, mean in zip(bars, mean_weights):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f'{mean:.4f}', ha='center', va='bottom', fontsize=32)

plt.xlabel('Score Index', fontsize=32)
plt.ylabel('Ensemble Weights', fontsize=32)
plt.title('Ensemble Weights Mean and Std Dev', fontsize=32)
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)
plt.ylim(0, 0.35)
plt.show()
