# Demo: Extract intermediate representations

## 1. Import torch

In [1]:
import torch

## 2. Define a random model
e.g., ResNet-18 in torchvision

In [2]:
from torchvision import models

model = models.resnet18(pretrained=False)

## 3. Define a forward hook manager
Let's say we want to extract representations from the following layers (modules) in [ResNet-18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L230-L249):  
- "conv1": input
- "layer1.0.bn2": input and output
- "fc": output

In [3]:
from torchdistill.core.forward_hook import ForwardHookManager

device = torch.device('cpu')
forward_hook_manager = ForwardHookManager(device)
forward_hook_manager.add_hook(model, 'conv1', requires_input=True, requires_output=False)
forward_hook_manager.add_hook(model, 'layer1.0.bn2', requires_input=True, requires_output=True)
forward_hook_manager.add_hook(model, 'fc', requires_input=False, requires_output=True)

## 4. Execute forward function of the model

Note that ResNet-18's input shape for ImageNet dataset is 3x224x224.  
Here we use batch size of 32 for a random input batch *x*

In [4]:
x = torch.rand(32, 3, 224, 224)
y = model(x)

## 5. Get I/O dictionary from the manager
There should be three keys: "conv1", "layer1.bn2", and "fc"

In [5]:
io_dict = forward_hook_manager.pop_io_dict()
print(io_dict.keys())

dict_keys(['conv1', 'layer1.0.bn2', 'fc'])


## 6. Make sure that input to "conv1" matches *x*

In [6]:
print(io_dict['conv1'].keys())
conv1_input = io_dict['conv1']['input']
print(torch.equal(x, conv1_input))

dict_keys(['input'])
True


## 7. Similarly, make sure that output from "fc" matches *y*

In [7]:
print(io_dict['fc'].keys())
fc_output = io_dict['fc']['output']
print(torch.equal(y, fc_output))

dict_keys(['output'])
True


## 8. Check if the extracted input/output tensor from "layer1.0.bn2"...

In [8]:
print(io_dict['layer1.0.bn2'].keys())
layer1_bn2_input = io_dict['layer1.0.bn2']['input']
layer1_bn2_output = io_dict['layer1.0.bn2']['output']

dict_keys(['input', 'output'])


To check if the extracted tensors match those used in the model, you'll need to reimplement [ResNet-18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L230-L249) to change the interface of forward function in **ResNet** and **BasicBlock** classes so that you can get the intermediate tensors in addition to *y* from ***model*** shown in Section 4.  
You don't want to, right? That's why you want to use **ForwardHookManager** instead.