2、实现分布式DDP模型训练，存盘后加载实现推理。

In [13]:
%%writefile ddp_ner.py

import os
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification
from transformers import TrainingArguments, Trainer
import evaluate  # pip install evaluate
import seqeval   # pip install seqeval
import torch
import torch.multiprocessing as mp
import torch.distributed as dist

# 对ds中的数据进行过滤:过滤掉tokens为空的数据
def data_filter(item):
    return len(item['tokens']) > 0

def data_input_proc_fn(tokenizer):
    def data_input_proc(item):
        input_data_list = []
        # 对tokens进行分词,而不是将tokens合并成句子再分词,因为合并成句子再分词会导致input_ids的长度和ner_tags的长度不一致
        # is_split_into_words=True已经分词不需要再分词，https://hf.cloudwisdom.top/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
        input_data = tokenizer(item['tokens'], 
                               truncation=True, 
                               add_special_tokens=False, 
                               max_length=512, 
                               is_split_into_words=True,
                               padding='max_length')
        # 对ner_tags的长度也进行截取和input_data长度一致
        ner_tags = [n[:512] for n in item['ner_tags']]
        # DataCollatorForTokenClassification中需要有labels这个标签
        input_data['labels'] = ner_tags
        return input_data
    return data_input_proc

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

# 模型训练
def train(rank, world_size):
    setup(rank, world_size)
    
    model_name = 'google-bert/bert-base-chinese'
    EPOCHES = 1
    # 数据预处理
    ds = load_dataset("doushabao4766/msra_ner_k_V3")
    ds['train'] = ds['train'].filter(data_filter)
    ds['test'] = ds['test'].filter(data_filter)
    tags = ds['train'].features['ner_tags'].feature.names
    entites = ['O', 'PER', 'ORG', 'LOC']
    entity_index = {e:i for i,e in enumerate(entites)}
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    max_length = tokenizer.model_max_length # 512
    ds = ds.map(data_input_proc_fn(tokenizer), batched=True)
    ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
    
    # 构建模型
    id2label = {i:tag for i, tag in enumerate(tags)}
    label2id = {tag:i for i, tag in enumerate(tags)}
    model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=7, id2label=id2label, label2id=label2id)
    # model.to(rank)

    # label_pad_token_id默认为-100
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)

    train_args = TrainingArguments(
        output_dir='ner_train',
        num_train_epochs=EPOCHES,
        save_safetensors=True,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        report_to='tensorboard',
        eval_strategy='epoch',
        learning_rate=1e-4,
        local_rank=rank, # 当前进程rank
        fp16=True, # 使用混合精度
        lr_scheduler_type='linear', # 动态学习率
        warmup_steps=100, # 预热步数
        ddp_find_unused_parameters=False # 优化DDP性能
    )

    def compute_metrics(result):
        predicts,labels = result
        # predicts.shape = (样本数量, padding后的sequence_length, num_labels)
        # labels.shape = (样本数量, padding后的sequence_length)
        # 获取评估对象
        seqeval = evaluate.load('seqeval')
        predicts = np.argmax(predicts, axis=2)
        # 准备评估数据
        predicts = [[tags[p] for p,l in zip(ps,ls) if l != -100]
                     for ps,ls in zip(predicts,labels)]
        labels = [[tags[l] for l in ls if l != -100]
                     for ls in labels]
        results = seqeval.compute(predictions=predicts, references=labels)
        return results
    
    trainer = Trainer(
        model, 
        train_args,
        train_dataset=ds['train'],
        eval_dataset=ds['test'],
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    trainer.train()
    cleanup()

def main():
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == '__main__':
    main()

Overwriting ddp_ner.py


In [14]:
!python ddp_ner.py

2025-06-14 15:10:37.644228: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749913837.668023    6759 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749913837.675568    6759 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-14 15:10:46.982746: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749913847.005157    6773 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-14 15:10:47.010584: E external/local_xla/xla

In [15]:
from transformers import pipeline
# 进行命名实体识别
ner = pipeline('token-classification', '/kaggle/working/ner_train/checkpoint-1407')
seq = '双方确定了今后发展中美关系的指导方针。'
ner_result = ner(seq)
print(ner_result)

2025-06-14 16:39:12.934549: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749919152.957840      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749919152.964877      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Device set to use cuda:0


[{'entity': 'B-LOC', 'score': 0.9991559, 'index': 10, 'word': '中', 'start': 9, 'end': 10}, {'entity': 'B-LOC', 'score': 0.9991559, 'index': 11, 'word': '美', 'start': 10, 'end': 11}]
