In [None]:
cd ../../../

In [None]:
from glob import glob
import os
from src.data import get_data_loaders
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
from tqdm import tqdm
import torch
import numpy as np

remain_rate = 0.3
n_class = 10
dataset_name = "CIFAR10"
batch_size = 128
device = "cuda"
#percentage = 0.8

def get_files_with_extension_recursively(base_path: str, extension: str):
    if not extension.startswith('.'):
        extension = '.' + extension
    search_pattern = os.path.join(base_path, '**', '*' + extension)
    files = glob(search_pattern, recursive=True)
    return files

def get_top_indices(tensor, percentage):
    if tensor.is_cuda:
        tensor = tensor.cpu()
    threshold = np.percentile(tensor.numpy(), 100 * (1 - percentage))
    top_indices = torch.where(tensor >= threshold)
    top_indices_tensor = torch.zeros(tensor.shape, dtype=torch.float32)
    top_indices_tensor[top_indices] = 1.0
    return top_indices_tensor, top_indices

def main(base_path, percentage):
    files = get_files_with_extension_recursively(base_path, '.pkl')
    
    score_indices = {}
    score_tensors = {}
    shapes = {}
    overlap_rates = []

    base_weight = torch.load(files[0])
    
    for file in files:
        data = torch.load(file)
        for k, v in data.items():
            if "scores" in k:
                if k not in score_indices:
                    score_indices[k] = []
                    score_tensors[k] = []
                
                shapes[k] = v.shape
                tensor, indices = get_top_indices(v.flatten(), percentage)
                score_tensors[k].append(tensor)
                score_indices[k].append(indices)

    for k, indices_list in score_tensors.items():
        binary = (torch.sum(torch.stack(score_tensors[k]), 0) == len(files))
        base_weight[k] = binary.reshape(shapes[k]).type(torch.float32).to(device)
        overlap_rate = torch.sum(binary) / len(binary)
        overlap_rates.append(overlap_rate.item())
        print(f"{k}: {overlap_rate:.4f}")

    return score_indices, score_tensors, base_weight, overlap_rates

# Example usage
base_path = "./logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_0/2024_06_07_09_47_21"
# /logs/CIFAR10/is_prune/ensemble_output_diff_score/20240605_q2_1/seed_0/2024_06_07_09_47_21
score_indices, score_tensors, base_weight, overlap_rates = main(base_path, 0.8)



In [None]:
list(score_tensors.keys())

In [None]:

# "./logs/CIFAR10/is_prune/baseline/20240606_q1/remain_rate_90/"

layers = ['conv.scores',
 'layer1.0.conv1.scores',
 'layer1.0.conv2.scores',
 'layer1.1.conv1.scores',
 'layer1.1.conv2.scores',
 'layer2.0.conv1.scores',
 'layer2.0.conv2.scores',
 'layer2.0.shortcut.0.scores',
 'layer2.1.conv1.scores',
 'layer2.1.conv2.scores',
 'layer3.0.conv1.scores',
 'layer3.0.conv2.scores',
 'layer3.0.shortcut.0.scores',
 'layer3.1.conv1.scores',
 'layer3.1.conv2.scores',
 'layer4.0.conv1.scores',
 'layer4.0.conv2.scores',
 'layer4.0.shortcut.0.scores',
 'layer4.1.conv1.scores',
 'layer4.1.conv2.scores',
 'fc.scores']

p_90 = [0.5833333134651184,
 0.5920952558517456,
 0.5889756679534912,
 0.5915798544883728,
 0.5890842080116272,
 0.5913628339767456,
 0.5923936367034912,
 0.59326171875,
 0.5921698808670044,
 0.5921630859375,
 0.5916035771369934,
 0.5912339687347412,
 0.590972900390625,
 0.5909152626991272,
 0.5910322666168213,
 0.5913069248199463,
 0.5915772914886475,
 0.5910263061523438,
 0.5910623073577881,
 0.5910780429840088,
 0.6025390625]

p_70 = [0.17245370149612427,
 0.1695421040058136,
 0.1655544638633728,
 0.1667751669883728,
 0.1663953959941864,
 0.1701524555683136,
 0.1683485209941864,
 0.1728515625,
 0.1663750559091568,
 0.1691012978553772,
 0.1679484099149704,
 0.1676923930644989,
 0.169219970703125,
 0.1679840087890625,
 0.167724609375,
 0.168121337890625,
 0.16793949902057648,
 0.16888427734375,
 0.16820822656154633,
 0.16820484399795532,
 0.18828125298023224]

p_50 = [0.032407406717538834,
 0.0330403633415699,
 0.0305447056889534,
 0.0312771275639534,
 0.03106011264026165,
 0.0321723073720932,
 0.0317179374396801,
 0.0325927734375,
 0.0309583880007267,
 0.0312839075922966,
 0.0314364954829216,
 0.03124152310192585,
 0.031829833984375,
 0.0311601422727108,
 0.03113810159265995,
 0.03155517578125,
 0.031257204711437225,
 0.03125,
 0.031135134398937225,
 0.03125,
 0.03203124925494194]


p_90_acc = 11.54
p_70_acc = 8.93
p_50_acc = 9.81

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# データフレームを作成
data = pd.DataFrame({
    'Layers': layers,
    'thresh_90': p_90,
    'thresh_70': p_70,
    'thresh_50': p_50
})

# データを長い形式に変換
data_melted = data.melt(id_vars='Layers', var_name='Percentile', value_name='Values')

# プロットの作成
plt.figure(figsize=(18, 12))
sns.barplot(x='Layers', y='Values', hue='Percentile', data=data_melted)

# フォントサイズの設定
plt.xlabel('Layers', fontsize=32)
plt.ylabel('Values', fontsize=32)
plt.title('Overlap Values for different layers at thresh=90, 70, 50', fontsize=32)
plt.xticks(rotation=90, fontsize=28)
plt.yticks(fontsize=28)
plt.legend(fontsize=28)
plt.grid(True)
plt.tight_layout()

plt.show()



In [None]:
train_loader, test_loader = get_data_loaders(
    dataset_name=dataset_name, batch_size=batch_size
)

resnet = ResNet18(n_class).to(device)
resnet = modify_module_for_slth(resnet, remain_rate=np.mean(overlap_rates)).to(device)
resnet.load_state_dict(base_weight)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = resnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
print(acc)