In [1]:
!pip install torch-pruning 





In [2]:
import gc
import torch
from torchvision import models
import matplotlib.pyplot as plt
import torch.quantization
from copy import deepcopy
import torch_pruning as tp

from utils.utils import blockPrint, enablePrint

PATH = "C:\\Users\\Sejio27\\Documents\\GitHub\\Edge_Profiler\\.model"
DATA_PATH = "C:\\Users\\Sejio27\\Documents\\GitHub\\Edge_Profiler\\.data"

### Import Libraries

In [3]:
from profiler.cuda import CUDA
from dataset.imagenet import preprocess_data, validate_accuracy

batch_size = 1
image_w, image_h = 224, 224

# device
device = torch.device('cuda')

# profiler
prof = CUDA(track_energy=False, track_flops=True, disable_warmup=True)
prof.disable_print = True

# datasetloader
dataloader = preprocess_data(DATA_PATH, 256, (image_w, image_h))

### Calculate FLOPs

In [None]:
def evaluate_flops(model, inputs):
    prof.start_profiling()
    _ = model(inputs)
    prof.stop_profiling()
    
    return prof.total_flops()

### FLOPs Based Prune using DepGraph

In [4]:
def prune_with_DepGraph(model, inputs, flops_threshold, pruning_ratio=0.01):
    
    if(pruning_ratio <= 0 or pruning_ratio >= 1) or (
        flops_threshold <= 0 or flops_threshold >= 1):
        return
    
    base_flops = evaluate_flops(model, inputs)
    th_flops = (1 - flops_threshold) * base_flops
    print(f'Base: {base_flops/10**9:.2f} GFLOPs')
    print(f'Target: {th_flops/10**9:.2f} GFLOPs')
    
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
            ignored_layers.append(m) # DO NOT prune the final classifier!
            
    imp = tp.importance.TaylorImportance()
    
    flops = base_flops
    prev_model = None
    
    while(flops > th_flops):
        prev_model = deepcopy(model)
        
        pruner = tp.pruner.MagnitudePruner(
            model,
            inputs,
            importance=imp,
            pruning_ratio=pruning_ratio, # remove pruning_ratio in percentage
            ignored_layers=ignored_layers,
        )
        
        if isinstance(imp, tp.importance.TaylorImportance):
            # Taylor expansion requires gradients for importance estimation
            loss = model(inputs).sum() # a dummy loss for TaylorImportance
            loss.backward() # before pruner.step()
            pruner.step()

        flops = evaluate_flops(model, inputs)
        print(f'Pruned to: {flops/10**9:.2f} GFLOPs')
        gc.collect()
        torch.cuda.empty_cache()
        
    if(flops < th_flops):
        print(f'Pruned beyond target FLOPs. back to previous FLOPs.')
        del model
        return prev_model
    
    del prev_model
    return model

In [6]:
# load model
model_file = 'mobilenet'
model = torch.load(PATH + f"\\{model_file}.pt")
model.to(device)

# dummy inputs
batch_size = 1
image_w, image_h = 224, 224
inputs = torch.rand(batch_size,3,image_w, image_h)
inputs = inputs.to(device)

# threshold
flops_threshold = 0.050 # 5% of base FLOPs

# prunning start
model = prune_with_DepGraph(model, inputs, flops_threshold=flops_threshold)

# save pruned model
model_file =  f"{model_file}_pruned_050"
torch.save(model, f"{PATH}\\{model_file}.pt")

Base: 0.43 GFLOPs
Target: 0.41 GFLOPs
Pruned to: 0.42 GFLOPs
Pruned to: 0.40 GFLOPs
Pruned beyond target FLOPs. back to previous FLOPs.
