In [22]:
import os
import json

import torch

from utils_tools import load_vocabulary, extract_kvpairs_in_bio
from lstm_models import LSTMModel, LSTMCRFModel

In [5]:
# 使用 os.path.abspath("") 可获取当前脚本所在的目录
root_path = os.path.dirname(os.path.abspath(""))
print(root_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载词汇表
vocab_char_path = os.path.join(root_path, "./data/resume-zh/vocab_char.txt")
vocab_bioattr_path = os.path.join(root_path, "./data/resume-zh/vocab_bioattr.txt")
w2i_char, i2w_char = load_vocabulary(vocab_char_path)
w2i_bio, i2w_bio = load_vocabulary(vocab_bioattr_path)

d:\code\github\ner
load vocab from: d:\code\github\ner\./data/resume-zh/vocab_char.txt, containing words: 4295
load vocab from: d:\code\github\ner\./data/resume-zh/vocab_bioattr.txt, containing words: 17


In [7]:
lstm_model_path = os.path.join(root_path, "./ckpt/lstm_model.pt")
crf_model_path = os.path.join(root_path, "./ckpt/lstm_crf_model.pt")

# 加载模型, 这种加载方式需要保存下类的初始化参数
lstm_model = LSTMModel(
    num_embeddings=len(w2i_char),
    output_size=len(w2i_bio),
    embedding_dim=300,
    hidden_size=300,
)
lstm_model.load_state_dict(torch.load(lstm_model_path))
lstm_model.eval()

LSTMModel(
  (embedding): Embedding(4295, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=600, out_features=17, bias=True)
)

In [9]:
test_file = os.path.join(root_path, "./data/resume-zh/test.json")
with open(test_file, "r", encoding="utf-8") as f:
    test_data = json.load(f)

test_data[0]

{'sentence': ['常', '建', '良', '，', '男', '，'],
 'ner': [{'index': [0, 1, 2], 'type': 'NAME'}],
 'word': [[0], [1, 2], [3], [4], [5]]}

In [24]:
index = 10
data = test_data[index]
sentence = data["sentence"]
ner = data["ner"]
print("".join(sentence))

for x in ner:
    print(x["type"], "".join(sentence[i] for i in x["index"]))

现任大股东无锡产业发展集团有限公司董事局董事、无锡威孚高科技集团股份有限公司党委书记。
ORG 大股东无锡产业发展集团有限公司
TITLE 董事局董事
ORG 无锡威孚高科技集团股份有限公司
TITLE 党委书记


In [25]:
x = [w2i_char.get(w, 0) for w in sentence]
x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
y = lstm_model.predict(x).squeeze(0).cpu().numpy()
print(y.shape)
print(y)
y = [i2w_bio[x] for x in y]
print(y)

# 从 BIO 序列中提取出 K-V 对
kvpairs = extract_kvpairs_in_bio(y, list(sentence))
print(kvpairs)

(43,)
[ 0  0  9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 15 16 16 16 16  0  9
 10 10 10 10 10 10 10 10 10 10 10 10 10 10 15 16 16 16  0]
['O', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'B-TITLE', 'I-TITLE', 'I-TITLE', 'I-TITLE', 'I-TITLE', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'B-TITLE', 'I-TITLE', 'I-TITLE', 'I-TITLE', 'O']
{('ORG', '无锡威孚高科技集团股份有限公司'), ('TITLE', '董事局董事'), ('TITLE', '党委书记'), ('ORG', '大股东无锡产业发展集团有限公司')}
