In [1]:
from tqdm.notebook import tqdm

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
)
from transformers import default_data_collator, Trainer, TrainingArguments

from short_hf import ShortHFModel

### Load Data

In [2]:
data = load_dataset("pg19", split="validation")  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size=1,
    shuffle=True,
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Load Model

In [3]:
MAX_SEQ_LEN = 1024
short_model = ShortHFModel(
    model_name="mistralai/Mistral-7B-v0.1",
    layers_path="model.layers"
)
short_model.model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [4]:
short_model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [5]:
# sample generation
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=20
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


  attn_output = torch.nn.functional.scaled_dot_product_attention(


['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and']

### Compute Importances

In [6]:
for i, batch in enumerate(tqdm(dataloader)):
    # to speed up experiments, change as you please
    if i == 5:
        break

    prompts = batch['text']

    short_model.eval_importance(
        prompts=prompts,
        max_seq_len=MAX_SEQ_LEN,
        stride=256,
        max_gen_len=0
    )

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
short_model.importances

[706592.453125,
 425593.15625,
 317297.3046875,
 272700.7421875,
 313270.2265625,
 281193.0625,
 259303.7890625,
 247235.203125,
 230518.4765625,
 207359.1015625,
 206604.88671875,
 191179.20703125,
 176394.51953125,
 171003.3359375,
 182990.56640625,
 170447.859375,
 175557.640625,
 170004.4140625,
 183532.1484375,
 161545.421875,
 123509.1484375,
 95928.953125,
 81660.552734375,
 83030.16796875,
 70967.033203125,
 68553.78125,
 68749.0234375,
 79548.078125,
 94246.546875,
 110476.56640625,
 113710.2578125,
 377803.671875]

### Remove unimportant layers

Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]

Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [8]:
short_model.remove_layers(num_layers=9)

[25, 26, 24, 27, 22, 23, 28, 21, 29]

In [10]:
short_model.layers

ModuleList(
  (0-22): 23 x MistralDecoderLayer(
    (self_attn): MistralSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)

In [None]:
# reassign layer_idx to attentions for caching
for layer_idx, module in enumerate(short_model.layers):
    module.self_attn.layer_idx = layer_idx

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [12]:
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=20,
    use_cache=True
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['I am an avid fan of 2015-16 TVB TV W5DW, the 22-']

### Model Healing

In [None]:
tokenizer = short_model.tokenizer
model = short_model.model

In [None]:
# referencing https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/huggingface_trainer/peft_finetuning.ipynb
eval_prompt = """
Summarize this dialog:
A: Hi Tom, are you busy tomorrow's afternoon?
B: I'm pretty sure I am. What's up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we've discussed it many times. I think he's ready now.
B: That's good. Raising a dog is a tough issue. Like having a baby ;-) 
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he'd name it after his dead hamster - Lemmy  - he's  a great Motorhead fan :-)))
---
Summary:
"""

model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100, use_cache=True)[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Summarize this dialog:
A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) 
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))
---
Summary:
Tom, I'
'_m
ŝm
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t
t



In [None]:
def get_preprocessed_samsum():
    dataset = load_dataset("samsum", split="train")

    prompt = (
        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
    )

    def apply_prompt_template(sample):
        return {
            "prompt": prompt.format(dialog=sample["dialogue"]),
            "summary": sample["summary"],
        }

    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))

    def tokenize_add_label(sample):
        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
        sample = {
            "input_ids": prompt + summary,
            "attention_mask" : [1] * (len(prompt) + len(summary)),
            "labels": [-100] * len(prompt) + summary,
            }

        return sample

    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

    return dataset

In [None]:
model.train()

def create_peft_config(model):
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

# create peft config
model, lora_config = create_peft_config(model)

trainable params: 2,449,408 || all params: 5,281,173,504 || trainable%: 0.04637999486562599


In [None]:
output_dir = "tmp/"

config = {
    'lora_config': lora_config,
    'learning_rate': 1e-6,
    'num_train_epochs': 1,
    'per_device_train_batch_size': 1,
    'gradient_checkpointing': False,
}


In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    # logging strategies
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch_fused",
    **{k:v for k,v in config.items() if k != 'lora_config'}
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=get_preprocessed_samsum(),
    data_collator=default_data_collator,
    callbacks=[],
)

# Start training
trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/14732 [00:00<?, ?it/s]

{'loss': 3.99, 'grad_norm': 34.96875, 'learning_rate': 9.993212055389627e-07, 'epoch': 0.0}
{'loss': 3.26, 'grad_norm': 40.71875, 'learning_rate': 9.986424110779256e-07, 'epoch': 0.0}
{'loss': 3.69, 'grad_norm': 38.4375, 'learning_rate': 9.979636166168883e-07, 'epoch': 0.0}
{'loss': 2.5622, 'grad_norm': 30.171875, 'learning_rate': 9.972848221558513e-07, 'epoch': 0.0}
{'loss': 2.4639, 'grad_norm': 46.65625, 'learning_rate': 9.96606027694814e-07, 'epoch': 0.0}
{'loss': 2.1644, 'grad_norm': 25.0625, 'learning_rate': 9.959272332337767e-07, 'epoch': 0.0}
{'loss': 2.3199, 'grad_norm': 45.9375, 'learning_rate': 9.952484387727397e-07, 'epoch': 0.0}
{'loss': 1.629, 'grad_norm': 27.3125, 'learning_rate': 9.945696443117024e-07, 'epoch': 0.01}
{'loss': 1.633, 'grad_norm': 27.984375, 'learning_rate': 9.93890849850665e-07, 'epoch': 0.01}
{'loss': 1.6279, 'grad_norm': 40.65625, 'learning_rate': 9.93212055389628e-07, 'epoch': 0.01}
{'loss': 1.6812, 'grad_norm': 19.015625, 'learning_rate': 9.9253326092

TrainOutput(global_step=14732, training_loss=1.257665402485306, metrics={'train_runtime': 1721.8179, 'train_samples_per_second': 8.556, 'train_steps_per_second': 8.556, 'train_loss': 1.257665402485306, 'epoch': 1.0})

In [None]:
model.eval()
with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Summarize this dialog:
A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) 
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))
---
Summary:
 A wants to get a puppy for his son. B is sure he'll be happy with it. A took his son to the animal shelter last Monday. He liked one of the puppies. A will get him one of tho