In [None]:
cd ../../../


In [None]:
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm
import torch
import numpy as np
import copy
import re

def extract_number(file_path):
    match = re.search(r'slth(\d+)_state\.pkl', file_path)
    return int(match.group(1)) if match else float('inf')

# 数値部分に基づいてソート

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"
#percentage = 0.8

def get_files_with_extension_recursively(base_path: str, extension: str):
    if not extension.startswith('.'):
        extension = '.' + extension
    search_pattern = os.path.join(base_path, '**', '*' + extension)
    files = glob(search_pattern, recursive=True)
    return files

def get_top_indices(tensor, percentage):
    if tensor.is_cuda:
        tensor = tensor.cpu()
    threshold = np.percentile(tensor.numpy(), 100 * (1 - percentage))
    top_indices = torch.where(tensor >= threshold)
    top_indices_tensor = torch.zeros(tensor.shape, dtype=torch.float32)
    top_indices_tensor[top_indices] = 1.0
    return top_indices_tensor, top_indices

def main(base_path, percentage):
    files = get_files_with_extension_recursively(base_path, '.pkl')
    
    score_indices = {}
    score_tensors = {}
    shapes = {}
    overlap_rates = []

    base_weight = torch.load(files[0])
    
    for file in files:
        data = torch.load(file)
        for k, v in data.items():
            if "scores" in k:
                if k not in score_indices:
                    score_indices[k] = []
                    score_tensors[k] = []
                
                shapes[k] = v.shape
                tensor, indices = get_top_indices(v.flatten(), percentage)
                score_tensors[k].append(tensor)
                score_indices[k].append(indices)

    for k, indices_list in score_tensors.items():
        binary = (torch.sum(torch.stack(score_tensors[k]), 0) == len(files))
        base_weight[k] = binary.reshape(shapes[k]).type(torch.float32).to(device)
        overlap_rate = torch.sum(binary) / len(binary)
        overlap_rates.append(overlap_rate.item())
        print(f"{k}: {overlap_rate:.4f}")

    return score_indices, score_tensors, base_weight, overlap_rates

base_path = "./logs/CIFAR10/is_prune/ensemble_output_diff_weight/20240628_q2_3/seed_0/2024_06_29_17_07_43"
files = get_files_with_extension_recursively(base_path, '.pkl')
#files = [file for file in files if 'resnet_slth4_state.pkl' not in file]
files = sorted(files, key=extract_number)
score_indices, score_tensors, base_weight, overlap_rates = main(base_path, 0.3)


In [None]:
files

In [None]:
train_loader, test_loader = get_data_loaders(
            dataset_name=dataset_name, batch_size=batch_size
        )

accs = np.zeros((len(files), len(files)))


for idx_1, file1 in enumerate(files):
    for idx_2, file2 in enumerate(files):
        base_weight = torch.load(file1)
        base_weight_init = copy.deepcopy(base_weight)
        f2_w = torch.load(file2)
        for k, v in f2_w.items():
            if "weight" in k:
                base_weight[k] += v
            #elif "score" in k:
            #    assert torch.equal(v, base_weight[k]), f"Weight {k} has changed."
        
        resnet = ResNet18(n_class).to(device)
        resnet = modify_module_for_slth(resnet, remain_rate=0.3, is_print=False).to(device)
        resnet.load_state_dict(base_weight)

        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = resnet(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            acc = 100 * correct / total
            print(idx_1, idx_2, acc)
        accs[idx_1, idx_2] = acc

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(24, 18))
heatmap = sns.heatmap(accs, annot=True, fmt=".4f", cmap="viridis", cbar=True, annot_kws={"size": 32})

# Add title and labels
plt.title("ACC Heatmap", fontsize=32)
plt.xlabel("Initialization Mode", fontsize=32)
plt.ylabel("Initialization Mode", fontsize=32)

# Set tick parameters
plt.xticks(fontsize=32)
plt.yticks(fontsize=32)

cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=32)


# Show the plot
plt.tight_layout()
plt.show()