In [1]:
import torch.nn as nn
import torch
from contextlib import nullcontext
import time
from tqdm.notebook import tqdm
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 环境与设备

In [2]:
import torch

print(f"CUDA is available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / (1024**3):.2f} GB")

CUDA is available: True
Number of GPUs: 1
  GPU 0: NVIDIA GeForce RTX 4090
    Memory: 23.53 GB


# 参数配置

In [3]:
from model.model_minimind import MiniMindConfig
class TrainArgs():
    #文件管理
    out_dir = './sft_output'
    checkpoint_path = "./sft_output/latest_checkpoint.pth"
    data_path = 'sft_1024.jsonl'
    #神经网络训练管理
    epochs = 2
    batch_size = 128
    accumulation_steps = 4
    learning_rate = 5e-6
    warm_up = 0
    grad_clip = 1
    dtype = 'bfloat16'
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    num_workers = min(8, os.cpu_count())
    log_interval = 60
    save_interval = 2000
    ctx = nullcontext() if device == "cpu" else torch.cuda.amp.autocast()
class LLMargs():
    use_moe = False
    hidden_size = 512
    num_hidden_layers = 8
    

lm_config = MiniMindConfig(use_moe=LLMargs.use_moe,hidden_size=LLMargs.hidden_size,num_hidden_layers=LLMargs.num_hidden_layers)


  ctx = nullcontext() if device == "cpu" else torch.cuda.amp.autocast()


In [6]:

os.makedirs(TrainArgs.out_dir, exist_ok=True)

# 加载模型与数据

加载模型

In [4]:
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
tokenizer = AutoTokenizer.from_pretrained('model')
model = MiniMindForCausalLM(lm_config).to(TrainArgs.device)



In [5]:


ckp = 'pretrain_output/latest_checkpoint.pth'
# 1. 加载完整的 checkpoint 字典，并指明 weights_only=False
print(f"Loading checkpoint from {ckp}...")
checkpoint = torch.load(ckp, map_location=TrainArgs.device, weights_only=False)

# 2. 从字典中提取模型的 state_dict 来加载
model.load_state_dict(checkpoint['model'], strict=True)

print("Checkpoint loaded successfully.")
model.to(TrainArgs.device)

Loading checkpoint from pretrain_output/latest_checkpoint.pth...
Checkpoint loaded successfully.


MiniMindForCausalLM(
  (model): MiniMindModel(
    (embed_tokens): Embedding(6400, 512)
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-7): 8 x MiniMindBlock(
        (self_attn): Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=128, bias=False)
          (v_proj): Linear(in_features=512, out_features=128, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (mlp): FeedForward(
          (gate_proj): Linear(in_features=512, out_features=1408, bias=False)
          (down_proj): Linear(in_features=1408, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=1408, bias=False)
          (drop

In [9]:
print(model)

total_params = sum(p.numel() for p in model.parameters())

# 打印结果，使用 f-string 的格式化功能让数字更易读（例如，加上千位分隔符）
print(f"模型总参数量 (Total Parameters): {total_params:,}")

# 如果你还想区分“可训练参数”，可以这样做：
# 在大多数情况下，总参数量和可训练参数量是一样的。
# 除非你手动设置了某些参数的 requires_grad=False (冻结了某些层)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数量 (Trainable Parameters): {trainable_params:,}")

MiniMindForCausalLM(
  (model): MiniMindModel(
    (embed_tokens): Embedding(6400, 512)
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-7): 8 x MiniMindBlock(
        (self_attn): Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=128, bias=False)
          (v_proj): Linear(in_features=512, out_features=128, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (mlp): FeedForward(
          (gate_proj): Linear(in_features=512, out_features=1408, bias=False)
          (down_proj): Linear(in_features=1408, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=1408, bias=False)
          (drop

加载数据

In [6]:
from torch.utils.data import Dataset, DataLoader, Subset
import json
class SFTDataset(Dataset):
    def __init__(self,  tokenizer,data_path, max_length=1024):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = self.load_data(data_path)
        self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids
        self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids

    def __len__(self):
        return len(self.samples)

    def load_data(self, data_path):
        samples = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                data = json.loads(line.strip())
                samples.append(data)
        return samples

    def _create_chat_prompt(self, conversations):
        """构建符合ChatML格式的对话"""
        messages = []
        for i, turn in enumerate(conversations):
            role = 'user' if i % 2 == 0 else 'assistant'
            messages.append({"role": role, "content": turn['content']})
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )

    def _generate_loss_mask(self, input_ids):
        loss_mask = [0] * len(input_ids)
        i = 0
        while i < len(input_ids):
            if input_ids[i:i + len(self.bos_id)] == self.bos_id:
                start = i + len(self.bos_id)
                end = start
                while end < len(input_ids):
                    if input_ids[end:end + len(self.eos_id)] == self.eos_id:
                        break
                    end += 1
                for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
                    loss_mask[j] = 1
                i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
            else:
                i += 1
        return loss_mask

    def __getitem__(self, index):
        sample = self.samples[index]
        
        # 构建对话提示
        prompt = self._create_chat_prompt(sample['conversations'])
        input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))

        # 生成动态损失掩码
        loss_mask = self._generate_loss_mask(input_ids)

        # 构建训练数据
        X = torch.tensor(input_ids[:-1], dtype=torch.long)
        Y = torch.tensor(input_ids[1:], dtype=torch.long)
        loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)  # 对齐预测位置
 # 对齐预测位置

        return X, Y, loss_mask

In [7]:
train_ds = SFTDataset(tokenizer,TrainArgs.data_path,max_length=512)

In [8]:
print(f'{train_ds[1]}\n数据集总长度为{len(train_ds)}')

(tensor([   1,   85,  736,  201,   59,  292,  389,  260, 3836, 1861,  501,    2,
         201,    1,  320,  275,  201, 1936,  343,  384, 4451,  260, 1482,  273,
         327, 3717,  295,  614, 2143, 1833,  730,  281, 5171,   14, 6194,   33,
           2,  201,    1, 1078,  538,  501,  201, 5042, 4451,  260, 1482,  273,
         327, 3717, 1833,  730,  281, 5171,  958,  261, 4109, 4941,   11,  295,
         614, 2143,   14,  363,  343,  813,  276, 3745, 3830,  327,   66,  461,
        6139,   14,  896, 3198, 2514, 2356,  325,  524,  290, 4451, 1482,  273,
         327, 1921, 2781,   16, 1193, 1860,   85,  822,  363,  343, 1004,  470,
          28,  201,  201, 1638, 4751,  201,  334, 1519, 1482,  273,  327,  201,
         201,    5,  886, 3877,  429,  260, 1482,  273,  327,  696, 1841,  275,
        1833,  730,  281, 5171,   14, 6194,  201, 3830,  327,   65, 3267, 1485,
         901, 1482,  273,  327,   16, 3830, 1131,   10,   19,   14, 5171,   11,
         201,  201, 2809,   10, 3830,  

In [13]:
test_ds = Subset(train_ds,range(1000))

加载dataloader

In [14]:
test_loader = DataLoader(
    test_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

In [9]:
train_loader = DataLoader(
    train_ds,
    batch_size = TrainArgs.batch_size,
    shuffle = True,
    num_workers = TrainArgs.num_workers,
    pin_memory= True
)

检查一下

In [16]:
tag = False
for batch_idx, (X_batch, Y_batch, loss_mask_batch) in enumerate(train_loader):
    if not tag:
        print(f"Batch {batch_idx}:")
        print(f"  X_batch shape: {X_batch.shape}")
        print(f"  Y_batch shape: {Y_batch.shape}")
        print(f"  loss_mask_batch shape: {loss_mask_batch.shape}")
        tag = True
    break

Batch 0:
  X_batch shape: torch.Size([128, 511])
  Y_batch shape: torch.Size([128, 511])
  loss_mask_batch shape: torch.Size([128, 511])


In [17]:
tag = False
for batch_idx, (X_batch, Y_batch, loss_mask_batch) in enumerate(test_loader):
    if not tag:
        print(f"Batch {batch_idx}:")
        print(f"  X_batch shape: {X_batch.shape}")
        print(f"  Y_batch shape: {Y_batch.shape}")
        print(f"  loss_mask_batch shape: {loss_mask_batch.shape}")
        tag = True
    break

Batch 0:
  X_batch shape: torch.Size([128, 511])
  Y_batch shape: torch.Size([128, 511])
  loss_mask_batch shape: torch.Size([128, 511])


# 训练

In [10]:
import math
from torch import optim
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
optimizer = optim.AdamW(model.parameters(), lr=TrainArgs.learning_rate)
#据模型在训练数据上的表现来调整模型的参数  训练的一个大脑
scaler = torch.cuda.amp.GradScaler(enabled=(TrainArgs.dtype in ['float16', 'bfloat16']))
#自动混合精度训练工具 训练的一个助手

  scaler = torch.cuda.amp.GradScaler(enabled=(TrainArgs.dtype in ['float16', 'bfloat16']))


In [11]:
import swanlab
# 初始化swanlab，传入项目名、实验名等
swanlab.init(
    project="MiniMind-SFT", 
    experiment_name="MiniMind-SFT", 
    config=vars(TrainArgs()) # 将你的所有超参数配置一次性传给swanlab
)

Output()

Output()

<swanlab.data.run.main.SwanLabRun at 0x7f8dceb30160>

定义训练epoch

In [12]:
def train_epoch(epoch, model, train_loader, optimizer, scaler, lm_config):
    #选择初始化损失函数！
    model.train() # 确保模型处于训练模式

    loss_fct = nn.CrossEntropyLoss(reduction='none')

    start_time = time.time()#一会记录日志要用

    iter_per_epoch = len(train_loader)#看一下一次数据有多长，计算学习率要用

    # --- 新增：使用tqdm包装DataLoader以显示进度条 ---
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TrainArgs.epochs}", leave=True)

    for step, (X, Y, loss_mask) in enumerate(progress_bar):#不断提取数据
        X = X.to(TrainArgs.device)
        Y = Y.to(TrainArgs.device)
        loss_mask = loss_mask.to(TrainArgs.device)#数据上设备，上到gpu上

        
        lr = get_lr(epoch * iter_per_epoch + step, TrainArgs.epochs * iter_per_epoch, TrainArgs.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr                #计算当前学习率的算法

        with TrainArgs.ctx:
            res = model(X)#得到的是logits，aux_loss

            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())  #计算loss的算法
            loss = (loss * loss_mask).sum() / loss_mask.sum()
            loss += res.aux_loss
            loss = loss / TrainArgs.accumulation_steps  

            scaler.scale(loss).backward()  #反向传播

        if (step + 1) % TrainArgs.accumulation_steps == 0:#判断梯度累加到位没有
            scaler.unscale_(optimizer)#*前置操作
            torch.nn.utils.clip_grad_norm_(model.parameters(), TrainArgs.grad_clip)
            #梯度裁剪
            scaler.step(optimizer)#参数更新
            scaler.update()#*调整

            optimizer.zero_grad(set_to_none=True)#本次训练结束，清零梯度，为下一次反向传播准备
             
            
        if step % TrainArgs.log_interval == 0:
            spend_time = time.time() - start_time
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                "loss": f"{loss.item():.3f}",
                "lr": f"{current_lr:.2e}"
            })

            if (swanlab is not None) : #用日志软件记录训练过程
                swanlab.log({"loss": loss.item() ,
                           "lr": optimizer.param_groups[0]['lr'],
                           "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

            # 保存checkpoint
        if (step + 1) % TrainArgs.save_interval == 0:
            # 只在更新步之后保存，确保梯度累积完成
            if (step + 1) % TrainArgs.accumulation_steps == 0:
                model.eval()
                print(f"\nSaving checkpoint at step {step+1}...")
                
                # 创建checkpoint字典
                checkpoint = {
                    
                    'model': model.state_dict(),  # 保留原始精度
                    'optimizer': optimizer.state_dict(),
                    'scaler': scaler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'config': TrainArgs,
                }
                
                # 保存
                torch.save(checkpoint, TrainArgs.checkpoint_path)
                print("Checkpoint saved.")
                model.train() # 切换回训练模式


# 实现循环训练

In [13]:
# --- 关键：Checkpoint加载逻辑 ---
start_epoch = 0
start_step = 0
if os.path.exists(TrainArgs.checkpoint_path):
    print(f"=> Resuming from checkpoint: {TrainArgs.checkpoint_path}")
    checkpoint = torch.load(TrainArgs.checkpoint_path, map_location=TrainArgs.device)
    
    # 加载模型权重 (处理DDP保存的权重)
    model_state_dict = checkpoint['model']
    if any(key.startswith('module.') for key in model_state_dict):
            model_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict)
    
    # 加载优化器和scaler状态
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    
    # 恢复训练进度
    start_epoch = checkpoint['epoch']
    # 注意：恢复step逻辑可以让dataloader跳过已训练数据，这里简化为从头开始当前epoch
    print(f"=> Resumed from epoch {start_epoch + 1}")
else:
    print("=> Starting from scratch...")

model.to(TrainArgs.device)

=> Starting from scratch...


MiniMindForCausalLM(
  (model): MiniMindModel(
    (embed_tokens): Embedding(6400, 512)
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-7): 8 x MiniMindBlock(
        (self_attn): Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=128, bias=False)
          (v_proj): Linear(in_features=512, out_features=128, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (mlp): FeedForward(
          (gate_proj): Linear(in_features=512, out_features=1408, bias=False)
          (down_proj): Linear(in_features=1408, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=1408, bias=False)
          (drop

In [22]:
#测试训练运行以下代码
for epoch in range(start_epoch, TrainArgs.epochs):
        train_epoch(epoch, model, test_loader, optimizer, scaler, lm_config)

Epoch 1/2:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2/2:   0%|          | 0/8 [00:00<?, ?it/s]

In [14]:
#正式训练运行以下代码
for epoch in range(start_epoch, TrainArgs.epochs):
        train_epoch(epoch, model, train_loader, optimizer, scaler, lm_config)

Epoch 1/2:   0%|          | 0/32785 [00:00<?, ?it/s]


Saving checkpoint at step 2000...
Checkpoint saved.

Saving checkpoint at step 4000...
Checkpoint saved.

Saving checkpoint at step 6000...
Checkpoint saved.

Saving checkpoint at step 8000...
Checkpoint saved.

Saving checkpoint at step 10000...
Checkpoint saved.

Saving checkpoint at step 12000...
Checkpoint saved.


KeyboardInterrupt: 

In [19]:
import torch
import gc

del train_ds, train_loader  # 解除引用

# 2. 清空PyTorch内部缓存
torch.cuda.empty_cache()  # 释放GPU显存

# 3. 触发Python垃圾回收（确保循环引用被清理）
gc.collect()  # 强制回收内存

NameError: name 'train_ds' is not defined

In [21]:
print(f"GPU显存占用: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") 

# 检查对象是否已被销毁
print('train_ds' in locals()) 

GPU显存占用: 21000.05 MB
False
