In [34]:
from utils import count_nonzero_parameters, load_untrained_model, calculate_score, combined_pruning, nn, prune, get_macs
from ptflops import get_model_complexity_info
import torch

def combined_pruning(model: nn.Module, amount_structured: float, amount_unstructured: float):
    parameters_to_prune = []

    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.ln_structured(module, name='weight', amount=amount_structured, n=2, dim=0)
            parameters_to_prune.append((module, 'weight')) 

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount_unstructured,
    )

    for module, _ in parameters_to_prune:
        prune.remove(module, 'weight')

def get_sparse_macs(model, input_size=(1, 3, 32, 32), device: str = "cuda"):
    macs = 0
    dummy_input = torch.randn(*input_size).to(device)
    
    def conv_hook(module, input, output):
        nonlocal macs
        batch_size, out_channels, out_h, out_w = output.shape
        weight = module.weight
        
        # Count non-zero weights
        nz = (weight != 0).float()
        nz_per_filter = nz.sum(dim=(1,2,3))  # [out_channels]
        
        # MACs = output_size * (non_zero_weights per filter)
        layer_macs = out_h * out_w * nz_per_filter.sum().item()
        macs += layer_macs

    def linear_hook(module, input, output):
        nonlocal macs
        weight = module.weight
        macs += (weight != 0).sum().item()  # Direct count of non-zero weights

    hooks = []
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            hooks.append(layer.register_forward_hook(conv_hook))
        elif isinstance(layer, nn.Linear):
            hooks.append(layer.register_forward_hook(linear_hook))

    model(dummy_input)
    
    for hook in hooks:
        hook.remove()
        
    return macs

In [29]:
dict_ = load_untrained_model("DenseNet121")
og_model = dict_["model"]

In [36]:
f, w = get_model_complexity_info(og_model, (3, 32, 32), as_strings=False,
                                         print_per_layer_stat=False, verbose=False)
print(f,w, get_sparse_macs(og_model), get_macs(og_model))
count_nonzero_parameters(og_model)

903157770 6956298 888350464.0 37758976


6914537

In [37]:
combined_pruning(og_model, amount_structured=0.2, amount_unstructured=0.3)

f, w = get_model_complexity_info(og_model, (3, 32, 32), as_strings=False,
                                         print_per_layer_stat=False, verbose=False)
print(f,w, get_sparse_macs(og_model), get_macs(og_model))
count_nonzero_parameters(og_model)

903157770 6956298 514262239.0 37758976


3900686

In [30]:
import pickle

with open("train_results_increasing_grouped_densenet121.pkl", "rb") as f:
    res = pickle.load(f)

In [41]:
from factorisation.densenet import get_increasing_grouped_densenet121, get_increasing_transition_grouped_densenet121


model = get_increasing_grouped_densenet121()
model.to("cuda")

DenseNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (dense1): Sequential(
    (0): GroupedCroissantBottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), groups=64, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    )
    (1): GroupedCroissantBottleneck(
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    )
    (2): GroupedCroissantBottl

In [35]:
f, w = get_model_complexity_info(model, (3, 32, 32), as_strings=False,
                                         print_per_layer_stat=False, verbose=False)
print(f,w, get_sparse_macs(model), get_macs(model))
count_nonzero_parameters(model)

138422282 938762 123615232.0 37758976


897002

In [43]:
from utils import remove_pruning


combined_pruning(model, amount_structured=0.2, amount_unstructured=0.3)
# remove_pruning(model)

In [13]:
f, w = get_model_complexity_info(model, (3, 32, 32), as_strings=False,
                                         print_per_layer_stat=False, verbose=False)
print(f,w)
count_nonzero_parameters(model)

138422282 938762


521374

In [38]:
ops_ref, params_ref = get_sparse_macs(og_model), count_nonzero_parameters(og_model)

In [44]:
f, w = get_sparse_macs(model), count_nonzero_parameters(model)

calculate_score(
    0.2, 0.3, 16, 16, w, f, params_ref, ops_ref
)

0.09194188139972875

In [7]:
with open("train_results_increasing_transition_grouped_densenet121.pkl", "rb") as f:
    res = pickle.load(f)

res["acc"]

78.66

In [8]:
with open("train_results_transition_grouped_densenet121.pkl", "rb") as f:
    res = pickle.load(f)

res["acc"]

90.37

89.38

In [1]:
import torch

In [2]:
res = torch.load("train_checkpoint/model_train_grouped1.pth")

res["acc"]

89.38