In [1]:
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, Dataset
import torch.nn.functional as F
import json
import gc
import os
from typing import List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForSeq2Seq
from transformers import Trainer, TrainingArguments
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn.utils.rnn import pad_sequence

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gc.collect()
torch.cuda.empty_cache()

In [3]:
DATA_PATH = '/Users/yingyao/Desktop/Code/GetHandsDirty.nosync/knowledge_distillation/cross_tokenizer/example.json'
STD_MODEL_PATH = '/Users/yingyao/Desktop/Code/GetHandsDirty.nosync/gz-data/Qwen2.5-0.5B-Instruct'
TCH_MODEL_PATH = '/Users/yingyao/Desktop/Code/GetHandsDirty.nosync/gz-data/Qwen2.5-1.5B-Instruct'# GLM-4-9B-0414
OUTPUT_DIR = '/Users/yingyao/Desktop/Code/GetHandsDirty.nosync/knowledge_distillation/result'

In [4]:
class SFTDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        super().__init__()
        self.data_path = data_path
        self.tokenizer = tokenizer
        with open(self.data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item['prompt']
        answer = item['answer']
   
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        answer_ids = self.tokenizer.encode(answer, add_special_tokens=False)
        
        input_ids = prompt_ids + answer_ids
        labels = answer_ids
    
        return {
            'input_ids': input_ids,
            'labels': labels
        }

In [None]:
std_model = AutoModelForCausalLM.from_pretrained(STD_MODEL_PATH, local_files_only=True)
tch_model = AutoModelForCausalLM.from_pretrained(TCH_MODEL_PATH, local_files_only=True)
std_tokenizer = AutoTokenizer.from_pretrained(STD_MODEL_PATH, use_fast=True, fix_mistral_regex=True)
tch_tokenizer = AutoTokenizer.from_pretrained(TCH_MODEL_PATH, use_fast=True, fix_mistral_regex=True)

In [6]:
dataset = SFTDataset(DATA_PATH, std_tokenizer)
dataset[0]

{'input_ids': [105043, 100165, 11319, 35946, 101909, 15469, 110498, 1773],
 'labels': [35946, 101909, 15469, 110498, 1773]}

In [7]:
std_tokenizer.decode(dataset[0]['input_ids'])

'你是谁？我是一个AI助手。'

In [8]:
data_collator = DataCollatorForSeq2Seq(tokenizer=std_tokenizer, padding=True)

In [9]:
data_collator(features = [dataset[0], dataset[1]])

{'input_ids': tensor([[105043, 100165,  11319,  35946, 101909,  15469, 110498,   1773, 151643,
         151643],
        [ 56568,  99882,  99245, 101419,  11319,  35946,  99882,  30709,  99473,
           1773]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 35946, 101909,  15469, 110498,   1773],
        [ 35946,  99882,  30709,  99473,   1773]])}

In [10]:
std_vocab = std_tokenizer.get_vocab()
tch_vocab = tch_tokenizer.get_vocab()
len(std_vocab), len(tch_vocab)

(151665, 151665)

In [11]:
def init_vocab_mapping(std_tokenizer, tch_tokenizer):
    """
    Returns: vacab_mapping: dict mapping teacher token IDs to student token IDs
             teacher_matched_ids: set of matched teacher token IDs
             student_matched_ids: set of matched student token IDs
    """

    student_vocab = std_tokenizer.get_vocab()
    teacher_vocab = tch_tokenizer.get_vocab()
    
    student_token_to_id = dict(student_vocab.items())
    vocab_mapping = {}
    
    teacher_matched_ids = set()
    student_matched_ids = set()

    for token_str, teacher_token_id in teacher_vocab.items():
        if token_str in student_token_to_id:
            student_token_id = student_token_to_id[token_str]
            vocab_mapping[teacher_token_id] = student_token_id
            teacher_matched_ids.add(teacher_token_id)
            student_matched_ids.add(student_token_id)

    return vocab_mapping, teacher_matched_ids, student_matched_ids

In [12]:
vocab_mapping, teacher_matched_ids, student_matched_ids = init_vocab_mapping(std_tokenizer, tch_tokenizer)
len(teacher_matched_ids), len(student_matched_ids)

(151665, 151665)

In [13]:
class ULDLoss(nn.Module):

    def __init__(self, std_tokenizer, tch_tokenizer, temperature=0.2):
        super().__init__()
        self.std_tokenizer = std_tokenizer
        self.tch_tokenizer = tch_tokenizer
        self.temperature = temperature
        vocab_mapping, teacher_matched_ids, student_matched_ids = (
            self.init_vocab_mapping()
        )
        self.vocab_mapping = vocab_mapping
        self.teacher_matched_ids = teacher_matched_ids
        self.student_matched_ids = student_matched_ids

    def init_vocab_mapping(self):
        """
        Returns: vacab_mapping: dict mapping teacher token IDs to student token IDs
                teacher_matched_ids: set of matched teacher token IDs
                student_matched_ids: set of matched student token IDs
        """

        student_vocab = self.std_tokenizer.get_vocab()
        teacher_vocab = self.tch_tokenizer.get_vocab()

        student_token_to_id = dict(student_vocab.items())
        vocab_mapping = {}

        teacher_matched_ids = set()
        student_matched_ids = set()

        for token_str, teacher_token_id in teacher_vocab.items():
            if token_str in student_token_to_id:
                student_token_id = student_token_to_id[token_str]
                vocab_mapping[teacher_token_id] = student_token_id
                teacher_matched_ids.add(teacher_token_id)
                student_matched_ids.add(student_token_id)

        return vocab_mapping, teacher_matched_ids, student_matched_ids

    def get_alignment_groups_from_ids(self, std_token_ids, tch_token_ids):

        def to_canonical_pieces(tok, ids):
            pieces = []
            prev = ""
            for k in range(len(ids)):
                cur = tok.decode(
                    ids[: k + 1],
                    skip_special_tokens=False,
                    clean_up_tokenization_spaces=False,
                )
                pieces.append(cur[len(prev) :])
                prev = cur
            return pieces

        s_pieces = to_canonical_pieces(self.student_tokenizer, std_token_ids)
        t_pieces = to_canonical_pieces(self.teacher_tokenizer, tch_token_ids)

        i = j = 0
        s_buf = t_buf = ""
        s_group = []
        t_group = []
        s_groups = []
        t_groups = []

        def flush():
            if s_group and t_group:
                s_groups.append(s_group.copy())
                t_groups.append(t_group.copy())

        while i < len(s_pieces) or j < len(t_pieces):
            if s_buf == t_buf and s_buf != "":
                flush()
                s_buf = t_buf = ""
                s_group = []
                t_group = []
                continue

            if s_buf == "" and i < len(s_pieces):
                s_buf += s_pieces[i]
                s_group.append(i)
                i += 1
                continue
            if t_buf == "" and j < len(t_pieces):
                t_buf += t_pieces[j]
                t_group.append(j)
                j += 1
                continue

            if len(s_buf) <= len(t_buf):
                if i < len(s_pieces):
                    s_buf += s_pieces[i]
                    s_group.append(i)
                    i += 1
                elif j < len(t_pieces):
                    t_buf += t_pieces[j]
                    t_group.append(j)
                    j += 1
            else:
                if j < len(t_pieces):
                    t_buf += t_pieces[j]
                    t_group.append(j)
                    j += 1
                elif i < len(s_pieces):
                    s_buf += s_pieces[i]
                    s_group.append(i)
                    i += 1

        if s_buf == t_buf and s_group and t_group:
            flush()
        elif s_group or t_group:

            if s_group or t_group:
                if not s_group:
                    s_group = []
                if not t_group:
                    t_group = []
                if s_group or t_group:
                    s_groups.append(s_group.copy() if s_group else [])
                    t_groups.append(t_group.copy() if t_group else [])

        return s_groups, t_groups

    def merge_prob_with_alignment_groups(self, probs, alignment_groups):

        if not alignment_groups:
            return probs

        vocab_size = probs.size(-1)
        target_len = len(alignment_groups)
        aligned_probs = torch.zeros(target_len, vocab_size, device=probs.device)

        for group_idx, group in enumerate(alignment_groups):
            if len(group) > 1:
                eps = 1e-8
                logp = torch.log(probs[group[0]].clamp_min(eps))
                for idx in group[1:]:
                    if idx < probs.size(0):
                        logp = logp + torch.log(probs[idx].clamp_min(eps))
                aligned_probs[group_idx] = torch.softmax(logp, dim=-1)
            elif len(group) == 1:
                aligned_probs[group_idx] = probs[group[0]]
            else:
                aligned_probs[group_idx] = torch.zeros_like(probs[0])

        return aligned_probs
    
    def get_answer_start_and_len(self, answers, tokenizer) -> Tuple[List[int], List[int]]:
        answers_index = []
        answers_size = []

        for answer in answers:
            answer_mask = answer.ne(tokenizer.pad_token_id)
            if not answer_mask.any():
                answers_index.append(0)
                answers_size.append(0)
                continue

            indices = answer_mask.nonzero(as_tuple=True)[0]
            answers_index.append(int(indices[0].item()))
            answers_size.append(int(answer_mask.sum().item()))
        return answers_index, answers_size
    
    def compute_uld_loss(self, std_logits, tch_logits, std_labels, tch_labels, std_input_ids, tch_input_ids):
        # align text length
        std_ans_index, std_ans_size = self.get_answer_start_and_len(std_logits, self.std_tokenizer)
        tch_ans_index, tch_ans_size = self.get_answer_start_and_len(tch_logits, self.tch_tokenizer)
        B = std_logits.shape[0]
        for b in range(B):
            # keep only ans part
            std_ans_logits = std_logits[b, std_ans_index[b] : std_ans_index[b] + std_ans_size[b], :] 
            tch_ans_logits = tch_logits[b, tch_ans_index[b] : tch_ans_index[b] + tch_ans_size[b], :] 

            student_probs = F.softmax(std_ans_logits / self.temperature, dim=-1)
            teacher_probs = F.softmax(tch_ans_logits / self.temperature, dim=-1)

            std_token_ids = std_input_ids[b, std_ans_index[b] : std_ans_index[b] + std_ans_size[b]].tolist()  
            tch_token_ids = tch_input_ids[b, tch_ans_index[b] : tch_ans_index[b] + tch_ans_size[b]].tolist()
            std_alignment_groups, tch_alignment_groups = self.get_alignment_groups_from_ids(std_token_ids[:-1], tch_token_ids[:-1])
            std_aligned = self.merge_prob_with_alignment_groups(student_probs[:-1], std_alignment_groups)
            tch_aligned = self.merge_prob_with_alignment_groups(teacher_probs[:-1], tch_alignment_groups)
            std_aligned = torch.cat([std_aligned, student_probs[-1:, :]], dim=0)
            tch_aligned = torch.cat([tch_aligned, teacher_probs[-1:, :]], dim=0)

        # align vocab size - use KL loss to train for matched tokens; use sort + pad and L1 loss to train for unmatched tokens
        return None

    def forward(self, std_logits, tch_logits, std_labels, tch_labels, std_input_ids, tch_input_ids):
        loss = self.compute_uld_loss(std_logits, tch_logits, std_labels, tch_labels, std_input_ids, tch_input_ids)
        return loss

In [14]:
class KDTrainer(Trainer):

    def __init__(
        self,
        model=None,
        tch_model=None,
        tch_tokenizer=None,
        args=None,
        data_collator=None,
        train_dataset=None,
        tokenizer=None,
        max_length=512,
        **kwargs,
    ):
        self.tch_model = tch_model.eval()
        self.tch_tokenizer = tch_tokenizer
        self.max_length = max_length
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
            **kwargs,
        )
        self.uld_loss = ULDLoss(std_tokenizer=tokenizer, tch_tokenizer=tch_tokenizer)

    def get_inputs_from_text(self, tokenizer, prompt_texts, ans_texts):
        sequences = []
        labels_list = []
        attention_masks = []
        for prompt_text, ans_text in zip(prompt_texts, ans_texts):
            messages = [{"role": "user", "content": prompt_text}]
            prompt = tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=False
            )
            prompt_ids = tokenizer.encode(prompt)
            answer_ids = tokenizer.encode(ans_text, add_special_tokens=False) + [
                tokenizer.eos_token_id
            ]
            sequence = prompt_ids + answer_ids
            attention_mask = [1] * len(sequence)
            labels = [tokenizer.pad_token_id] * len(prompt_ids) + answer_ids
            if len(sequence) > self.max_length:
                sequence = sequence[: self.max_length]
                attention_mask = attention_mask[: self.max_length]
                labels = labels[: self.max_length]
            else:
                sequence += tokenizer.pad_token_id * (self.max_length - len(sequence))
                attention_mask += tokenizer.pad_token_id * (self.max_length - len(sequence))
                labels += tokenizer.pad_token_id * (self.max_length - len(sequence))
            sequences.append(torch.tensor(sequence))
            attention_masks.append(torch.tensor(attention_mask))
            labels_list.append(torch.tensor(labels))
        sequences = torch.stack(sequences).contiguous().to(self.model.device)
        attention_masks = torch.stack(attention_masks).contiguous().to(self.model.device)
        labels = torch.stack(labels_list).contiguous().to(self.model.device)
        return sequences, attention_masks, labels


    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        input_ids = inputs['input_ids']
        labels = inputs['labels']
        prompt_ids = [input_id[:len(input_id) - len(label)] for input_id, label in zip(input_ids, labels)]
        
        prompt_texts = self.tokenizer.batch_decode(prompt_ids)
        answer_texts = self.tokenizer.batch_decode(labels)

        std_input_ids, std_labels, std_attention_mask  = self.get_inputs_from_text(self.tokenizer, prompt_texts, answer_texts)
        tch_input_ids, tch_labels, tch_attention_mask = self.get_inputs_from_text(self.tch_tokenizer, prompt_texts, answer_texts)

        outputs = model(input_ids = std_input_ids, attention_mask=std_attention_mask)
        with torch.no_grad():
            tch_outputs = self.tch_model(input_ids = tch_input_ids, attention_mask=tch_attention_mask)
        logits = outputs.logits
        tch_logits = tch_outputs.logits

        loss = self.uld_loss(logits, tch_logits, std_labels, tch_labels, std_attention_mask, tch_attention_mask)
        print(f"loss: {loss:.4f}")
        return (loss, logits) if return_outputs else loss

In [15]:
args = TrainingArguments(output_dir=OUTPUT_DIR, 
                        num_train_epochs=1, 
                        do_train=True, 
                        per_device_train_batch_size=8,
                        gradient_accumulation_steps=1,
                        logging_steps=1,
                        report_to='tensorboard',
                        save_strategy='steps',
                        save_total_limit=3,
                        save_steps=100,
                        bf16=True,
                        learning_rate=0.00001,
                        lr_scheduler_type='cosine',
                        dataloader_num_workers=8,
                        dataloader_pin_memory=True,
                        max_steps = 5)

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = "cuda" if torch.cuda.is_available() else "cpu"
trainer = KDTrainer(model=std_model,
                    tch_model=tch_model, 
                    args=args, 
                    train_dataset=dataset, 
                    tokenizer=std_tokenizer, 
                    tch_tokenizer=tch_tokenizer,
                    data_collator=data_collator)

trainer.train(resume_from_checkpoint=False)

  super().__init__(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
ERROR:tornado.general:SEND Error: Host unreachable
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    import sys; sys.path.insert(0, r'/Users/yingyao/miniconda3/envs/transformer-practice/lib/python3.14/site