# Llama distillation

In [1]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters()) / 10**6
    print(f'total_params: {total_params:.3f}M')

In [2]:
cd lm-evaluation-harness

/home/jupyter/work/resources/lm-evaluation-harness


In [3]:
import gc
import os
from typing import List

from lm_eval import evaluator, tasks

import torch
import numpy as np
from datasets import load_dataset, concatenate_datasets


from transformers import (
    AdamW,
    LlamaForCausalLM,
    LlamaTokenizer,
    AutoConfig,
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling
)

from torch.utils.data import Dataset, DataLoader

from lm_eval.base import BaseLM



In [4]:
def distillation_loss(student_logits, teacher_logits, temperature):
    loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
    student_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    
    return loss_fn(student_probs, teacher_probs)

In [5]:
teacher_llama = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_3b_v2")

count_parameters(teacher_llama)

total_params: 3426.474M


In [6]:
student_config = AutoConfig.from_pretrained("openlm-research/open_llama_3b_v2", num_hidden_layers=6)



In [7]:
## take 1, 4, 11, 16, 23, 26 layers of parent model
layers = [0, 3, 10, 15, 22, 25]
print(len(layers))

6


In [8]:
student_llama = LlamaForCausalLM(student_config)

for i in range(6):
    student_llama.model.layers[i].load_state_dict(teacher_llama.model.layers[layers[i]].state_dict())
    
print(count_parameters(student_llama))

total_params: 948.266M
None


In [9]:
def filter_dataset(dataset, percents):
    part = percents / 100
    return dataset.shuffle(seed=42).select(range(int(len(dataset)*part)))

In [10]:
wikipedia_dataset = load_dataset('wikipedia', '20220301.en', split='train[:1%]')
bookcorpus_dataset = load_dataset('bookcorpus', split='train[:1%]')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [11]:
wikipedia_dataset_filtered = filter_dataset(wikipedia_dataset, 10)
bookcorpus_dataset_filtered = filter_dataset(bookcorpus_dataset, 5)

In [12]:
print(len(wikipedia_dataset_filtered))
print(len(bookcorpus_dataset_filtered))

6458
37002


In [13]:
tokenizer=LlamaTokenizer.from_pretrained(
    "openlm-research/open_llama_3b_v2",
)
tokenizer.pad_token_id = 0

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [14]:
tokenizer

LlamaTokenizer(name_or_path='openlm-research/open_llama_3b_v2', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [15]:
def tokenize_function(examples):
    texts = [text for text in examples['text'] if text is not None]
    return tokenizer(
        texts,
        return_special_tokens_mask=True,
        truncation=True,
        max_length=2048,
        padding='max_length'
    )

In [16]:
tokenized_wikipedia_dataset = wikipedia_dataset_filtered.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

In [17]:
tokenized_bookcorpus_dataset = bookcorpus_dataset_filtered.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

In [18]:
combined_dataset = concatenate_datasets([tokenized_wikipedia_dataset, tokenized_bookcorpus_dataset])

In [19]:
len(combined_dataset)

43460

In [20]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = torch.tensor(dataset['input_ids'])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return item

In [21]:
train_dataset = CustomDataset(combined_dataset)

In [22]:
len(train_dataset)

43460

In [23]:
BATCH_SIZE = 8
NUM_EPOCHS = 4
STEP_LOG_PERIOD = 50

In [24]:
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = AdamW(student_llama.parameters(), lr=3e-4)



In [25]:
len(train_loader)

5433

In [50]:
# student_llama.to('cpu')
# del student_llama
# teacher_llama.to('cpu')
# del teacher_llama
# del trainer
gc.collect()
torch.cuda.empty_cache()
for device in range(torch.cuda.device_count()):
    torch.cuda.set_device(device)
    torch.cuda.empty_cache()
gc.collect()

0

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [27]:
print('upd')

upd


In [28]:
student_llama.to(device)
teacher_llama.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 3200, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (k_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (v_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (o_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (up_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (down_proj): Linear(in_features=8640, out_features=3200, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )


In [29]:
import time

step = 0
start_time = time.time()
for epoch in range(NUM_EPOCHS):
    student_llama.train()
    for inputs in train_loader:
        inputs = inputs.to(device)
        with torch.no_grad():
            teacher_outputs = teacher_llama(inputs).logits
        student_outputs = student_llama(inputs).logits

        loss = distillation_loss(student_outputs, teacher_outputs, temperature=2.0)
        if (step % STEP_LOG_PERIOD == 0):
            end_time = time.time()
            execution_time = end_time - start_time
            print(f'step: {step} / {len(train_loader)*NUM_EPOCHS}, loss: {loss}, execution time: {execution_time}')
            start_time = time.time()
        step += 1
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

step: 0 / 21732, loss: 2447.484375, execution time: 8.996416330337524
step: 50 / 21732, loss: 477.70770263671875, execution time: 594.5072367191315


KeyboardInterrupt: 

In [None]:
# Сохранение дообученной модели
student_llama.save_pretrained("./open_llama_3b_v2_distillation")
tokenizer.save_pretrained("./open_llama_3b_v2_distillation")

In [158]:
BATCH_SIZE = 8
PER_DEVICE_TRAIN_BATCH_SIZE = BATCH_SIZE 
NUM_EPOCHS = 3
OUTPUT_DIR = "./dist_output"

In [159]:
class ModelWrapper(BaseLM):
    def __init__(
        self,
        model,
        batch_size,
        tokenizer,
        device
    ):
        super().__init__()
        self.config = model.config
        self.model = model
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model)
        self.model.to(device)
        self.tokenizer = tokenizer
        self.batch_size_per_gpu = batch_size
        self.device_ = device

    @torch.inference_mode()
    def _model_call(self, inps):
        outputs = self.model(inps)
        if hasattr(outputs, 'logits'):
            return outputs.logits
        elif hasattr(outputs, 'last_hidden_state'):
            return outputs.last_hidden_state
        else:
            raise ValueError("Model output does not contain 'logits' or 'last_hidden_state'")

    @torch.inference_mode()
    def _model_generate(self, context, max_length, eos_token_id) -> torch.Tensor:
        # this only supports batch size 1
        assert context.shape[0] == 1
        out = generate(self.model, context[0], max_length, eos_id=eos_token_id)
        for block in self.model.transformer.h:
            block.attn.kv_cache.reset_parameters()
        return out.unsqueeze(0)

    @property
    def batch_size(self):
        return self.batch_size_per_gpu*torch.cuda.device_count()

    @property
    def device(self):
        return self.device_

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_id

    @property
    def max_gen_toks(self):
        return 256

    @property
    def max_length(self):
        return self.config.max_position_embeddings

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        t = torch.tensor(tokens)
        return self.tokenizer.decode(t)
    
    def clear_gpu_memory(self):
        self.model.module.cpu()
        del self.model  
        gc.collect()
        torch.cuda.empty_cache()
        for device in range(torch.cuda.device_count()):
            torch.cuda.set_device(device)
            torch.cuda.empty_cache()
        gc.collect()

In [163]:
model_path = "./open_llama_3b_v2_distillation"
model = LlamaForCausalLM.from_pretrained(model_path)

tokenizer=LlamaTokenizer.from_pretrained(
    model_path,
)

In [164]:
eval_tasks: List[str] = ['winogrande','boolq','piqa']

num_fewshot = 0
limit = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wrapped_model = ModelWrapper(
    model=model,
    batch_size=BATCH_SIZE,
    tokenizer=tokenizer,
    device=device
)

In [None]:
results = evaluator.evaluate(
    lm=wrapped_model,
    task_dict=tasks.get_task_dict(eval_tasks),
    num_fewshot=num_fewshot
)

In [None]:
print(results)