In [1]:
from pattern_matcher import fuse
import torch
import torch.fx as fx
from utils import channel_prune, NN, dependency_grapher
import torch.nn as nn
import  timm

prune_ratio = 0.5 # Prune 50% of the channels

In [2]:
model = NN()
_ = model(torch.randn(1, 3, 64, 64))
model = fx.symbolic_trace(model)
model = fuse(model.eval())


**Build Dependency Graph for the Model**

In [3]:
dependency_graph = dependency_grapher(model)
print("Model Dependency Graph")
for prev, next in dependency_graph.items():
    print(f"Dependency Graph for {prev} -> {next}")
    print(f"Dependency Layers for {[model.get_submodule(pre) for pre in prev]} -> {[model.get_submodule(pre) for pre in next]}")

Model Dependency Graph
Dependency Graph for ('conv1',) -> ('conv2',)
Dependency Layers for [Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))] -> [Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))]
Dependency Graph for ('conv2',) -> ('conv3',)
Dependency Layers for [Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))] -> [Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))]
Dependency Graph for ('conv3',) -> ('conv4', 'conv5')
Dependency Layers for [Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))] -> [Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), Conv2d(256, 512, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))]
Dependency Graph for ('conv4',) -> ('conv6',)
Dependency Layers for [Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))] -> [Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))]
Dependency Graph for ('conv5',) -> ('conv7',)
Dependency Laye

**Prune the Model**

In [4]:
pruned_model = channel_prune(model, prune_ratio)
del model

_ = pruned_model(torch.randn(1, 3, 32, 32))


In [5]:
pruned_model

GraphModule(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv5): Conv2d(128, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv9): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv10): Conv2d(64, 256, kernel_size=(6, 6), stride=(1, 1), padding=(2, 2))
  (conv11): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)