## Tutorial of using Mistral with this library

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frankaging/align-transformers/blob/main/tutorials/Hook%20with%20new%20model%20and%20intervention%20types.ipynb)

In [1]:
__author__ = "Zhengxuan Wu and Ruixiang Cui"
__version__ = "10/05/2023"

### Overview

This library only supports a set of library as a priori. We allow users to add new model architectures to do intervention-based alignment training, and static path patching analyses. This tutorial shows how to deal with new model type that is not pre-defined in this library.

**Note that this tutorial will not add this new model type to our codebase. Feel free to open a PR to do that!**

In [4]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys
    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys
    sys.path.append("align-transformers/")

In [2]:
import sys
sys.path.append("..")

import torch
import pandas as pd
from models.constants import CONST_OUTPUT_HOOK
from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import Intervention, VanillaIntervention
from models.utils import lsm, sm, top_vals, format_token, type_to_module_mapping, \
    type_to_dimension_mapping, output_to_subcomponent_fn_mapping, \
    scatter_intervention_output_fn_mapping, simple_output_to_subcomponent, \
    simple_scatter_intervention_output

%config InlineBackend.figure_formats = ['svg']
from plotnine import ggplot, geom_tile, aes, facet_wrap, theme, element_text, \
                     geom_bar, geom_hline, scale_y_log10

### Try on new model type Mistral

In [3]:
def create_mistral(name="mistralai/Mistral-7B-Instruct-v0.1", cache_dir="../../.huggingface_cache"):
    """Creates a mistral model, config, and tokenizer from the given name and revision"""
    from transformers import MistralForCausalLM, AutoTokenizer, MistralConfig
    
    config = MistralConfig.from_pretrained(name)
    tokenizer = AutoTokenizer.from_pretrained(name)
    mistral = MistralForCausalLM.from_pretrained(name, config=config, cache_dir=cache_dir)
    mistral.bfloat16()
    print("loaded model")
    return config, tokenizer, mistral

def embed_to_distrib_mistral(embed, log=False, logits=False):
    """Convert an embedding to a distribution over the vocabulary"""
    with torch.inference_mode():
        vocab = embed
        if logits:
            return vocab
        return lsm(vocab) if log else sm(vocab)

config, tokenizer, mistral = create_mistral()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded model


In [38]:
base = "The capital of Spain is"
source = "The capital of Italy is"
inputs = [
    tokenizer(base, return_tensors="pt"),
    tokenizer(source, return_tensors="pt")
]
print(base)
res = mistral(**inputs[0])
distrib = embed_to_distrib_mistral(res.logits, logits=False)
top_vals(tokenizer, distrib[0][-1], n=10)
print()
print(source)
res = mistral(**inputs[1])
distrib = embed_to_distrib_mistral(res.logits, logits=False)
top_vals(tokenizer, distrib[0][-1], n=10)

The capital of Spain is
a                    0.20014575123786926
known                0.14642974734306335
Madrid               0.1292238086462021
one                  0.05734271928668022
full                 0.04753878340125084
not                  0.03267289698123932
an                   0.03267289698123932
famous               0.02109520696103573
also                 0.019817113876342773
renown               0.017488541081547737

The capital of Italy is


KeyboardInterrupt: 

### To add mistral, you only need the following block

In [4]:
# """Only define for the block output here for simplicity"""
# type_to_module_mapping[type(mistral)] = {
#     "mlp_output": ("encoder.block[%s].layer[1]", CONST_OUTPUT_HOOK),
#     "attention_input": ("encoder.block[%s].layer[0]", CONST_OUTPUT_HOOK),
# }
# type_to_dimension_mapping[type(mistral)] = {
#     "mlp_output": ("config.d_model", ),
#     "attention_input": ("config.d_model", ),
# }
# output_to_subcomponent_fn_mapping[type(mistral)] = simple_output_to_subcomponent           # no special subcomponent
# scatter_intervention_output_fn_mapping[type(mistral)] = simple_scatter_intervention_output # no special scattering

### Path patching with mistral

In [4]:
print(mistral.config)

MistralConfig {
  "_name_or_path": "Open-Orca/Mistral-7B-OpenOrca",
  "architectures": [
    "MistralForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 32000,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.34.0",
  "use_cache": true,
  "vocab_size": 32002
}


In [5]:
def simple_position_config(model_type, intervention_type, layer):
    alignable_config = AlignableConfig(
        alignable_model_type=model_type,
        alignable_representations=[
            AlignableRepresentationConfig(
                layer,             # layer
                intervention_type, # intervention type
                "pos",             # intervention unit
                1                  # max number of unit
            ),
        ],
        alignable_interventions_type=VanillaIntervention,
    )
    return alignable_config
base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]

In [25]:
mistral.config.num_hidden_layers

32

In [28]:
def remove_forward_hooks(main_module: nn.Module):
    """Function to remove all forward and pre-forward hooks from a module and its sub-modules."""
    # Remove forward hooks
    for _, submodule in main_module.named_modules():
        if hasattr(submodule, "_forward_hooks"):
            hooks = list(submodule._forward_hooks.keys()) 
            for hook_id in hooks:
                submodule._forward_hooks.pop(hook_id)

        # Remove pre-forward hooks
        if hasattr(submodule, "_forward_pre_hooks"):
            pre_hooks = list(submodule._forward_pre_hooks.keys()) 
            for pre_hook_id in pre_hooks:
                submodule._forward_pre_hooks.pop(pre_hook_id)

remove_forward_hooks(mistral)

[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[

In [None]:
# should finish within 1 min with a standard 12G GPU
tokens = tokenizer.encode("Madrid Rome")[:2]

data = []
for layer_i in range(mistral.config.num_hidden_layers):
    print("layer_i", layer_i)
    alignable_config = simple_position_config(type(mistral), "mlp_output", layer_i)
    alignable = AlignableModel(alignable_config, mistral)
    for pos_i in range(len(base.input_ids[0])):
        # print(base)
        # print(sources)
        _, counterfactual_outputs = alignable(
            base,
            sources,
            {"sources->base": ([[[pos_i]]], [[[pos_i]]])}
        )
        distrib = embed_to_distrib_mistral(
            counterfactual_outputs.logits, 
            logits=False
        )
        print("distrib", distrib)
        for token in tokens:
            data.append({
                'token': format_token(tokenizer, token),
                'prob': float(distrib[0][-1][token]),
                'layer': f"f{layer_i}",
                'pos': pos_i,
                'type': "mlp_output"
            })
        print("data", data)  
        
    alignable_config = simple_position_config(type(mistral), "attention_input", layer_i)
    alignable = AlignableModel(alignable_config, mistral)
    for pos_i in range(len(base.input_ids[0])):
        print("pos_i", pos_i)
        _, counterfactual_outputs = alignable(
            base,
            sources,
            {"sources->base": ([[[pos_i]]], [[[pos_i]]])}
        )
        distrib = embed_to_distrib_mistral(
            counterfactual_outputs.logits, 
            logits=False
        )
        for token in tokens:
            data.append({
                'token': format_token(tokenizer, token),
                'prob': float(distrib[0][-1][token]),
                'layer': f"a{layer_i}",
                'pos': pos_i,
                'type': "attention_input"
            })
        print("data", data) 
df = pd.DataFrame(data)

layer_i 0
distrib tensor([[[7.2409e-09, 7.9525e-09, 3.5641e-08,  ..., 4.6484e-08,
          4.5956e-06, 3.3291e-07],
         [7.6336e-10, 4.3495e-10, 2.6341e-07,  ..., 2.2309e-08,
          2.4501e-08, 9.8282e-05],
         [8.5927e-10, 5.2117e-10, 2.2207e-07,  ..., 3.0943e-09,
          2.4864e-09, 9.5312e-09],
         [8.5909e-10, 5.3760e-10, 4.8394e-08,  ..., 1.0466e-08,
          1.9974e-09, 2.7475e-06],
         [8.4567e-11, 5.2921e-11, 1.5742e-08,  ..., 8.5410e-10,
          6.4471e-10, 1.5643e-11],
         [5.2970e-10, 7.2401e-10, 8.9082e-08,  ..., 1.7269e-08,
          5.8844e-12, 6.8220e-16]]])
data [{'token': '<s>', 'prob': 7.240146815412629e-10, 'layer': 'f0', 'pos': 0, 'type': 'mlp_output'}, {'token': 'Madrid', 'prob': 0.1292238086462021, 'layer': 'f0', 'pos': 0, 'type': 'mlp_output'}]
distrib tensor([[[7.2409e-09, 7.9525e-09, 3.5641e-08,  ..., 4.6484e-08,
          4.5956e-06, 3.3291e-07],
         [7.6336e-10, 4.3495e-10, 2.6341e-07,  ..., 2.2309e-08,
          2.4501e

KeyboardInterrupt: 

In [None]:
df['layer'] = df['layer'].astype('category')
df['token'] = df['token'].astype('category')
nodes = []
for l in range(mistral.config.num_hidden_layers - 1, -1, -1):
    nodes.append(f'f{l}')
    nodes.append(f'a{l}')
df['layer'] = pd.Categorical(df['layer'], categories=nodes[::-1], ordered=True)

g = (ggplot(df) + geom_tile(aes(x='pos', y='layer', fill='prob', color='prob')) +
     facet_wrap("~token") + theme(axis_text_x=element_text(rotation=90)))
print(g)

In [13]:
filtered = df
filtered = filtered[filtered['pos'] == 4]
g = (ggplot(filtered) + geom_bar(aes(x='layer', y='prob', fill='token'), stat='identity')
         + theme(axis_text_x=element_text(rotation=90), legend_position='none') + scale_y_log10()
         + facet_wrap("~token", ncol=1))
# save as pdf
print(g)
g.save("mistral.pdf", width=10, height=10)
print("PDF saved.")

NameError: name 'df' is not defined