In [2]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from contextlib import nullcontext
import time
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 参数配置

In [None]:
from model.model_minimind import MiniMindConfig
class TrainArgs():
    #文件管理
    out_dir = './DPO_output'
    checkpoint_path = "./DPO_output/latest_checkpoint.pth"
    data_path = '../data/dpo.jsonl'
    #神经网络训练管理
    epochs = 2
    batch_size = 16
    accumulation_steps = 4
    learning_rate = 5e-4
    warm_up = 0
    grad_clip = 1
    dtype = 'bfloat16'
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    num_workers = 0
    log_interval = 50
    save_interval = 2000
    ctx = nullcontext() if device == "cpu" else torch.cuda.amp.autocast()
class LLMargs():
    use_moe = True
    hidden_size = 512
    num_hidden_layers = 8
    

lm_config = MiniMindConfig(use_moe=LLMargs.use_moe,hidden_size=LLMargs.hidden_size,num_hidden_layers=LLMargs.num_hidden_layers)


# 模型加载、数据加载

In [None]:
tokenizer = AutoTokenizer.from_pretrained('../model/')

In [None]:
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
tokenizer = AutoTokenizer.from_pretrained('../model/')
student_model = MiniMindForCausalLM(lm_config).to(TrainArgs.device)
#下面是蒸馏的特殊之处，还需要再加载个模型
teacher_model = MiniMindForCausalLM(lm_config).to(TrainArgs.device)

In [None]:
ckp_teacher = './DPO_output/latest_checkpoint.pth'
ckp_student = './sft_output/latest_checkpoint.pth'
t_state_dict = torch.load(ckp_teacher, map_location=TrainArgs.device)
s_state_dict = torch.load(ckp_student, map_location=TrainArgs.device)

student_model.load_state_dict(s_state_dict, strict=False)
teacher_model.load_state_dict(t_state_dict, strict=False)

In [None]:
teacher_model.eval()
teacher_model.requires_grad_(False)

加载数据

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
import json
class SFTDataset(Dataset):
    def __init__(self,  tokenizer,data_path, max_length=1024):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = self.load_data(data_path)
        self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
        self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids

    def __len__(self):
        return len(self.samples)

    def load_data(self, data_path):
        samples = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                data = json.loads(line.strip())
                samples.append(data)
        return samples

    def _create_chat_prompt(self, conversations):
        """构建符合ChatML格式的对话"""
        messages = []
        for i, turn in enumerate(conversations):
            role = 'user' if i % 2 == 0 else 'assistant'
            messages.append({"role": role, "content": turn['content']})
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )

    def _generate_loss_mask(self, input_ids):
        loss_mask = [0] * len(input_ids)
        i = 0
        while i < len(input_ids):
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                start = i + len(self.bos_id)
                end = start
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
                    loss_mask[j] = 1
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                i += 1
        return loss_mask

    def __getitem__(self, index):
        sample = self.samples[index]
        # 构建对话提示
        prompt = self._create_chat_prompt(sample['conversations'])
        input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))

        # 生成动态损失掩码
        loss_mask = self._generate_loss_mask(input_ids)

        # 构建训练数据
        X = input_ids[:-1].clone().detach()
        Y = input_ids[1:].clone().detach()
        loss_mask = loss_mask[1:].clone().detach() # 对齐预测位置

        return X, Y, loss_mask

In [None]:
train_ds =SFTDataset(TrainArgs.data_path,tokenizer,max_length=512)

In [None]:
print(f'{train_ds[1]}\n数据集总长度为{len(train_ds)}')

In [None]:
test_ds = Subset(train_ds,range(1000))

加载dataloader

In [None]:
test_loader = DataLoader(
    test_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

# 开始训练

In [None]:
import math
from torch import optim
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
optimizer = optim.AdamW(student_model.parameters(), lr=TrainArgs.learning_rate)
#据模型在训练数据上的表现来调整模型的参数  训练的一个大脑
scaler = torch.cuda.amp.GradScaler(enabled=(TrainArgs.dtype in ['float16', 'bfloat16']))
#自动混合精度训练工具 训练的一个助手

In [None]:
import swanlab
# 初始化swanlab，传入项目名、实验名等
swanlab.init(
    project="MiniMind-Distillation", 
    experiment_name="MiniMind-Distillation", 
    config=vars(TrainArgs()) # 将你的所有超参数配置一次性传给swanlab
)

In [None]:
#蒸馏所需要的特殊loss算法
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
    with torch.no_grad():
        teacher_probs = F.softmax(teacher_logits / temperature, hidden_size=-1).detach()

    student_log_probs = F.log_softmax(student_logits / temperature, hidden_size=-1)

    kl = F.kl_div(
        student_log_probs,
        teacher_probs,
        reduction=reduction
    )
    return (temperature ** 2) * kl

In [None]:
def train_epoch(epoch, model, train_loader, optimizer, scaler, lm_config,alpha=0.0, temperature=1.0):
    #选择初始化损失函数！
    model.train() # 确保模型处于训练模式

    loss_fct = nn.CrossEntropyLoss(reduction='none')

    start_time = time.time()#一会记录日志要用

    iter_per_epoch = len(train_loader)#看一下一次数据有多长，计算学习率要用

    # --- 新增：使用tqdm包装DataLoader以显示进度条 ---
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TrainArgs.epochs}", leave=True)

    for step, (X, Y, loss_mask) in enumerate(train_loader):#不断提取数据
        X = X.to(TrainArgs.device)
        Y = Y.to(TrainArgs.device)
        loss_mask = loss_mask.to(TrainArgs.device)#数据上设备，上到gpu上

        
        lr = get_lr(epoch * iter_per_epoch + step, TrainArgs.epochs * iter_per_epoch, TrainArgs.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr                #计算当前学习率的算法

        with TrainArgs.ctx:
            res = model(X)
            student_logits = res.logits

            # 教师模型前向传播（只在eval & no_grad）
            if teacher_model is not None:
                with torch.no_grad():
                    teacher_logits = teacher_model(X).logits
                    vocab_size_student = student_logits.size(-1)  # N
                    teacher_logits = teacher_logits[..., :vocab_size_student]


            # ========== 计算损失 ==========
        # 1) Ground-Truth CE Loss（可选）
        loss_mask_flat = loss_mask.view(-1)
        ce_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            Y.view(-1),
            ignore_index=0,
            reduction='none'
        )
        ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
        if lm_config.use_moe:
            ce_loss += res.aux_loss

        # 2) Distillation Loss（可选）
        if teacher_model is not None:
            # 只在有效token位置做蒸馏
            distill_loss = distillation_loss_fn(
                student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
                teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
                temperature=temperature
            )
        else:
            distill_loss = torch.tensor(0.0, device=args.device)

        # 3) 总损失 = alpha * CE + (1-alpha) * Distill
        loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps

        scaler.scale(loss).backward()  #反向传播

        if (step + 1) % TrainArgs.accumulation_steps == 0:#判断梯度累加到位没有
            scaler.unscale_(optimizer)#*前置操作
            torch.nn.utils.clip_grad_norm_(model.parameters(), TrainArgs.grad_clip)
            #梯度裁剪
            scaler.step(optimizer)#参数更新
            scaler.update()#*调整

            optimizer.zero_grad(set_to_none=True)#本次训练结束，清零梯度，为下一次反向传播准备

        if step % TrainArgs.log_interval == 0:
            spend_time = time.time() - start_time
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                "loss": f"{loss.item():.3f}",
                "lr": f"{current_lr:.2e}"
            })

            if (swanlab is not None) : #用日志软件记录训练过程
                swanlab.log({"loss": loss.item() * TrainArgs.accumulation_steps,
                           "lr": optimizer.param_groups[-1]['lr'],
                           "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

            # 保存checkpoint
        if (step + 1) % TrainArgs.save_interval == 0:
            # 只在更新步之后保存，确保梯度累积完成
            if (step + 1) % TrainArgs.accumulation_steps == 0:
                model.eval()
                print(f"\nSaving checkpoint at step {step+1}...")
                
                # 创建checkpoint字典
                checkpoint = {
                     # 对模型权重进行半精度转换
                    'model': {k: v.half() for k, v in model.state_dict().items()},
                    # 对优化器状态也进行转换（更严谨的做法）
                    'optimizer': {
                        'state': {k: v.half() if torch.is_tensor(v) else v for k, v in optimizer.state_dict()['state'].items()},
                        'param_groups': optimizer.state_dict()['param_groups']
                    },
                    'scaler': scaler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'config': TrainArgs,
                }
                
                # 保存
                torch.save(checkpoint, TrainArgs.checkpoint_path)
                print("Checkpoint saved.")
                model.train() # 切换回训练模式
