Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepKE-cnSchema使用遇到问题 #531

Closed
512fengbujue opened this issue Jun 14, 2024 · 7 comments
Closed

DeepKE-cnSchema使用遇到问题 #531

512fengbujue opened this issue Jun 14, 2024 · 7 comments
Labels
question Further information is requested

Comments

@512fengbujue
Copy link

Describe the question

A clear and concise description of what the question is.
RuntimeError: Error(s) in loading state_dict for BertNer:
size mismatch for classifier.weight: copying a param with shape torch.Size([60, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([60]) from checkpoint, the shape in current model is torch.Size([2]).

Environment (please complete the following information):

  • OS: [e.g. mac / window]
  • Python Version [e.g. 3.6]
    window 3.8.19

Screenshots

If applicable, add screenshots to help explain your problem.

image

Additional context

Add any other context about the problem here.
text: '此外网易云平台还上架了一系列歌曲,其中包括田馥甄的《小幸运》等'
nerfp: 'D:\项目\地面推理机\DeepKE\DeepKE-main\example\triple\cnschema\ner_pretrain_model\checkpoints_pretrain_bert'
refp: 'D:\项目\地面推理机\DeepKE\DeepKE-main\example\triple\cnschema\re_pretrain_model\re_bert.pth'

cwd: ???

defaults:

  • hydra/output: custom
  • preprocess
  • train
  • embedding
  • predict
  • model: lm #

vocab_size: ???
word_dim: 60
pos_size: ??? # 2 * pos_limit + 2
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同

dim_strategy: sum # [cat, sum]

关系种类

num_relations: 51

@512fengbujue 512fengbujue added the question Further information is requested label Jun 14, 2024
@BeasterYong
Copy link
Collaborator

BeasterYong commented Jun 15, 2024

您好,请检查一下ner模型版本是否正确,可以在链接中重新下载后重试

@F2023888
Copy link

您好,我遇到了另外一个问题。
11
Logits Label: [58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Logits Confidence: [0.507036566734314, 0.5059388875961304, 0.5061192512512207, 0.5054084062576294, 0.5053091049194336, 0.5055623650550842, 0.5059460997581482, 0.5060293674468994, 0.5062680840492249, 0.5059909224510193, 0.5056659579277039, 0.5060045719146729, 0.5064235329627991, 0.505791962146759, 0.5063173770904541, 0.5058268904685974, 0.5060592293739319, 0.5063878893852234, 0.5065939426422119, 0.5059806108474731, 0.505949079990387, 0.5061217546463013, 0.5054893493652344, 0.5059810280799866, 0.5060664415359497, 0.506255030632019, 0.5060448050498962, 0.5057430863380432, 0.5060313940048218, 0.5054957866668701, 0.5062739253044128, 0.5064918398857117, 0.5069468021392822, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447]
Labels: [('[CLS]', 0.5059388875961304), ('[CLS]', 0.5061192512512207), ('[CLS]', 0.5054084062576294), ('[CLS]', 0.5053091049194336), ('[CLS]', 0.5055623650550842), ('[CLS]', 0.5059460997581482), ('[CLS]', 0.5060293674468994), ('[CLS]', 0.5062680840492249), ('[CLS]', 0.5059909224510193), ('[CLS]', 0.5056659579277039), ('[CLS]', 0.5060045719146729), ('[CLS]', 0.5064235329627991), ('[CLS]', 0.505791962146759), ('[CLS]', 0.5063173770904541), ('[CLS]', 0.5058268904685974), ('[CLS]', 0.5060592293739319), ('[CLS]', 0.5063878893852234), ('[CLS]', 0.5065939426422119), ('[CLS]', 0.5059806108474731), ('[CLS]', 0.505949079990387), ('[CLS]', 0.5061217546463013), ('[CLS]', 0.5054893493652344), ('[CLS]', 0.5059810280799866), ('[CLS]', 0.5060664415359497), ('[CLS]', 0.506255030632019), ('[CLS]', 0.5060448050498962), ('[CLS]', 0.5057430863380432), ('[CLS]', 0.5060313940048218), ('[CLS]', 0.5054957866668701), ('[CLS]', 0.5062739253044128), ('[CLS]', 0.5064918398857117)]
Result: [('此', '[CLS]'), ('外', '[CLS]'), ('网', '[CLS]'), ('易', '[CLS]'), ('云', '[CLS]'), ('平', '[CLS]'), ('台', '[CLS]'), ('还', '[CLS]'), ('上', '[CLS]'), ('架', '[CLS]'), ('了', '[CLS]'), ('一', '[CLS]'), ('系', '[CLS]'), ('列', '[CLS]'), ('歌', '[CLS]'), ('曲', '[CLS]'), (',', '[CLS]'), ('其', '[CLS]'), ('中', '[CLS]'), ('包', '[CLS]'), ('括', '[CLS]'), ('田', '[CLS]'), ('馥', '[CLS]'), ('甄', '[CLS]'), ('的', '[CLS]'), ('《', '[CLS]'), ('小', '[CLS]'), ('幸', '[CLS]'), ('运', '[CLS]'), ('》', '[CLS]'), ('等', '[CLS]')]
它好像没有成功运行NER任务,我的NER和RE都可以单独正确执行,这是因为什么呢

@BeasterYong
Copy link
Collaborator

您好,我遇到了另外一个问题。 11 Logits Label: [58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Logits Confidence: [0.507036566734314, 0.5059388875961304, 0.5061192512512207, 0.5054084062576294, 0.5053091049194336, 0.5055623650550842, 0.5059460997581482, 0.5060293674468994, 0.5062680840492249, 0.5059909224510193, 0.5056659579277039, 0.5060045719146729, 0.5064235329627991, 0.505791962146759, 0.5063173770904541, 0.5058268904685974, 0.5060592293739319, 0.5063878893852234, 0.5065939426422119, 0.5059806108474731, 0.505949079990387, 0.5061217546463013, 0.5054893493652344, 0.5059810280799866, 0.5060664415359497, 0.506255030632019, 0.5060448050498962, 0.5057430863380432, 0.5060313940048218, 0.5054957866668701, 0.5062739253044128, 0.5064918398857117, 0.5069468021392822, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447, 0.017459817230701447] Labels: [('[CLS]', 0.5059388875961304), ('[CLS]', 0.5061192512512207), ('[CLS]', 0.5054084062576294), ('[CLS]', 0.5053091049194336), ('[CLS]', 0.5055623650550842), ('[CLS]', 0.5059460997581482), ('[CLS]', 0.5060293674468994), ('[CLS]', 0.5062680840492249), ('[CLS]', 0.5059909224510193), ('[CLS]', 0.5056659579277039), ('[CLS]', 0.5060045719146729), ('[CLS]', 0.5064235329627991), ('[CLS]', 0.505791962146759), ('[CLS]', 0.5063173770904541), ('[CLS]', 0.5058268904685974), ('[CLS]', 0.5060592293739319), ('[CLS]', 0.5063878893852234), ('[CLS]', 0.5065939426422119), ('[CLS]', 0.5059806108474731), ('[CLS]', 0.505949079990387), ('[CLS]', 0.5061217546463013), ('[CLS]', 0.5054893493652344), ('[CLS]', 0.5059810280799866), ('[CLS]', 0.5060664415359497), ('[CLS]', 0.506255030632019), ('[CLS]', 0.5060448050498962), ('[CLS]', 0.5057430863380432), ('[CLS]', 0.5060313940048218), ('[CLS]', 0.5054957866668701), ('[CLS]', 0.5062739253044128), ('[CLS]', 0.5064918398857117)] Result: [('此', '[CLS]'), ('外', '[CLS]'), ('网', '[CLS]'), ('易', '[CLS]'), ('云', '[CLS]'), ('平', '[CLS]'), ('台', '[CLS]'), ('还', '[CLS]'), ('上', '[CLS]'), ('架', '[CLS]'), ('了', '[CLS]'), ('一', '[CLS]'), ('系', '[CLS]'), ('列', '[CLS]'), ('歌', '[CLS]'), ('曲', '[CLS]'), (',', '[CLS]'), ('其', '[CLS]'), ('中', '[CLS]'), ('包', '[CLS]'), ('括', '[CLS]'), ('田', '[CLS]'), ('馥', '[CLS]'), ('甄', '[CLS]'), ('的', '[CLS]'), ('《', '[CLS]'), ('小', '[CLS]'), ('幸', '[CLS]'), ('运', '[CLS]'), ('》', '[CLS]'), ('等', '[CLS]')] 它好像没有成功运行NER任务,我的NER和RE都可以单独正确执行,这是因为什么呢

您好,经过我们验证没有遇到您的问题,可以麻烦发一下详细的参数配置和运行场景吗?

@F2023888
Copy link

predict py
参数就是修改了这个config.yaml和lm.yaml(从hugging——face下载到本地的bert_base_chinese的路径)
Uploading lm.png…

@F2023888
Copy link

packages in environment at D:\Anaconda_envs\envs\rz:

Name Version Build Channel

absl-py 2.1.0 pypi_0 pypi
aiohttp 3.9.5 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
antlr4-python3-runtime 4.8 pypi_0 pypi
asttokens 2.4.1 pypi_0 pypi
async-timeout 4.0.3 pypi_0 pypi
attrs 23.2.0 pypi_0 pypi
backcall 0.2.0 pypi_0 pypi
blinker 1.8.2 pypi_0 pypi
ca-certificates 2024.3.11 haa95532_0 defaults
cachetools 4.2.4 pypi_0 pypi
certifi 2024.2.2 pypi_0 pypi
charset-normalizer 3.3.2 pypi_0 pypi
click 8.1.7 pypi_0 pypi
colorama 0.4.6 pypi_0 pypi
configparser 7.0.0 pypi_0 pypi
cycler 0.12.1 pypi_0 pypi
datasets 2.13.2 pypi_0 pypi
decorator 5.1.1 pypi_0 pypi
dill 0.3.6 pypi_0 pypi
docker-pycreds 0.4.0 pypi_0 pypi
environs 11.0.0 pypi_0 pypi
executing 2.0.1 pypi_0 pypi
filelock 3.14.0 pypi_0 pypi
flask 3.0.3 pypi_0 pypi
flask-cors 4.0.1 pypi_0 pypi
frozendict 2.4.4 pypi_0 pypi
frozenlist 1.4.1 pypi_0 pypi
fsspec 2024.6.0 pypi_0 pypi
gitdb 4.0.11 pypi_0 pypi
gitpython 3.1.43 pypi_0 pypi
google-auth 1.35.0 pypi_0 pypi
google-auth-oauthlib 0.4.6 pypi_0 pypi
grpcio 1.64.1 pypi_0 pypi
huggingface-hub 0.11.0 pypi_0 pypi
hydra-core 1.0.6 pypi_0 pypi
idna 3.7 pypi_0 pypi
importlib-metadata 7.1.0 pypi_0 pypi
importlib-resources 6.4.0 pypi_0 pypi
interchange 2021.0.4 pypi_0 pypi
ipdb 0.13.11 pypi_0 pypi
ipython 8.12.3 pypi_0 pypi
itsdangerous 2.2.0 pypi_0 pypi
jedi 0.19.1 pypi_0 pypi
jieba 0.42.1 pypi_0 pypi
jinja2 3.1.2 pypi_0 pypi
joblib 1.4.2 pypi_0 pypi
jsonlines 4.0.0 pypi_0 pypi
kiwisolver 1.4.5 pypi_0 pypi
libffi 3.4.4 hd77b12b_1 defaults
lxml 5.2.2 pypi_0 pypi
markdown 3.6 pypi_0 pypi
markupsafe 2.1.5 pypi_0 pypi
marshmallow 3.21.2 pypi_0 pypi
matplotlib 3.4.1 pypi_0 pypi
matplotlib-inline 0.1.7 pypi_0 pypi
monotonic 1.6 pypi_0 pypi
multidict 6.0.5 pypi_0 pypi
multiprocess 0.70.14 pypi_0 pypi
mysql-connector 2.2.9 pypi_0 pypi
neo4j 5.20.0 pypi_0 pypi
nltk 3.8 pypi_0 pypi
numpy 1.21.0 pypi_0 pypi
oauthlib 3.2.2 pypi_0 pypi
omegaconf 2.0.6 pypi_0 pypi
openai 0.28.0 pypi_0 pypi
openssl 3.0.13 h2bbff1b_2 defaults
opt-einsum 3.3.0 pypi_0 pypi
packaging 24.0 pypi_0 pypi
pandas 2.0.3 pypi_0 pypi
pansi 2020.7.3 pypi_0 pypi
parso 0.8.4 pypi_0 pypi
pathtools 0.1.2 pypi_0 pypi
pickleshare 0.7.5 pypi_0 pypi
pillow 10.3.0 pypi_0 pypi
pip 24.0 py38haa95532_0 defaults
promise 2.3 pypi_0 pypi
prompt-toolkit 3.0.47 pypi_0 pypi
protobuf 3.20.1 pypi_0 pypi
psutil 5.9.8 pypi_0 pypi
psycopg2 2.9.9 pypi_0 pypi
pure-eval 0.2.2 pypi_0 pypi
py2neo 2021.2.4 pypi_0 pypi
pyarrow 16.1.0 pypi_0 pypi
pyasn1 0.6.0 pypi_0 pypi
pyasn1-modules 0.4.0 pypi_0 pypi
pygments 2.18.0 pypi_0 pypi
pyhocon 0.3.60 pypi_0 pypi
pyld 2.0.4 pypi_0 pypi
pymysql 1.1.1 pypi_0 pypi
pyparsing 3.1.2 pypi_0 pypi
python 3.8.19 h1aa4202_0 defaults
python-dateutil 2.9.0.post0 pypi_0 pypi
python-dotenv 1.0.1 pypi_0 pypi
pytorch-crf 0.7.2 pypi_0 pypi
pytz 2024.1 pypi_0 pypi
pyyaml 6.0.1 pypi_0 pypi
regex 2024.5.15 pypi_0 pypi
requests 2.32.3 pypi_0 pypi
requests-oauthlib 2.0.0 pypi_0 pypi
rsa 4.9 pypi_0 pypi
safetensors 0.4.3 pypi_0 pypi
scikit-learn 0.24.1 pypi_0 pypi
scipy 1.10.1 pypi_0 pypi
sentry-sdk 2.5.1 pypi_0 pypi
seqeval 1.2.2 pypi_0 pypi
setuptools 69.5.1 py38haa95532_0 defaults
shortuuid 1.0.13 pypi_0 pypi
six 1.16.0 pypi_0 pypi
smmap 5.0.1 pypi_0 pypi
sqlite 3.45.3 h2bbff1b_0 defaults
stack-data 0.6.3 pypi_0 pypi
subprocess32 3.5.4 pypi_0 pypi
tensorboard 2.4.1 pypi_0 pypi
tensorboard-plugin-wit 1.8.1 pypi_0 pypi
tensorboardx 2.5.1 pypi_0 pypi
termcolor 2.4.0 pypi_0 pypi
threadpoolctl 3.5.0 pypi_0 pypi
tokenizers 0.13.3 pypi_0 pypi
tomli 2.0.1 pypi_0 pypi
torch 1.11.0+cu113 pypi_0 pypi
tqdm 4.66.1 pypi_0 pypi
traitlets 5.14.3 pypi_0 pypi
transformers 4.26.0 pypi_0 pypi
typing-extensions 4.12.1 pypi_0 pypi
tzdata 2024.1 pypi_0 pypi
ujson 5.6.0 pypi_0 pypi
urllib3 2.2.1 pypi_0 pypi
vc 14.2 h2eaa2aa_1 defaults
vs2015_runtime 14.29.30133 h43f2093_3 defaults
wandb 0.12.7 pypi_0 pypi
wcwidth 0.2.13 pypi_0 pypi
werkzeug 3.0.3 pypi_0 pypi
wheel 0.43.0 py38haa95532_0 defaults
xxhash 3.4.1 pypi_0 pypi
yarl 1.9.4 pypi_0 pypi
yaspin 2.5.0 pypi_0 pypi
zipp 3.19.0 pypi_0 pypi

这个是我的conda list

@F2023888
Copy link

这个是predict.py
`logger = logging.getLogger(name)

def _preprocess_data(data, cfg):

relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
rels = _handle_relation_data(relation_data)
_lm_serialize(data,cfg)
return data, rels

def get_jsonld(head,rel,tail,url):

doc = {
"@id": head,
url: {"@id":tail},
}
context = {
    rel:url
}
compacted = jsonld.compact(doc, context)
logger.info(json.dumps(compacted, indent=2,ensure_ascii=False))

@hydra.main(config_path="conf", config_name='config')
def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd

label2word = {}
with open(os.path.join(cfg.cwd, cfg.data_path, 'type.txt'), 'r', encoding='utf-8') as f:
    data = f.readlines()
    for d in data:
        label2word[d.split(' ')[1].strip('\n')] = d.split(' ')[0]
logger.info(label2word)
rel2url = {}
with open(os.path.join(cfg.cwd, cfg.data_path, 'url.txt'), 'r', encoding='utf-8') as f:
    data = f.readlines()
    for d in data:
        rel2url[d.split(' ')[0]] = d.split(' ')[1].strip('\n')
logger.info(rel2url)
model = InferNer(cfg.nerfp)
text = cfg.text

print("text")
logger.info(text)

result = model.predict(text)
print("result")
logger.info(result)
temp = ''
last_type = result[0][1][2:]
res = {}
word_len = len(result)

for i in range(word_len):
    k = result[i][0]
    v = result[i][1]
    
    if v[0] == 'B':
        if temp != '':
            res[temp] =  label2word[result[i - 1][1][2:]]
        temp = k
        last_type = result[i][1][2:]
    elif v[0] == 'I':
        if last_type == result[i][1][2:]:
            temp += k        
    

if temp != '':
    res[temp] = label2word[result[len(result) - 1][1][2:]]
logger.info(res)
entity = []
entity_label = []
for k,v in res.items():
  entity.append(k)
  entity_label.append(v)

entity_len = len(entity_label)

for i in range(entity_len):
  for j in range(i + 1,entity_len):
    instance = dict()
    instance['sentence'] = text.strip()
    instance['head'] = entity[i].strip()
    instance['tail'] = entity[j].strip()
    instance['head_type'] = entity_label[i].strip()
    instance['tail_type'] = entity_label[j].strip()

    data = [instance]
    data, rels = _preprocess_data(data, cfg)

    device = torch.device('cpu')

    model = LM(cfg)
    model.load(cfg.refp, device=device)
    model.to(device)
    model.eval()

    x = dict()
    x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
    for key in x.keys():
        x[key] = x[key].to(device)

    with torch.no_grad():
        y_pred = model(x)
        y_pred = torch.softmax(y_pred, dim=-1)[0]
        prob = y_pred.max().item()
        prob_rel = list(rels.keys())[y_pred.argmax().item()]
        logger.info(f"\"{data[0]['head']}\" 和 \"{data[0]['tail']}\" 在句中关系为:\"{prob_rel}\",置信度为{prob:.2f}。")
        get_jsonld(data[0]['head'],prob_rel,data[0]['tail'],rel2url[prob_rel])

if name == 'main':
main()`

@F2023888
Copy link

抱歉,我的pytorch_transformers没有下载,现在已经解决了

@zxlzr zxlzr closed this as completed Jun 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants