In [4]:
import torch
from transformers import BertModel, BertForMaskedLM, BertTokenizer, BertConfig
import transformers
import numpy as np

In [5]:
def forward_from_specific_layer(model: transformers.BertModel, layer_number: int, 
                                layer_representation: torch.Tensor):
    """
   :param model: a BertForMaskedLM model
   :param layer_representation: a torch tensor, dims: [1, seq length, 768]
   Return:
           states, a numpy array. dims: [#LAYERS - layer_number, seq length, 768]
           last_state_after_batch_norm: np array, after batch norm. dims: [seq_length, 768]
   """

    
    layers = model.bert.encoder.layer[layer_number:]
    layers.append(model.cls.predictions.transform)

    h = layer_representation
    states = []

    with torch.no_grad():
        for i, layer in enumerate(layers):
            h = layer(h)[0] if i != len(layers) - 1 else layer(h)
            states.append(h)

    for i, s in enumerate(states):
        states[i] = s.detach().cpu().numpy()

    states = np.array(states)
    for x in states:
        assert len(x.shape) == 3

    return states.squeeze(1)


def intervene_in_layer(model: transformers.BertModel, tokens: torch.Tensor, layer_number: int, 
                       projection_matrix: torch.tensor, apply_on_all=True):
    """
    Intervening in the representations at layer layer_number.
    :param model: a bert model
    :layer_number: the layer on which we intervene.
    :param projection_matrix
    :param apply_on_all: if True, apply projection on all tokens. O/w, apply on the CLS only.
    Returns: the all hidden representations of the model.
    """
    
    # extract representation at layer i
    
    with torch.no_grad():
        outputs = model(tokens, return_dict = True)
    hidden_states = outputs["hidden_states"]
    hidden_state_layer_i = hidden_states[layer_number][0]
    if apply_on_all:
        
        hidden_state_layer_i = hidden_state_layer_i @ projection_matrix
    else:
        hidden_state_layer_i[0] = hidden_state_layer_i[0] @ projection_matrix
    
    # continue the forward pass
    hidden_state_layer_i = hidden_state_layer_i.unsqueeze(0) # add empty batch dim
    hidden_after_projection_i_onwards = forward_from_specific_layer(model, layer_number, hidden_state_layer_i)
    hidden_states_until_i = torch.stack(hidden_states[:layer_number]).squeeze(1)
    hidden_states_after_projection = torch.cat([hidden_states_until_i, torch.tensor(hidden_after_projection_i_onwards)])
    return hidden_states_after_projection

### Usage example

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("bert-base-uncased", output_hidden_states = True)
bert = BertForMaskedLM.from_pretrained("bert-base-uncased", config = config)
tokens = torch.tensor([tokenizer.encode("To be or not to be, that is the question.")])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
P = torch.randn(768, 768)
states = intervene_in_layer(bert, tokens, 3, P, apply_on_all = True)
print(states.shape)

torch.Size([13, 14, 768])
