In [2]:
import torch
from torch import nn, Tensor


In [3]:
model = nn.Sequential(nn.Linear(32, 24), nn.ReLU(), nn.Linear(24, 16), nn.ReLU())


In [4]:
x = torch.randn(8, 30, 32)
y = model(x)
print(y.shape)


torch.Size([8, 30, 16])


In [5]:
# iterating through all of the modules
for name, module in model.named_modules():
    print(name, type(module), isinstance(module, nn.Module))


 <class 'torch.nn.modules.container.Sequential'> True
0 <class 'torch.nn.modules.linear.Linear'> True
1 <class 'torch.nn.modules.activation.ReLU'> True
2 <class 'torch.nn.modules.linear.Linear'> True
3 <class 'torch.nn.modules.activation.ReLU'> True


In [6]:
layer = model[0]
print(layer)


Linear(in_features=32, out_features=24, bias=True)


In [7]:
def hook_fn(module: nn.Module, args, output):
    print(f"Hello! we are inside of a forward hook. My module is {module}, the inputs are {args}, the outputs are {output}")


In [8]:
handle = layer.register_forward_hook(hook_fn)


In [9]:
model(torch.randn(1, 32))


Hello! we are inside of a forward hook. My module is Linear(in_features=32, out_features=24, bias=True), the inputs are (tensor([[ 0.4668,  1.5062, -1.6923,  1.3530, -0.7549,  1.6089,  1.5451,  0.5262,
         -0.6871,  0.3373, -0.8215, -0.3513, -1.1234, -1.4878, -0.3461, -1.4130,
          0.6693,  0.0783, -1.1228, -0.2543,  0.6861,  0.5037, -0.6093, -0.2720,
         -0.3580, -0.5863,  1.1232, -1.8202, -0.0596,  0.8736, -1.1066,  1.7741]]),), the outputs are tensor([[ 0.5048,  0.3148, -0.9752, -0.5608,  0.4318, -0.4555, -0.5908, -0.0170,
          1.1975, -0.4305,  0.1371, -0.0959, -0.5939,  0.8271,  1.8683, -0.4340,
          0.7912,  0.3859,  0.3985, -0.9812, -0.2133, -0.2314, -0.4976, -0.4333]],
       grad_fn=<AddmmBackward0>)


tensor([[0.0859, 0.0000, 0.0000, 0.0000, 0.3881, 0.6115, 0.4115, 0.0000, 0.0000,
         0.0000, 0.0220, 0.0000, 0.0000, 0.5266, 0.0000, 0.1157]],
       grad_fn=<ReluBackward0>)

In [10]:
handle.remove()


In [11]:
model(torch.randn(1, 32))


tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0580, 0.5276, 0.0815, 0.0000, 0.1239,
         0.0000, 0.1051, 0.3343, 0.0000, 0.1487, 0.0000, 0.0045]],
       grad_fn=<ReluBackward0>)

In [12]:
# Putting it all together:
# let's iterate through each module and add a forward hook that just prints the layer


In [13]:
def hook_fn_print_layer(module: nn.Module, args, output):
    print(module)

def prehook_fn_print_layer(module: nn.Module, args):
    print("prehook", module)


In [14]:
handles = {}
for name, module in model.named_modules():
    # handles[name] = module.register_forward_hook(hook_fn_print_layer)
    handles[name] = module.register_forward_pre_hook(prehook_fn_print_layer)


In [15]:
model(torch.randn(1, 32));


prehook Sequential(
  (0): Linear(in_features=32, out_features=24, bias=True)
  (1): ReLU()
  (2): Linear(in_features=24, out_features=16, bias=True)
  (3): ReLU()
)
prehook Linear(in_features=32, out_features=24, bias=True)
prehook ReLU()
prehook Linear(in_features=24, out_features=16, bias=True)
prehook ReLU()


In [16]:
for handle in handles.values():
    handle.remove()


In [17]:
# mini prototype of OBS
# our overall goal is to build a container that collects matrix weights, input vectors, and output vectors for each
# nn.Linear in the graph

# OBS will iterate through these in order and prune the layers one-by-one


In [18]:
from dataclasses import dataclass

@dataclass
class OBSLinearCache:
    name: str = None
    weight: Tensor = None
    input: Tensor = None
    output: Tensor = None
    module: nn.Linear = None

def get_layer_hook(name: str):
    cache = OBSLinearCache()
    def hook_fn(module, args, outputs):
        cache.module = module
        cache.name = name
        cache.input = args
        cache.output = outputs
        cache.weight = module.weight

    return hook_fn, cache


In [19]:
caches = {}
hooks = {}

for name, module in model.named_modules():
    if not isinstance(module, nn.Linear):
        continue

    hook_fn, cache = get_layer_hook(name)
    caches[name] = cache
    hooks[name] = module.register_forward_hook(hook_fn)


In [20]:
caches


{'0': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None),
 '2': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None)}

In [21]:
model(torch.randn(8, 32))
for handle in handles.values():
    handle.remove()


In [22]:
caches


{'0': OBSLinearCache(name='0', weight=Parameter containing:
 tensor([[-3.0700e-02,  1.7464e-02, -1.4900e-01,  1.0195e-01, -1.0894e-01,
           1.3645e-01, -9.9031e-02,  1.3497e-01, -5.0742e-02,  6.2458e-02,
          -1.7630e-01,  4.7019e-02,  5.9899e-02, -1.0123e-01, -1.2802e-01,
           6.1456e-02,  1.3629e-01,  6.3063e-02,  1.4376e-01, -3.8462e-02,
           1.7458e-01, -1.7630e-01,  8.1967e-02,  1.3386e-01,  1.7037e-01,
           9.4823e-02,  5.2966e-04,  6.5677e-02, -8.5024e-02,  8.4206e-02,
           6.0984e-02,  3.6805e-02],
         [ 6.2203e-02, -1.7380e-01, -6.5445e-02,  2.0842e-02, -1.2105e-01,
          -5.1732e-02, -6.5871e-02, -2.2648e-02,  1.5204e-02,  1.3887e-02,
           1.3138e-01, -1.4400e-01,  1.1331e-01,  1.5506e-01,  8.1966e-02,
          -1.7560e-01, -6.4073e-02, -4.2885e-02,  1.0424e-02,  2.8951e-02,
           6.7926e-02, -7.9798e-02,  3.8113e-02,  5.8120e-03, -6.4037e-02,
          -4.8705e-02,  1.7367e-01, -1.2617e-01, -1.2156e-02,  8.7100e-02,
   