In [2]:
import torch
import torch
from torch import nn
from collections import OrderedDict

In [3]:
class LayerInspector(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.layers = OrderedDict()
        self._register_hooks()

    def _register_hooks(self):
        def hook(name):
            def fn(module, input, output):
                self.layers[name] = {
                    'input': input,
                    'output': output,
                    'module': module
                }
            return fn

        for name, module in self.model.named_modules():
            if list(module.children()) == []:  # leaf module
                module.register_forward_hook(hook(name))

    def forward(self, x):
        return self.model(x)

# Example usage:
def inspect_model(model, sample_input):
    inspector = LayerInspector(model)
    inspector(sample_input)
    return inspector.layers


In [9]:
sample_input = torch.randn(1, 3, 32, 32)
from quant.models import SimpleCNN
model = SimpleCNN().to('cpu'
                       '')
model.load_state_dict(torch.load('data/weights/best_model.pth',map_location=torch.device('cpu'),weights_only=True ))

layer_info = inspect_model(model, sample_input)


In [10]:
for name, info in layer_info.items():
    print(f"Layer: {name}")
    print(f"  Module: {info['module']}")
    print(f"  Input shape: {[tuple(t.shape) for t in info['input']]}")
    print(f"  Output shape: {tuple(info['output'].shape)}")
    print()

Layer: conv1
  Module: Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  Input shape: [(1, 3, 32, 32)]
  Output shape: (1, 32, 32, 32)

Layer: pool
  Module: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  Input shape: [(1, 64, 16, 16)]
  Output shape: (1, 64, 8, 8)

Layer: conv2
  Module: Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  Input shape: [(1, 32, 16, 16)]
  Output shape: (1, 64, 16, 16)

Layer: fc1
  Module: Linear(in_features=4096, out_features=512, bias=True)
  Input shape: [(1, 4096)]
  Output shape: (1, 512)

Layer: fc2
  Module: Linear(in_features=512, out_features=10, bias=True)
  Input shape: [(1, 512)]
  Output shape: (1, 10)

