In [4]:
import torch.nn as nn
import torch
from contextlib import nullcontext
import time
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 参数的配置

In [2]:
from model.model_minimind import MiniMindConfig
class TrainArgs():
    #文件管理
    out_dir = './DPO_output'
    checkpoint_path = "./DPO_output/latest_checkpoint.pth"
    data_path = '../data/dpo.jsonl'
    #神经网络训练管理
    epochs = 2
    batch_size = 16
    accumulation_steps = 4
    learning_rate = 5e-4
    warm_up = 0
    grad_clip = 1
    dtype = 'bfloat16'
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    num_workers = 0
    log_interval = 50
    save_interval = 2000
    ctx = nullcontext() if device == "cpu" else torch.cuda.amp.autocast()
class LLMargs():
    use_moe = True
    hidden_size = 512
    num_hidden_layers = 8
    

lm_config = MiniMindConfig(use_moe=LLMargs.use_moe,hidden_size=LLMargs.hidden_size,num_hidden_layers=LLMargs.num_hidden_layers)


  ctx = nullcontext() if device == "cpu" else torch.cuda.amp.autocast()


# 环境与设备

In [3]:
import torch

print(f"CUDA is available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB")

CUDA is available: True
Number of GPUs: 1
  GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
    Memory: 8.00 GB


# 模型加载、数据加载

In [4]:
tokenizer = AutoTokenizer.from_pretrained('../model/')

In [None]:
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config).to(TrainArgs.device)
ref_model = MiniMindForCausalLM(lm_config).to(TrainArgs.device)


In [None]:
ckp = './DPO_output/latest_checkpoint.pth'
state_dict = torch.load(ckp, map_location=TrainArgs.device)


model.load_state_dict(state_dict, strict=False)
ref_model.load_state_dict(state_dict, strict=False) # 加载同样的SFT权重

In [None]:
ref_model.eval() # 切换到评估模式
ref_model.requires_grad_(False) # 关闭所有参数的梯度计算

In [None]:
print(model)

total_params = sum(p.numel() for p in model.parameters())

# 打印结果，使用 f-string 的格式化功能让数字更易读（例如，加上千位分隔符）
print(f"模型总参数量 (Total Parameters): {total_params:,}")

# 如果你还想区分“可训练参数”，可以这样做：
# 在大多数情况下，总参数量和可训练参数量是一样的。
# 除非你手动设置了某些参数的 requires_grad=False (冻结了某些层)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数量 (Trainable Parameters): {trainable_params:,}")

加载数据

In [5]:
from torch.utils.data import Dataset, DataLoader, Subset
import json
class DPODataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=4096):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
        self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids
        with open(file_path, 'r', encoding='utf-8') as f:
            self.data = []
            for line in f:
                line = line.strip()
                obj = json.loads(line)
                self.data.append(obj)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        chosen = item['chosen']  # 是一个 list，里面包含若干 {role, content}
        rejected = item['rejected']  # 同上
        chosen_prompt = self.tokenizer.apply_chat_template(
            chosen, tokenize=False, add_generation_prompt=False
        )

        rejected_prompt = self.tokenizer.apply_chat_template(
            rejected, tokenize=False, add_generation_prompt=False
        )
        chosen_encoding = self.tokenizer(
            chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
        )
        rejected_encoding = self.tokenizer(
            rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
        )

        chosen_input_ids = chosen_encoding['input_ids']
        chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)

        rejected_input_ids = rejected_encoding['input_ids']
        rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
        x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
        y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
        mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
        x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
        y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
        mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)

        return {
            'x_chosen': x_chosen,
            'y_chosen': y_chosen,
            'mask_chosen': mask_chosen,
            'x_rejected': x_rejected,
            'y_rejected': y_rejected,
            'mask_rejected': mask_rejected
        }

    def _generate_loss_mask(self, input_ids):
        loss_mask = [0] * len(input_ids)
        i = 0
        while i < len(input_ids):
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                start = i + len(self.bos_id)
                end = start
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
                    loss_mask[j] = 1
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                i += 1
        return loss_mask


In [9]:
train_ds =DPODataset(TrainArgs.data_path,tokenizer,max_length=512)

In [10]:
print(f'{train_ds[1]}\n数据集总长度为{len(train_ds)}')

{'x_chosen': tensor([   1,   85,  736,  201,   59,  292,  389,  260, 3836, 1861,  501,    2,
         201,    1,  320,  275,  201, 1701,  260, 4324,   82,  356, 6100,  566,
        6395,   14,  276, 3792, 1185,  303,  260,  823,  275,  370,  365, 1519,
         848,  290,  276,  271,  497,  569,  779,  356,  303,  276, 3717,  303,
        5301,  823,  298, 1215, 4912,   77,   16, 2675, 2472,   80,   67, 3034,
        2903, 5301,  260, 4912,   77,  281,  311,  754,   85, 2488, 4406,   18,
          14,  822, 1353, 5301, 1214,  689,   78,  291,   67,  823,  290,  311,
         754, 2488,   22,  704, 1215, 4912,   77,   33,    2,  201,    1, 1078,
         538,  501,  201,   41, 2382,  349,  276, 3792, 1185,  303,  260,  823,
         275,  370,  365, 1519,  848,  290,  276,  271,  497,  569,  779,  356,
         303,  276, 3717,  303, 5301,  823,  298, 1215, 4912,   77,   14,  670,
         343, 3295,  935, 4429,  432,   28,  201,  201,   62,   61, 5217, 5184,
          93,   53,  285, 1

In [11]:
test_ds = Subset(train_ds,range(1000))

加载dataloader

In [None]:
test_loader = DataLoader(
    test_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

# 开始训练

In [None]:
import math
from torch import optim
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
optimizer = optim.AdamW(model.parameters(), lr=TrainArgs.learning_rate)
#据模型在训练数据上的表现来调整模型的参数  训练的一个大脑
scaler = torch.cuda.amp.GradScaler(enabled=(TrainArgs.dtype in ['float16', 'bfloat16']))
#自动混合精度训练工具 训练的一个助手

In [None]:
import swanlab
# 初始化swanlab，传入项目名、实验名等
swanlab.init(
    project="MiniMind-SFT", 
    experiment_name="MiniMind-SFT", 
    config=vars(TrainArgs()) # 将你的所有超参数配置一次性传给swanlab
)

下面是计算loss的特殊算法

In [None]:
import torch.nn.functional as F
def logits_to_probs(logits, labels):
    # logits shape: (batch_size, seq_len, vocab_size)
    # labels shape: (batch_size, seq_len)
    # probs shape: (batch_size, seq_len)
    log_probs = F.log_softmax(logits, dim=2)
    probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
    return probs


def dpo_loss(ref_probs, probs, mask, beta):
    # ref_probs 和 probs 都是 shape: (batch_size, seq_len)
    # https://github.com/jingyaogong/minimind/issues/298
    seq_lengths = mask.sum(dim=1, keepdim=True)  # (batch_size, 1)
    ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
    probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()

    # 将 chosen 和 rejected 数据分开
    batch_size = ref_probs.shape[0]
    chosen_ref_probs = ref_probs[:batch_size // 2]
    reject_ref_probs = ref_probs[batch_size // 2:]
    chosen_probs = probs[:batch_size // 2]
    reject_probs = probs[batch_size // 2:]

    pi_logratios = chosen_probs - reject_probs
    ref_logratios = chosen_ref_probs - reject_ref_probs
    logits = pi_logratios - ref_logratios
    loss = -F.logsigmoid(beta * logits)
    return loss.mean()

In [None]:
def train_epoch(epoch, model, train_loader, optimizer, scaler, lm_config):
    #选择初始化损失函数！
    model.train() # 确保模型处于训练模式

    loss_fct = nn.CrossEntropyLoss(reduction='none')

    start_time = time.time()#一会记录日志要用

    iter_per_epoch = len(train_loader)#看一下一次数据有多长，计算学习率要用

    # --- 新增：使用tqdm包装DataLoader以显示进度条 ---
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TrainArgs.epochs}", leave=True)

    for step, batch in enumerate(train_loader):#不断提取数据
        x_chosen = batch['x_chosen'].to(args.device)
        x_rejected = batch['x_rejected'].to(args.device)
        y_chosen = batch['y_chosen'].to(args.device)
        y_rejected = batch['y_rejected'].to(args.device)
        mask_chosen = batch['mask_chosen'].to(args.device)
        mask_rejected = batch['mask_rejected'].to(args.device)
        X = torch.cat([x_chosen, x_rejected], dim=0)
        Y = torch.cat([y_chosen, y_rejected], dim=0)
        loss_mask = torch.cat([mask_chosen, mask_rejected], dim=0)#数据上设备，上到gpu上

        
        lr = get_lr(epoch * iter_per_epoch + step, TrainArgs.epochs * iter_per_epoch, TrainArgs.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr                #计算当前学习率的算法

        with TrainArgs.ctx:
            with torch.no_grad():
                ref_outputs = ref_model(X)
                ref_logits = ref_outputs.logits
            ref_probs = logits_to_probs(ref_logits, Y)
            ref_probs = ref_probs * loss_mask
            outputs = model(X)
            logits = outputs.logits
            probs = logits_to_probs(logits, Y)
            probs = probs * loss_mask
            loss = dpo_loss(ref_probs, probs, loss_mask, beta=0.1)
            loss = loss / TrainArgs.accumulation_steps  

            scaler.scale(loss).backward()  #反向传播

        if (step + 1) % TrainArgs.accumulation_steps == 0:#判断梯度累加到位没有
            scaler.unscale_(optimizer)#*前置操作
            torch.nn.utils.clip_grad_norm_(model.parameters(), TrainArgs.grad_clip)
            #梯度裁剪
            scaler.step(optimizer)#参数更新
            scaler.update()#*调整

            optimizer.zero_grad(set_to_none=True)#本次训练结束，清零梯度，为下一次反向传播准备

        if step % TrainArgs.log_interval == 0:
            spend_time = time.time() - start_time
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                "loss": f"{loss.item():.3f}",
                "lr": f"{current_lr:.2e}"
            })

            if (swanlab is not None) : #用日志软件记录训练过程
                swanlab.log({"loss": loss.item() * TrainArgs.accumulation_steps,
                           "lr": optimizer.param_groups[-1]['lr'],
                           "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

            # 保存checkpoint
        if (step + 1) % TrainArgs.save_interval == 0:
            # 只在更新步之后保存，确保梯度累积完成
            if (step + 1) % TrainArgs.accumulation_steps == 0:
                model.eval()
                print(f"\nSaving checkpoint at step {step+1}...")
                
                # 创建checkpoint字典
                checkpoint = {
                     # 对模型权重进行半精度转换
                    'model': {k: v.half() for k, v in model.state_dict().items()},
                    # 对优化器状态也进行转换（更严谨的做法）
                    'optimizer': {
                        'state': {k: v.half() if torch.is_tensor(v) else v for k, v in optimizer.state_dict()['state'].items()},
                        'param_groups': optimizer.state_dict()['param_groups']
                    },
                    'scaler': scaler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'config': TrainArgs,
                }
                
                # 保存
                torch.save(checkpoint, TrainArgs.checkpoint_path)
                print("Checkpoint saved.")
                model.train() # 切换回训练模式


# 实现循环训练

In [None]:
# --- 关键：Checkpoint加载逻辑 ---
start_epoch = 0
start_step = 0
if os.path.exists(TrainArgs.checkpoint_path):
    print(f"=> Resuming from checkpoint: {TrainArgs.checkpoint_path}")
    checkpoint = torch.load(TrainArgs.checkpoint_path, map_location=TrainArgs.device)
    
    # 加载模型权重 (处理DDP保存的权重)
    model_state_dict = checkpoint['model']
    if any(key.startswith('module.') for key in model_state_dict):
            model_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict)
    
    # 加载优化器和scaler状态
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    
    # 恢复训练进度
    start_epoch = checkpoint['epoch']
    # 注意：恢复step逻辑可以让dataloader跳过已训练数据，这里简化为从头开始当前epoch
    print(f"=> Resumed from epoch {start_epoch + 1}")
else:
    print("=> Starting from scratch...")

model.to(TrainArgs.device)

In [None]:
#测试训练运行以下代码
for epoch in range(start_epoch, TrainArgs.epochs):
        train_epoch(epoch, model, test_loader, optimizer, scaler, lm_config)

In [None]:
#正式训练运行以下代码
for epoch in range(start_epoch, TrainArgs.epochs):
        train_epoch(epoch, model, train_loader, optimizer, scaler, lm_config)