In [1]:
from transformers import AutoModelForTokenClassification, AutoTokenizer,DataCollatorForTokenClassification
from transformers import TrainingArguments, Trainer
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

# entity_index
entites = ['O'] + list({'movie', 'name', 'game', 'address', 'position', \
           'company', 'scene', 'book', 'organization', 'government'})
tags = ['O']
for entity in entites[1:]:
    tags.append('B-' + entity.upper())
    tags.append('I-' + entity.upper())

entity_index = {entity:i for i, entity in enumerate(entites)}
id2lbl = {i:tag for i, tag in enumerate(tags)}
lbl2id = {tag:i for i, tag in enumerate(tags)}

MODEL_NAME = 'google-bert/bert-base-chinese'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def init_model():
    return AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(tags),
                                                            id2label=id2lbl,
                                                            label2id=lbl2id).to(DEVICE)

def init_dl():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    ds = load_dataset('nlhappy/CLUE-NER')
    
    def mapper(item):
        tags = [0] * 512    # 初始值为‘O’
        # 遍历实体列表，所有实体类别标记填入tags
        entites = item['ents']
        for ent in entites:
            indices = ent['indices']  # 实体索引
            label = ent['label']   # 实体名
            tags[indices[0]] = entity_index[label] * 2 - 1
            for idx in indices[1:]:
                tags[idx] = entity_index[label] * 2
    
        input_data = tokenizer(list(item['text']), truncation=True, add_special_tokens=False, max_length=512, 
                               is_split_into_words=True, padding='max_length')
        input_data['labels'] = tags
        return input_data
    
    # 使用自定义回调函数处理数据集记录
    ds1 = ds.map(mapper)
    ds1.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
    # for item in ds1['train']:
    #     print(item)
    #     break
    train_dl = DataLoader(ds1['train'], shuffle=True, batch_size=16)
    return train_dl

2025-06-13 13:46:47.729824: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749822408.183925      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749822408.296024      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
from transformers import get_linear_schedule_with_warmup
import torch.optim as optim
from tqdm import tqdm

# 模型参数分组
def init_scheduler(model, train_steps):
    bert_params, classifier_params = [],[]
    for name, params in model.named_parameters():
        if 'bert' in name:
            bert_params.append(params)
        else:
            classifier_params.append(params)
    
    param_groups = [
        {'params': bert_params, 'lr': 1e-5},
        {'params': classifier_params, 'lr': 1e-3, 'weight_decay': 0.1}
    ]
    optimizer = optim.AdamW(param_groups)
    return get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps=100, 
                                                num_training_steps=train_steps), optimizer

In [3]:
# model = init_model()
# device_count = torch.cuda.device_count()
# print('device_count', device_count)
# model = torch.nn.DataParallel(model)
# train_dl = init_dl()
# train_steps = len(train_dl) * 5
# scheduler, optimizer = init_scheduler(model, train_steps)

# for epoch in range(5):
#     model.train()
#     tpbar = tqdm(train_dl)
#     for items in tpbar:
#         items = {k:v.to(DEVICE) for k,v in items.items()}
#         optimizer.zero_grad()
#         outputs = model(**items)
#         loss = outputs.loss.mean()
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
    
#         tpbar.set_description(f'Epoch:{epoch+1} ' + 
#                           f'bert_lr:{scheduler.get_lr()[0]} ' + 
#                           f'classifier_lr:{scheduler.get_lr()[1]} '+
#                           f'Loss:{loss.item():.4f}')


In [4]:
model = init_model()
device_count = torch.cuda.device_count()
print('device_count', device_count)
model = torch.nn.DataParallel(model)
train_dl = init_dl()
train_steps = len(train_dl) * 5
scheduler, optimizer = init_scheduler(model, train_steps)

# 混合精度训练
# 梯度计算缩放器
scaler = torch.GradScaler()

for epoch in range(5):
    model.train()
    tpbar = tqdm(train_dl)
    for items in tpbar:
        items = {k:v.to(DEVICE) for k,v in items.items()}
        optimizer.zero_grad()

        with torch.autocast(device_type=DEVICE):
            outputs = model(**items)
        loss = outputs.loss.mean()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # loss.backward()
        # optimizer.step()
        scheduler.step()
    
        tpbar.set_description(f'Epoch:{epoch+1} ' + 
                          f'bert_lr:{scheduler.get_lr()[0]} ' + 
                          f'classifier_lr:{scheduler.get_lr()[1]} '+
                          f'Loss:{loss.item():.4f}')

config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


device_count 2


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/970 [00:00<?, ?B/s]

(…)-00000-of-00001-a33d0e4276aef9b4.parquet:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

(…)-00000-of-00001-07f476b71c5edde6.parquet:   0%|          | 0.00/178k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10748 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1343 [00:00<?, ? examples/s]

Map:   0%|          | 0/10748 [00:00<?, ? examples/s]

Map:   0%|          | 0/1343 [00:00<?, ? examples/s]

Epoch:1 bert_lr:8.245398773006135e-06 classifier_lr:0.0008245398773006136 Loss:0.0364: 100%|██████████| 672/672 [03:32<00:00,  3.17it/s]
Epoch:2 bert_lr:6.184049079754602e-06 classifier_lr:0.0006184049079754601 Loss:0.0101: 100%|██████████| 672/672 [03:32<00:00,  3.16it/s]
Epoch:3 bert_lr:4.122699386503068e-06 classifier_lr:0.0004122699386503068 Loss:0.0193: 100%|██████████| 672/672 [03:32<00:00,  3.17it/s]
Epoch:4 bert_lr:2.061349693251534e-06 classifier_lr:0.0002061349693251534 Loss:0.0126: 100%|██████████| 672/672 [03:32<00:00,  3.16it/s]
Epoch:5 bert_lr:0.0 classifier_lr:0.0 Loss:0.0189: 100%|██████████| 672/672 [03:32<00:00,  3.17it/s]


In [5]:
torch.save(model.state_dict(), 'bert_text_scheduler.pt')