In [1]:
import os
from models.resnet_real import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from models.resnet_quat import ResNet18_quat, ResNet34_quat, ResNet50_quat, ResNet101_quat, ResNet152_quat
from oth import *
from utils.pruning import prune_model
from tqdm import tqdm, trange

models = {
    18: {"real": ResNet18, "quat": ResNet18_quat},
    34: {"real": ResNet34, "quat": ResNet34_quat},
    50: {"real": ResNet50, "quat": ResNet50_quat},
    101: {"real": ResNet101, "quat": ResNet101_quat},
    152: {"real": ResNet152, "quat": ResNet152_quat},
}

DEVICE = torch.device("cuda:0")
save_path = "saved_models"

def rule(x):
    a = x[:-3].split("_")[-1]
    try: return int(a)
    except: return 0

In [2]:
for model_type in models:
    for realorquat in models[model_type]:
        m = models[model_type][realorquat](num_classes=100, name=f"RN{model_type}_{realorquat}").to(DEVICE)  # this is a model with random weights
        load_from = f"{save_path}/RN{model_type}"
        save_to = f"{save_path}/{m.name}_prune"
        os.mkdir(save_to)
        pruned = False
        for pruneV in tqdm(
            sorted(
                [x for x in os.listdir(load_from) if realorquat in x],
                key=rule
            ),
            desc=f"RN{model_type}_{realorquat}",
            unit = "models"
        ):
            m.load_state_dict(torch.load(f"{load_from}/{pruneV}", map_location=DEVICE))
            torch.save(m, f"{save_to}/{pruneV}")
            if not pruned: prune_model(m, 0.99)

RN18_real: 100%|██████████| 21/21 [00:03<00:00,  5.72models/s]
RN18_quat: 100%|██████████| 21/21 [00:02<00:00,  8.07models/s]
RN34_real: 100%|██████████| 21/21 [00:06<00:00,  3.04models/s]
RN34_quat: 100%|██████████| 21/21 [00:04<00:00,  4.38models/s]
RN50_real: 100%|██████████| 21/21 [00:08<00:00,  2.57models/s]
RN50_quat: 100%|██████████| 21/21 [00:06<00:00,  3.30models/s]
RN101_real: 100%|██████████| 23/23 [00:17<00:00,  1.30models/s]
RN101_quat: 100%|██████████| 21/21 [00:12<00:00,  1.66models/s]
RN152_real: 100%|██████████| 23/23 [00:25<00:00,  1.12s/models]
RN152_quat: 100%|██████████| 18/18 [00:15<00:00,  1.14models/s]
