In [1]:
from tqdm.notebook import tqdm

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

from short_hf import ShortHFModel

### Load Data

In [2]:
data = load_dataset("pg19", split="validation")  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size=1,
    shuffle=True,
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Load Model

In [3]:
MAX_SEQ_LEN = 1024
short_model = ShortHFModel(
    model_name="mistralai/Mistral-7B-v0.1",
    layers_path="model.layers"
)
short_model.model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [4]:
short_model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [5]:
# sample generation
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=20
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  attn_output = torch.nn.functional.scaled_dot_product_attention(


['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and']

### Compute Importances

In [6]:
for i, batch in enumerate(tqdm(dataloader)):
    # to speed up experiments, change as you please
    if i == 5:
        break

    prompts = batch['text']

    short_model.eval_importance(
        prompts=prompts,
        max_seq_len=MAX_SEQ_LEN,
        stride=256,
        max_gen_len=0
    )

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
short_model.importances

[706592.453125,
 425593.15625,
 317297.3046875,
 272700.7421875,
 313270.2265625,
 281193.0625,
 259303.7890625,
 247235.203125,
 230518.4765625,
 207359.1015625,
 206604.88671875,
 191179.20703125,
 176394.51953125,
 171003.3359375,
 182990.56640625,
 170447.859375,
 175557.640625,
 170004.4140625,
 183532.1484375,
 161545.421875,
 123509.1484375,
 95928.953125,
 81660.552734375,
 83030.16796875,
 70967.033203125,
 68553.78125,
 68749.0234375,
 79548.078125,
 94246.546875,
 110476.56640625,
 113710.2578125,
 377803.671875]

### Remove unimportant layers

Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]

Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [8]:
short_model.remove_layers(num_layers=9)

[25, 26, 24, 27, 22, 23, 28, 21, 29]

In [10]:
short_model.layers

ModuleList(
  (0-22): 23 x MistralDecoderLayer(
    (self_attn): MistralSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [12]:
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=20,
    use_cache=False  # TODO: currently necessary since caching accesses removed layer indices
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['I am an avid fan of 2015-16 TVB TV W5DW, the 22-']