Skip to content

Commit 69041d9

Browse files
add init and runner script
1 parent 633d08a commit 69041d9

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

python/init_datasets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# -*- coding: UTF-8 -*-
2+
import os
3+
from shutil import copyfile
4+
from dnlp.data_process.process_cws import ProcessCWS
5+
6+
7+
def copy():
8+
src_folder = '../datasets/'
9+
dst_base_folder = 'dnlp/data/'
10+
if not os.path.exists(dst_base_folder):
11+
os.makedirs(dst_base_folder)
12+
pku = 'pku_training.utf8'
13+
copyfile(src_folder + pku, dst_base_folder + pku)
14+
15+
16+
def build_cws_datasets():
17+
files = ('pku_training.utf8',)
18+
base_folder = 'dnlp/data/cws/'
19+
if not os.path.exists(base_folder):
20+
os.makedirs(base_folder)
21+
ProcessCWS(files=files, base_folder=base_folder, name='pku_training')
22+
23+
24+
if __name__ == '__main__':
25+
copy()
26+
build_cws_datasets()

python/runner/cws_ner.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -*- coding: UTF-8 -*-
2+
import sys
3+
import getopt
4+
from dnlp.config.config import DnnCrfConfig
5+
from dnlp.core.dnn_crf import DnnCrf
6+
7+
8+
def train_cws():
9+
data_path = '../dnlp/data/cws/pku_training.pickle'
10+
config = DnnCrfConfig()
11+
dnncrf = DnnCrf(config=config, data_path=data_path,nn='lstm')
12+
dnncrf.fit()
13+
14+
15+
def test_cws():
16+
sentence = '小明来自南京师范大学'
17+
model_path = '../dnlp/models/cws1.ckpt'
18+
config = DnnCrfConfig()
19+
dnncrf = DnnCrf(config=config, mode='predict', model_path=model_path, nn='lstm')
20+
res = dnncrf.predict(sentence)
21+
print(res)
22+
23+
24+
if __name__ == '__main__':
25+
try:
26+
opts, args = getopt.getopt(sys.argv[1:], 'tp', [])
27+
if len(opts) != 1:
28+
raise Exception('cmd args count is not 1')
29+
opts = opts[0]
30+
if opts[0] == '-t':
31+
train_cws()
32+
elif opts[0] == '-p':
33+
test_cws()
34+
else:
35+
raise Exception('unknown cmd arg')
36+
except Exception as e:
37+
print(e.args[0])

0 commit comments

Comments
 (0)