In [None]:
#!pip install datasets
#!pip install peft
#!pip install sentencepiece
#!pip install huggingface_hub

In [None]:
import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import List, Dict, Optional
from huggingface_hub import notebook_login

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    get_cosine_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training


In [None]:
def set_seed(seed:int):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@dataclass
class Collator:
    tokenizer: AutoTokenizer
    max_len: int
    text_field: str
    def __call__(self, batch: List[Dict[str, str]]):
        texts = [ex[self.text_field] for ex in batch]
        toks = self.tokenizer(
            texts,
            max_length=self.max_len,
            truncation=True,
            padding=True,
            return_tensors='pt'
        )
        input_ids = toks['input_ids']
        attention_mask = toks['attention_mask']
        # standard causal LM labels: next-token prediction (shifted inside loss)
        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }


def load_teacher(teacher_name: str, bits: int):
    if bits == 16:
        bnb = None
    else:
        bnb = BitsAndBytesConfig(
            load_in_8bit=(bits==8),
            load_in_4bit=(bits==4),
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4',
        )
    teacher = AutoModelForCausalLM.from_pretrained(
        teacher_name,
        device_map='auto',
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        quantization_config=bnb,
    )
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad_(False)
    t_tok = AutoTokenizer.from_pretrained(teacher_name, use_fast=True)
    if t_tok.pad_token is None:
        t_tok.pad_token = t_tok.eos_token
    return teacher, t_tok


def load_student(student_name: str, bits: int, lora_cfg: LoraConfig):
    if bits == 16:
        quant_cfg = None
    else:
        quant_cfg = BitsAndBytesConfig(
            load_in_8bit=(bits==8),
            load_in_4bit=(bits==4),
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4',
        )

    model = AutoModelForCausalLM.from_pretrained(
        student_name,
        device_map='auto',
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        quantization_config=quant_cfg,
    )

    tok = AutoTokenizer.from_pretrained(student_name, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    if bits in (4, 8):
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    if lora_cfg is not None:
        model = get_peft_model(model, lora_cfg)
        model.print_trainable_parameters()
    if hasattr(model, 'enable_input_require_grads'):
        model.enable_input_require_grads()
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
    return model, tok

def kd_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha_kl=0.9, alpha_ce=0.1):
    """Mixed loss: KL on softened distributions + optional CE on labels.
    - student_logits, teacher_logits: [B, T, V]
    - labels: [B, T] with -100 masking
    """
    T = temperature
    # Align lengths just in case
    minT = min(student_logits.size(1), teacher_logits.size(1), labels.size(1))
    s = student_logits[:, :minT, :]
    t = teacher_logits[:, :minT, :]
    y = labels[:, :minT]

    s_log_probs = F.log_softmax(s / T, dim=-1)
    t_probs = F.softmax(t / T, dim=-1)
    kl = F.kl_div(s_log_probs, t_probs, reduction='none')  # [B, T, V]
    # Mask KL by valid positions (labels != -100 OR attention positions)
    valid = (y != -100).unsqueeze(-1).type_as(kl)
    kl = (kl * valid).sum(-1)  # [B, T]
    kl = kl.sum() / valid.sum().clamp(min=1.0)
    kl = (T * T) * kl

    ce = torch.tensor(0.0, device=s.device)
    if alpha_ce > 0:
        ce = F.cross_entropy(s.reshape(-1, s.size(-1)), y.reshape(-1), ignore_index=-100)

    return alpha_kl * kl + alpha_ce * ce, {'kl': kl.detach(), 'ce': ce.detach()}


In [None]:
notebook_login()

In [None]:
#__teacher_model = "meta-llama/Llama-3.1-8B-Instruct"
__teacher_model = "mistralai/Mistral-7B-Instruct-v0.2"
__student_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
__dataset_name = "json"
__data_files = '{"train": "train.jsonl", "validation": "val.jsonl"}'
__text_field = "text"
__output_dir = "./kd_lora_tinymistral"
__max_seq_len = 1024
__per_device_train_batch_size = 1
__gradient_accumulation_steps = 16
__learning_rate = 2e-4
__num_train_epochs = 1
__temperature = 2.0
__alpha_kl = 0.9
__alpha_ce = 0.1
__lora_r = 16 
__lora_alpha = 32 
__lora_dropout = 0.05
__target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
__teacher_bits = 8
__student_bits = 4
__bf16 = True
__seed = 42
__eval_steps = 0
__weight_decay = 0.0

__per_device_eval_batch_size=1
__max_steps=-1
__warmup_ratio=0.03
__temperature=2.0
__alpha_kl=0.9
__alpha_ce=0.1

__fp16=True
__gradient_checkpointing=True

__save_steps=500
__logging_steps=50


In [None]:
set_seed(__seed)
os.makedirs(__output_dir, exist_ok=True)

# LoRA config
lora_cfg = LoraConfig(
    r=__lora_r,
    lora_alpha=__lora_alpha,
    lora_dropout=__lora_dropout,
    target_modules=__target_modules,
    bias='none',
    task_type='CAUSAL_LM',
)

# Load teacher & student
print('Loading teacher...')
teacher, teacher_tok = load_teacher(__teacher_model, __teacher_bits)
print('Loading student...')
student, tok = load_student(__student_model, __student_bits, lora_cfg)


In [None]:
# Dataset
if __dataset_name == 'json':
    if not __data_files:
        raise ValueError('--data_files is required when dataset_name=json')
    data_files = json.loads(__data_files)
    ds = load_dataset('json', data_files=data_files)
else:
    if __data_files:
        ds = load_dataset(__dataset_name, data_files=json.loads(__data_files))
    else:
        ds = load_dataset(__dataset_name)


In [None]:
collate = Collator(tokenizer=tok, max_len=__max_seq_len, text_field=__text_field)
train_loader = DataLoader(ds['train'], batch_size=__per_device_train_batch_size, shuffle=True, collate_fn=collate)
eval_loader = None
if 'validation' in ds and __eval_steps != 0:
    eval_loader = DataLoader(ds['validation'], batch_size=args.per_device_eval_batch_size, shuffle=False, collate_fn=collate)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student.train()


In [None]:
# Optim & sched
optim = torch.optim.AdamW(student.parameters(), lr=__learning_rate, weight_decay=__weight_decay)
total_steps = __max_steps if __max_steps > 0 else int(len(train_loader) * __num_train_epochs // __gradient_accumulation_steps)
warmup_steps = int(total_steps * __warmup_ratio)
sched = get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

scaler = torch.cuda.amp.GradScaler(enabled=__fp16)

global_step = 0
running = {'loss': 0.0, 'kl': 0.0, 'ce': 0.0}

autocast_dtype = torch.bfloat16 if __bf16 else (torch.float16 if __fp16 else None)

for epoch in range(__num_train_epochs if __max_steps <= 0 else 10_000_000):
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.no_grad():
            t_out = teacher(input_ids=input_ids, attention_mask=attention_mask)
            t_logits = t_out.logits.detach()

        if autocast_dtype is not None:
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=autocast_dtype):
                s_out = student(input_ids=input_ids, attention_mask=attention_mask)
                s_logits = s_out.logits
                loss, parts = kd_loss(s_logits, t_logits, labels, __temperature, __alpha_kl, __alpha_ce)
        else:
            s_out = student(input_ids=input_ids, attention_mask=attention_mask)
            s_logits = s_out.logits
            loss, parts = kd_loss(s_logits, t_logits, labels, __temperature, __alpha_kl, __alpha_ce)

        loss = loss / __gradient_accumulation_steps
        if __fp16:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        running['loss'] += loss.item()
        running['kl'] += parts['kl'].item() / __gradient_accumulation_steps
        running['ce'] += parts['ce'].item() / __gradient_accumulation_steps

        if (global_step + 1) % __gradient_accumulation_steps == 0:
            if __fp16:
                scaler.step(optim)
                scaler.update()
            else:
                optim.step()
            sched.step()
            student.zero_grad(set_to_none=True)

            if (global_step // __gradient_accumulation_steps + 1) % __logging_steps == 0:
                avg_loss = running['loss'] / __logging_steps
                avg_kl = running['kl'] / __logging_steps
                avg_ce = running['ce'] / __logging_steps
                print(f"step {global_step}: loss {avg_loss:.4f} | kl {avg_kl:.4f} | ce {avg_ce:.4f}")
                running = {'loss': 0.0, 'kl': 0.0, 'ce': 0.0}

            if __save_steps and (global_step // __gradient_accumulation_steps + 1) % __save_steps == 0:
                save_dir = os.path.join(__output_dir, f'step_{global_step}')
                os.makedirs(save_dir, exist_ok=True)
                student.save_pretrained(save_dir)

            if __eval_steps and (global_step // __gradient_accumulation_steps + 1) % __eval_steps == 0 and eval_loader is not None:
                evaluate(student, eval_loader, device, autocast_dtype)

        global_step += 1
        if __max_steps > 0 and (global_step // __gradient_accumulation_steps) >= __max_steps:
            break
    if __max_steps > 0 and (global_step // __gradient_accumulation_steps) >= __max_steps:
        break

print('Saving final adapter...')
student.save_pretrained(__output_dir)
print('Done.')


In [None]:
__output_dir

In [None]:
# Upload to S3
!aws s3 cp --recursive ./kd_lora_tinymistral s3://data-daizika-com/incar_assist/model/kd_lora_tinymistral/


In [None]:
def evaluate(model, eval_loader, device, autocast_dtype):
    model.eval()
    losses = []
    with torch.no_grad():
        for batch in tqdm(eval_loader, desc='Eval'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            if autocast_dtype is not None:
                with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=autocast_dtype):
                    out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = out.loss
            else:
                out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = out.loss
            losses.append(loss.item())
    ppl = math.exp(sum(losses)/len(losses)) if losses else float('inf')
    print(f"Eval perplexity: {ppl:.2f}")
    model.train()


## Inference

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch

__teacher_model  = "mistralai/Mistral-7B-Instruct-v0.2"
__student_model  = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # base student *checkpoint*
MODEL_DIR        = __output_dir

# 1) Load the *base student tokenizer* (NOT from adapter dir)
tok = AutoTokenizer.from_pretrained(__student_model, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# (Optional) 8-bit load to save RAM for inference
bnb = BitsAndBytesConfig(load_in_8bit=True)

# 2) Load the *base student model*
model = AutoModelForCausalLM.from_pretrained(
    __student_model,
    device_map="auto",
    quantization_config=bnb,                 # or remove if you want full-precision
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

# 3) Apply the LoRA adapter saved in MODEL_DIR
model = PeftModel.from_pretrained(model, MODEL_DIR)
model.eval()

def infer(history: str, max_new=256):
    inputs = tok(history, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tok.eos_token_id,
        )
    return tok.decode(out[0], skip_special_tokens=True)

print(infer("<s>[INST] User: Find the nearest Toyota service center in Chicago. [/INST]"))


In [None]:
__output_dir