In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification
from transformers import AdamW
import numpy as np

# 定义标签映射
label_map = {
    'address': 1, 'book': 2, 'company': 3, 'game': 4, 'government': 5,
    'movie': 6, 'name': 7, 'organization': 8, 'position': 9, 'scene': 10
}
id2label = {v: k for k, v in label_map.items()}
id2label[0] = 'O'  # 添加O标签表示非实体

class NERDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=256):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                example = json.loads(line)
                self.data.append(example)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        text = example['text']
        labels = example['label']
        
        # 初始化标签序列
        label_seq = ['O'] * len(text)
        
        # 填充标签序列
        for entity_type, entities in labels.items():
            for entity, spans in entities.items():
                for start, end in spans:
                    for i in range(start, end):
                        label_seq[i] = entity_type
        
        # 转换为ID
        label_ids = [label_map.get(label, 0) for label in label_seq]
        
        # tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        # 调整label以匹配tokenize后的长度
        token_labels = [-100] * self.max_len  # -100是忽略的标签
        word_ids = encoding.word_ids()
        
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                token_labels[idx] = label_ids[word_id]
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(token_labels)
        }

# 初始化tokenizer和模型
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
model = BertForTokenClassification.from_pretrained(
    'bert-base-chinese',
    num_labels=len(label_map) + 1  # +1 for 'O' label
)

# 准备数据
train_dataset = NERDataset('train.json', tokenizer)
test_dataset = NERDataset('test.json', tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

# 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 5

# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')

# 保存模型
model.save_pretrained('./ner_model')
tokenizer.save_pretrained('./ner_model')

# 评估函数
def evaluate(model, test_loader, device):
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            pred = torch.argmax(logits, dim=2)
            
            predictions.extend(pred.cpu().numpy())
            true_labels.extend(labels.numpy())
    
    return predictions, true_labels

# 进行预测
predictions, true_labels = evaluate(model, test_loader, device)

# 预测新文本的函数
def predict_text(text, model, tokenizer, device):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2)
    
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    label_ids = predictions[0].cpu().numpy()
    
    results = []
    current_entity = None
    current_text = ''
    
    for token, label_id in zip(tokens, label_ids):
        if label_id == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
            if current_entity:
                results.append((current_text.strip(), current_entity))
                current_entity = None
                current_text = ''
            continue
        
        if id2label[label_id] != current_entity:
            if current_entity:
                results.append((current_text.strip(), current_entity))
                current_text = ''
            current_entity = id2label[label_id]
        
        if token.startswith('##'):
            current_text += token[2:]
        else:
            current_text += ' ' + token if current_text else token
    
    return results

# 示例预测
test_text = "加勒比海盗3：世界尽头在北京环球影城上映"
results = predict_text(test_text, model, tokenizer, device)
print("\n预测结果:")
for text, entity_type in results:
    print(f"文本: {text}, 类型: {entity_type}")

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Average Loss: 0.3561
Epoch 2, Average Loss: 0.2074
Epoch 3, Average Loss: 0.1501
Epoch 4, Average Loss: 0.1155
Epoch 5, Average Loss: 0.0941


KeyError: 'label'

In [3]:
import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification

def extract_scene(text, model, tokenizer, device):
    """
    提取文本中的景点信息，确保完整提取词语
    Args:
        text: 输入文本
        model: 加载的模型
        tokenizer: 分词器
        device: 设备类型（CPU/GPU）
    Returns:
        list: 提取出的景点列表
    """
    model.eval()
    # 对输入文本进行编码
    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    
    # 获取input_ids和attention_mask
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=2)
    
    # 获取词片段、偏移映射和预测标签
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    offset_mapping = encoding['offset_mapping'][0].numpy()
    label_ids = predictions[0].cpu().numpy()
    
    scenes = []
    current_scene = ''
    scene_start = None
    
    for idx, (token, label_id, offset) in enumerate(zip(tokens, label_ids, offset_mapping)):
        # 跳过特殊token
        if token in ['[CLS]', '[SEP]', '[PAD]'] or offset[0] == offset[1]:
            continue
        
        # 判断是否是景点标签（scene对应的label_id为10）
        if label_id == 10:  # scene对应的label_id
            if scene_start is None:
                scene_start = offset[0]
            
            # 检查下一个token
            next_is_scene = False
            if idx + 1 < len(label_ids):
                next_is_scene = label_ids[idx + 1] == 10
            
            if not next_is_scene:
                # 提取完整的景点名称
                scene_text = text[scene_start:offset[1]]
                if scene_text.strip():
                    scenes.append(scene_text.strip())
                scene_start = None
        else:
            scene_start = None
    
    return scenes

# 加载保存的模型
def load_model(model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = BertTokenizerFast.from_pretrained(model_path)
    model = BertForTokenClassification.from_pretrained(model_path)
    model.to(device)
    return model, tokenizer, device

def process_file(input_file, output_file, model, tokenizer, device):
    """
    处理整个文件中的文本并保存结果
    """
    results = []
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            text = data['text']
            scenes = extract_scene(text, model, tokenizer, device)
            results.append({
                'text': text,
                'scenes': scenes
            })
    
    # 保存结果
    with open(output_file, 'w', encoding='utf-8') as f:
        for item in results:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')

def main():
    # 加载模型
    model_path = './ner_model'  # 替换为你的模型路径
    model, tokenizer, device = load_model(model_path)
    
    # 测试文本
    test_texts = [
        "我昨天去北京故宫博物院参观，然后去了长城",
        "去年夏天我在西湖边散步，看到了雷峰塔",
        "准备明天去颐和园游玩，听说那里风景不错",
        "第一次到杭州什么攻略都没有做，任由同学带着到几大标志性景点转悠，到灵隐寺前还不知道它的历史文化，进去第一眼是被岩石上雕刻的栩栩如生的佛像吸引，走不久路就看到摸得光滑的“吉祥物”，应该是人人都想讨一个吉利，进去就是一线天了，然而没有指引说明，我们进去都摸不着头脑，里面基本没什么光线，而这一线天又在哪能看到？所以最好还是租个电子讲解器，也好清楚些。到灵隐寺除了拜佛烧香，剩下的就是爬山了，石梯非常多，...",
    ]
    
    # 测试单个文本
    for text in test_texts:
        scenes = extract_scene(text, model, tokenizer, device)
        print(f"\n原文: {text}")
        print(f"提取的景点: {scenes}")
    
    # 处理文件
    process_file(
        input_file='test.json',
        output_file='scenes_results.json',
        model=model,
        tokenizer=tokenizer,
        device=device
    )

if __name__ == "__main__":
    main()


原文: 我昨天去北京故宫博物院参观，然后去了长城
提取的景点: ['北京故宫博物']

原文: 去年夏天我在西湖边散步，看到了雷峰塔
提取的景点: ['雷峰']

原文: 准备明天去颐和园游玩，听说那里风景不错
提取的景点: ['颐和']

原文: 第一次到杭州什么攻略都没有做，任由同学带着到几大标志性景点转悠，到灵隐寺前还不知道它的历史文化，进去第一眼是被岩石上雕刻的栩栩如生的佛像吸引，走不久路就看到摸得光滑的“吉祥物”，应该是人人都想讨一个吉利，进去就是一线天了，然而没有指引说明，我们进去都摸不着头脑，里面基本没什么光线，而这一线天又在哪能看到？所以最好还是租个电子讲解器，也好清楚些。到灵隐寺除了拜佛烧香，剩下的就是爬山了，石梯非常多，...
提取的景点: ['灵隐', '灵隐']


In [None]:
!source /etc/profile.d/clash.sh