In [1]:
from tqdm.notebook import tqdm

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

from llama import Llama

from short_llama import ShortLlama

### Load Data

In [2]:
data = load_dataset("pg19", split="validation")  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator(device="cuda")
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Fetch and Wrap Model

In [3]:
MAX_SEQ_LEN = 1024  # authors use a context width of 1024
llama = Llama.build(
    ckpt_dir="../llama/llama-2-7b",
    tokenizer_path="../llama/tokenizer.model",
    max_seq_len=MAX_SEQ_LEN,
    max_batch_size=1,
)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 12.17 seconds


In [4]:
short_llama = ShortLlama(llama=llama)

short_llama.llama.model.layers

ModuleList(
  (0-31): 32 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

In [5]:
# sample generation
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': '1960s-70s era pop music. I grew up listening to the radio'}]

### Compute Importances

In [6]:
for batch in tqdm(dataloader):
    prompts = batch['text']

    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
    max_prompt_len = max(len(t) for t in prompt_tokens)

    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, 256):

        inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]

        short_llama.eval_importance(
            prompt_tokens=inputs,
            max_gen_len=0
        )

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
short_llama.importances

[16717843.43359375,
 10423418.44140625,
 6518133.3359375,
 6328185.017578125,
 7037034.49609375,
 6307392.001953125,
 6125241.50390625,
 5712124.599609375,
 5348248.4765625,
 5091788.0625,
 4765901.00390625,
 4389966.291015625,
 4292717.021484375,
 4361633.55859375,
 4291800.3046875,
 4252424.794921875,
 4361357.048828125,
 3372381.509765625,
 3048071.146484375,
 2540082.32421875,
 2737189.0546875,
 1909176.11328125,
 1889121.580078125,
 1560965.88671875,
 1487861.056640625,
 1465746.361328125,
 1490804.53125,
 1466835.6328125,
 1524585.98828125,
 1542287.908203125,
 2607044.50390625,
 11649695.109375]

### Remove unimportant layers

Layers removed when using pg19 val set: [25, 27, 24, 26, 28, 29, 23, 22, 21]

Note: Different order than paper but same 9 least important layers -> [27, 26, 25, 28, 24, 29, 23, 21, 22]

Additionally, authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [8]:
short_llama.remove_layers(num_layers=9)

[25, 27, 24, 26, 28, 29, 23, 22, 21]

In [9]:
short_llama.llama.model.layers

ModuleList(
  (0-22): 23 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [10]:
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': 'Đo n Khơ 20th Century. Hinweis: In = ,t and lồ'}]