In [1]:
from tqdm.notebook import tqdm

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

from llama import Llama

from short_llama import ShortLlama

### Load Data

In [2]:
data = load_dataset("pg19", split="validation")  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator(device="cuda")
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Fetch and Wrap Model

In [3]:
MAX_SEQ_LEN = 1024  # authors use a context width of 1024
llama = Llama.build(
    ckpt_dir="../../llama/llama-2-7b",
    tokenizer_path="../../llama/tokenizer.model",
    max_seq_len=MAX_SEQ_LEN,
    max_batch_size=1,
)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 10.96 seconds


In [4]:
short_llama = ShortLlama(llama=llama, n_prune_layers=9)

short_llama.llama.model.layers

ModuleList(
  (0-31): 32 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

In [5]:
# sample generation
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': '1960s-70s era pop music. I grew up listening to the radio'}]

### Compute Importances

In [6]:
for batch in tqdm(dataloader):
    prompts = batch['text']

    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
    max_prompt_len = max(len(t) for t in prompt_tokens)

    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, 256):

        inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]

        short_llama.eval_importance(
            prompt_tokens=inputs,
            max_gen_len=0
        )

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
short_llama.importances

[8358921.716796875,
 5211709.220703125,
 3259066.66796875,
 3164092.5087890625,
 3518517.248046875,
 3153696.0009765625,
 3062620.751953125,
 2856062.2998046875,
 2674124.23828125,
 2545894.03125,
 2382950.501953125,
 2194983.1455078125,
 2146358.5107421875,
 2180816.779296875,
 2145900.15234375,
 2126212.3974609375,
 2180678.5244140625,
 1686190.7548828125,
 1524035.5732421875,
 1270041.162109375,
 1368594.52734375,
 954588.056640625,
 944560.7900390625,
 780482.943359375,
 743930.5283203125,
 732873.1806640625,
 745402.265625,
 733417.81640625,
 762292.994140625,
 771143.9541015625,
 1303522.251953125,
 5824847.5546875]

### Remove unimportant layers

Layers removed when using pg19 val set: [25, 27, 24, 26, 28, 29, 23, 22, 21]

Note: Different order than paper but same 9 least important layers -> [27, 26, 25, 28, 24, 29, 23, 21, 22]

Additionally, authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [8]:
short_llama.remove_layers()

[25, 27, 24, 26, 28, 29, 23, 22, 21]

In [9]:
short_llama.llama.model.layers

ModuleList(
  (0-22): 23 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [10]:
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': 'Đo n Khơ 20th Century. Hinweis: In = ,t and lồ'}]

### Compute Angular Importances

In [6]:
for batch in tqdm(dataloader):
    prompts = batch['text']

    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
    max_prompt_len = max(len(t) for t in prompt_tokens)

    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, 256):

        inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]

        short_llama.eval_importance(
            prompt_tokens=inputs,
            max_gen_len=0,
            angular=True
        )

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
short_llama.importances

[8640.460205078125,
 7881.541015625,
 7303.3876953125,
 7156.226318359375,
 7003.533935546875,
 6749.5189208984375,
 6630.6031494140625,
 6494.6051025390625,
 6475.490295410156,
 6482.81884765625,
 6489.277587890625,
 6479.0064697265625,
 6486.2188720703125,
 6440.6580810546875,
 6338.8604736328125,
 6196.098876953125,
 6014.3204345703125,
 5677.5113525390625,
 5532.0673828125,
 5384.6334228515625,
 5314.61669921875,
 5176.587646484375,
 5425.315673828125,
 7029.1893310546875,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

### Remove unimportant layers

In [8]:
short_llama.remove_layers(angular=True)

[21, 22, 23, 24, 25, 26, 27, 28, 29]

In [9]:
short_llama.llama.model.layers

ModuleList(
  (0-22): 23 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

In [10]:
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': 'Đo n Khơ 20th Century. Hinweis: In = ,t and lồ'}]