In [None]:
# default_exp model

In [None]:
# hide
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import pandas as pd 
import numpy as np
from nbdev.showdoc import show_doc
from tqdm import tqdm 

# model 

> help library

In [None]:
# export 
class ForwardHook():
    def __init__(self, module, name:None, activation:bool, stats:bool):
        ''' Track stats and activations of any layer in model '''
        self.hook = module.register_forward_hook(self.hook_fn)
        self.status = 'active'
        self.name = name
        self.activation_status = activation
        self.stats = stats
        if self.activation_status: self.activations = []
        if stats: self.means, self.stds = [], []

    def hook_fn(self, module, input, output):
        if self.activation_status: self.activations.append ( output[1].detach().cpu().numpy() )
        if self.stats: 
          self.means += output[1].mean(1).detach().cpu().numpy().squeeze().tolist() 
          self.stds += output[1].std(1).detach().cpu().numpy().squeeze().tolist()

    def _stack_activations(self):
        self.activations = np.vstack(self.activations)

    def close(self):
        self.hook.remove()
        self.status = 'removed'

Inputs: <br>
- module(nn.Module): Module example: model.linear1
- name(str): Tag 
- activation(bool): if True, it will track module output
- stats(bool): if True, it will track module output mean and std 

Output: <br> 
- ForwardHook.means(list): list of means
- ForwardHook.stds(list): list of std 
- ForwardHook.status(str): Hook status (active/removed)

`_stack_activations()`: Stack all output(activations) once all outputs are saved in ForwardHook.activations

This class can be inheritet this class and define your own `hook_fn`

## Collect intermediate activations and stats

In [None]:
#export
def get_linear_layer_activations_states(dls, model, layers:list, stats:bool, activation:bool, output:bool, remove_hooks=True):
  ''' useful when model is trained and want to analyze intermediate layers'''

  hooks = dict()
  if output: output_list = []

  for layer in layers:
    if hasattr(model, layer): hooks[f'{layer}'] = ForwardHook ( getattr(model, layer),layer, activation=activation, stats=stats ) 

  model.eval()
  i = 0
  for batch in tqdm(dls, desc='Fatching data: '):
    with torch.no_grad():
        i+=1
        out = model(batch)
        if output: output_list.append( (out) ) 

  for k,v in hooks.items():
    if activation: v._stack_activations()
    if remove_hooks: v.close()

  if output: 
      return hooks, output_list
  else: 
      return hooks

Example: 
```
class BertSentiModel(nn.Module):

  def __init__(self, *args, **kwargs):

    super(BertSentiModel, self).__init__()
    self.bert_model = transformers.BertModel.from_pretrained(pretrained_model, return_dict=False)
    self.kwargs = kwargs
    self.lin1 = nn.Linear(768, 10)
    self.lin2 = nn.Linear(768, 10)
    self.lin3 = nn.Linear(10, 1)
    self.lin4 = nn.Linear(10, 1)

  def forward(self,batch):
    ids, mask, token_type_ids = batch['ids'], batch['mask'], batch['token_type_ids']
    _, x = self.bert_model(ids, attention_mask=mask, token_type_ids=token_type_ids)
    x1 = self.lin3(F.relu(self.lin1(x)))
    x2 = self.lin4(F.relu(self.lin2(x)))
    return x1, x2
    
model = BertSentiModel()
model.load_state_dict(state_dict)
hooks, output_list = get_linear_layer_activations_states(dataloader, model, ['lin1', 'bert_model', 'lin2'], stats=True, activation=True, output=True, remove_hooks=True) 
```