In [1]:
# Install necessary libraries
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install evaluate  # Install the evaluate library for metrics
!pip install rouge_score

# Import libraries
import torch
from unsloth import FastLanguageModel
import math
from datasets import load_dataset
from unsloth.chat_templates import get_chat_template
from transformers import TrainingArguments, DataCollatorForSeq2Seq, Trainer
from unsloth import is_bfloat16_supported
from torch.utils.data import DataLoader
from tqdm import tqdm
import evaluate

# Initialize metrics
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

In [2]:
max_seq_length = 256  # Choose any! We auto support RoPE Scaling internally!
dtype = None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True  # Use 4bit quantization to reduce memory usage. Can be False.

# Load the student model and tokenizer (Qwen 2.5 0.5B)
student_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-0.5B-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

==((====))==  Unsloth 2024.10.7: Fast Qwen2 patching. Transformers = 4.44.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.5.0+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [3]:
# Prepare the student model for PEFT
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r=16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  # Supports any, but = 0 is optimized
    bias="none",     # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
    random_state=3407,
    use_rslora=False,   # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
)

# Load the teacher model (Qwen 2.5 1.5B)
teacher_model, _ = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-1.5B-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
teacher_model.eval()
# teacher_model.to("cuda")  # Ensure the teacher model is on the GPU

==((====))==  Unsloth 2024.10.7: Fast Qwen2 patching. Transformers = 4.44.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.5.0+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear4bit(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear4bit(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
    

In [4]:
# Set the chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="qwen-2.5",
)

# Function to format prompts
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]

    # Combine instruction and input for the prompt
    convos = []
    for instruction, input_text in zip(instructions, inputs):
        if input_text.strip() != '':
            content = f"Instruction: {instruction}\nInput: {input_text}"
        else:
            content = f"Instruction: {instruction}"
        convos.append({"role": "user", "content": content})

    responses = [{"role": "assistant", "content": output} for output in outputs]

    # Combine conversations and apply the chat template
    conversations = [{"conversations": [convo, response]} for convo, response in zip(convos, responses)]

    texts = [tokenizer.apply_chat_template(convo["conversations"], tokenize=False, add_generation_prompt=False)
             for convo in conversations]

    target_texts = outputs  # Keep the outputs as target texts

    return {"text": texts, "target_text": target_texts}

# Load and split the dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")

In [5]:
# Split dataset into training and validation sets
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset['train']
eval_dataset = dataset['test']

# Apply formatting to both datasets
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
eval_dataset = eval_dataset.map(formatting_prompts_func, batched=True)

# Tokenize the datasets and prepare labels
def tokenize_function(examples):
    tokenized = tokenizer(examples["text"], truncation=True, max_length=max_seq_length, padding="max_length")
    labels = tokenized['input_ids'].copy()
    # Set padding tokens to -100 to ignore them in the loss
    labels = [[(label if label != tokenizer.pad_token_id else -100) for label in label_list] for label_list in labels]
    tokenized['labels'] = labels
    tokenized['target_text'] = examples['target_text']
    return tokenized

train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = eval_dataset.map(tokenize_function, batched=True)

# Set the format to PyTorch tensors
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'target_text'])
eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'target_text'])

Map:   0%|          | 0/46801 [00:00<?, ? examples/s]

Map:   0%|          | 0/5201 [00:00<?, ? examples/s]

In [6]:
# Define the Knowledge Distillation Trainer
class KDTrainer(Trainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.teacher_model.eval()
        # No need to move teacher_model to 'cuda' manually

    def compute_loss(self, model, inputs, return_outputs=False):
        # Move inputs to the student's device
        input_ids = inputs['input_ids'].to(model.device)
        attention_mask = inputs['attention_mask'].to(model.device)
        labels = inputs['labels'].to(model.device)

        # Get student outputs
        student_outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits  # shape: [batch_size, seq_length, vocab_size]

        with torch.no_grad():
            # Get the device of the teacher model
            teacher_device = next(self.teacher_model.parameters()).device
            # Move inputs to the teacher's device
            input_ids_teacher = inputs['input_ids'].to(teacher_device)
            attention_mask_teacher = inputs['attention_mask'].to(teacher_device)

            # Get teacher outputs
            teacher_outputs = self.teacher_model(input_ids=input_ids_teacher, attention_mask=attention_mask_teacher)
            teacher_logits = teacher_outputs.logits  # shape: [batch_size, seq_length, vocab_size]

        # Set temperature and alpha
        temperature = 2.0
        alpha = 0.5

        # Compute student cross-entropy loss
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        student_ce_loss = loss_fct(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

        # Compute KL divergence loss at each token position
        student_log_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
        teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
        kl_loss_fct = torch.nn.KLDivLoss(reduction='none')
        kl_div_loss = kl_loss_fct(student_log_probs, teacher_probs)  # shape: [batch_size, seq_length, vocab_size]
        # Sum over vocab dimension
        kl_div_loss = kl_div_loss.sum(-1)  # shape: [batch_size, seq_length]
        # Mask out padding tokens
        mask = (labels != -100).float()
        kl_div_loss = (kl_div_loss * mask).sum() / mask.sum()
        # Adjust for temperature scaling
        kl_div_loss = kl_div_loss * (temperature ** 2)

        # Combine losses
        loss = alpha * kl_div_loss + (1 - alpha) * student_ce_loss

        return (loss, student_outputs) if return_outputs else loss

    def _move_model_to_device(self, model, device):
        # Override to prevent moving model to device
        pass

In [7]:
# Set training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=5,
    max_steps=60,
    learning_rate=2e-4,
    fp16=True,  # Set to True for T4 GPU
    bf16=False,  # T4 does not support bfloat16
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=3407,
    output_dir="outputs",
    report_to="none",  # Use this for WandB etc
    evaluation_strategy="steps",
    eval_steps=10,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True, return_tensors='pt')

# Initialize the KDTrainer
trainer = KDTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
)



In [8]:
# Start training
trainer.train()

Step,Training Loss,Validation Loss
10,6.1011,6.157755
20,3.7118,4.461288
30,3.1028,3.602319
40,3.3024,3.196372
50,2.8792,2.963253
60,2.7397,2.866669


TrainOutput(global_step=60, training_loss=4.302190721035004, metrics={'train_runtime': 10407.3018, 'train_samples_per_second': 0.006, 'train_steps_per_second': 0.006, 'total_flos': 46340902748160.0, 'train_loss': 4.302190721035004, 'epoch': 0.0012820238883784535})

In [9]:
# Evaluate the model to get perplexity
eval_results = trainer.evaluate()
perplexity = math.exp(eval_results["eval_loss"])
print(f"Perplexity: {perplexity}")

# Prepare the student model for inference
FastLanguageModel.for_inference(student_model)  # Enable native 2x faster inference

Perplexity: 17.578365968066567


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 896)
        (layers): ModuleList(
          (0-23): 24 x Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=896, out_features=896, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=896, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear4bit(
            

In [None]:
# Function to generate predictions and compute metrics
def evaluate_model(model, tokenizer, eval_dataset, batch_size=8):
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
    predictions = []
    references = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)

            # Generate outputs from the model
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=256,
                temperature=1.0,
                min_p=0.1,
            )

            # Get the generated tokens (excluding the prompt)
            generated_tokens = outputs[:, input_ids.size(1):]
            decoded_outputs = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            predictions.extend(decoded_outputs)

            # For evaluation, we need the reference outputs
            references.extend(batch['target_text'])

    # Compute BLEU and ROUGE scores
    bleu_score = bleu_metric.compute(predictions=predictions, references=references)
    rouge_score = rouge_metric.compute(predictions=predictions, references=references)

    print(f"BLEU score: {bleu_score['bleu']}")
    print(f"ROUGE scores:")
    for key, value in rouge_score.items():
        print(f"{key}: {value}")

# Evaluate the model
evaluate_model(student_model, tokenizer, eval_dataset)

  6%|▌         | 38/651 [08:41<2:20:42, 13.77s/it]