In [1]:
from tqdm.notebook import tqdm

from datasets import load_dataset
from torch.utils.data import DataLoader

from llama import Llama

from short_llama import ShortLlama

### Load Data

In [2]:
data = load_dataset("pg19", split="validation")
dataloader = DataLoader(data, batch_size=1)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
next(iter(dataloader))

{'short_book_title': ['Walking by Henry David Thoreau'],
 'publication_date': tensor([1862]),
 'url': ['http://www.gutenberg.org/ebooks/1022'],
 'text': ['\n\n\n\nProduced by Q Myers\n\n\n\n\n\nWALKING\n\nby Henry David Thoreau\n\n\nI wish to speak a word for Nature, for absolute freedom and wildness, as\ncontrasted with a freedom and culture merely civil--to regard man as\nan inhabitant, or a part and parcel of Nature, rather than a member\nof society. I wish to make an extreme statement, if so I may make\nan emphatic one, for there are enough champions of civilization: the\nminister and the school committee and every one of you will take care of\nthat.\n\n\n\nI have met with but one or two persons in the course of my life who\nunderstood the art of Walking, that is, of taking walks--who had a\ngenius, so to speak, for SAUNTERING, which word is beautifully derived\n"from idle people who roved about the country, in the Middle Ages, and\nasked charity, under pretense of going a la Saint

### Fetch and Wrap Model

In [4]:
llama = Llama.build(
    ckpt_dir="../llama/llama-2-7b",
    tokenizer_path="../llama/tokenizer.model",
    max_seq_len=512,
    max_batch_size=1,
)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 11.94 seconds


In [5]:
short_llama = ShortLlama(llama=llama)

short_llama.llama.model.layers

ModuleList(
  (0-31): 32 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

In [6]:
# sample generation
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': '1960s-70s era pop music. I grew up listening to the radio'}]

### Compute Importances

In [7]:
for batch in tqdm(dataloader):
    prompts = batch['text']
    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False)[:short_llama.llama.model.params.max_seq_len] for x in prompts]

    short_llama.eval_importance(
        prompt_tokens=prompt_tokens,
        max_gen_len=short_llama.llama.model.params.max_seq_len - 1
    )

  0%|          | 0/50 [00:00<?, ?it/s]

In [8]:
short_llama.importances

[12393.0,
 7455.75,
 4697.0,
 4369.625,
 4658.875,
 4229.9375,
 4071.0,
 3800.3125,
 3511.625,
 3375.875,
 3043.90625,
 2817.1875,
 2713.15625,
 2762.9375,
 2681.875,
 2749.25,
 2747.875,
 2282.4375,
 2147.0,
 1797.265625,
 1860.390625,
 1462.03125,
 1375.640625,
 1190.953125,
 1189.125,
 1235.625,
 1205.984375,
 1229.265625,
 1345.0,
 1430.890625,
 2107.21875,
 8122.375]

### Remove unimportant layers

Layers removed when using pg19 val set: [24, 23, 26, 25, 27, 28, 22, 29, 21] \
Layers removed when using pg19 test set: [24, 23, 26, 25, 27, 28, 22, 29, 21]

Note: Different order than paper but same 9 least important layers -> [27, 26, 25, 28, 24, 29, 23, 21, 22]

In [9]:
short_llama.remove_layers(num_layers=9)

[24, 23, 26, 27, 25, 28, 22, 29, 21]

In [10]:
short_llama.llama.model.layers

ModuleList(
  (0-22): 23 x TransformerBlock(
    (attention): Attention(
      (wq): ColumnParallelLinear()
      (wk): ColumnParallelLinear()
      (wv): ColumnParallelLinear()
      (wo): RowParallelLinear()
    )
    (feed_forward): FeedForward(
      (w1): ColumnParallelLinear()
      (w2): RowParallelLinear()
      (w3): ColumnParallelLinear()
    )
    (attention_norm): RMSNorm()
    (ffn_norm): RMSNorm()
  )
)

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [11]:
short_llama.llama.text_completion(
    prompts=["I am an avid fan of "],
    max_gen_len=20
)

[{'generation': '한국 인수 매 실직 반'}]