Skip to content

Commit

Permalink
fix threshold=hp['THRESHOLD'] in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
taishi-i committed Jun 21, 2020
1 parent 76f94e2 commit 34fc795
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
3 changes: 2 additions & 1 deletion nagisa/train.py
Expand Up @@ -20,7 +20,7 @@

def fit(train_file, dev_file, test_file, model_name,
dict_file=None, emb_file=None, delimiter='\t', newline='EOS',
layers=1, min_count=3, decay=1, epoch=10, window_size=3,
layers=1, min_count=2, decay=1, epoch=10, window_size=3,
dim_uni=32, dim_bi=16, dim_word=16, dim_ctype=8, dim_tagemb=16,
dim_hidden=100, learning_rate=0.1, dropout_rate=0.3, seed=1234):

Expand Down Expand Up @@ -89,6 +89,7 @@ def fit(train_file, dev_file, test_file, model_name,

# Preprocess
vocabs = prepro.create_vocabs_from_trainset(trainset=hp['TRAINSET'],
threshold=hp['THRESHOLD'],
fn_dictionary=hp['DICTIONARY'],
fn_vocabs=hp['VOCAB'],
delimiter=delimiter,
Expand Down
38 changes: 35 additions & 3 deletions test/nagisa_test.py
Expand Up @@ -98,17 +98,35 @@ def test_tagging(self):
postags = nagisa.decode(words)
self.assertEqual(output, postags)

# test_17
text = 'こんばんは😀'
output = 'こんばんは/感動詞 😀/補助記号'
words = nagisa.tagging(text)
self.assertEqual(output, str(words))

# test_18
text = 'コンバンハ12345'
output = 'コンバンハ/名詞 1/名詞 2/名詞 3/名詞 4/名詞 5/名詞'
words = nagisa.tagging(text)
self.assertEqual(output, str(words))

# test_19
text = '𪗱𪘂𪘚𪚲'
output = '𪗱/補助記号 𪘂/補助記号 𪘚/補助記号 𪚲/補助記号'
words = nagisa.tagging(text)
self.assertEqual(output, str(words))


def test_fit(self):
# test_17
# test_20
nagisa.fit(
train_file="nagisa/data/sample_datasets/sample.train",
dev_file="nagisa/data/sample_datasets/sample.dev",
test_file="nagisa/data/sample_datasets/sample.test",
model_name="sample",
)

# test_18
# test_21
nagisa.fit(
train_file="nagisa/data/sample_datasets/sample.train",
dev_file="nagisa/data/sample_datasets/sample.dev",
Expand All @@ -120,9 +138,23 @@ def test_fit(self):
delimiter="\t"
)

# test_22
nagisa.fit(
train_file="nagisa/data/sample_datasets/sample.train",
dev_file="nagisa/data/sample_datasets/sample.dev",
test_file="nagisa/data/sample_datasets/sample.test",
dict_file="nagisa/data/sample_datasets/sample.dict",
emb_file="nagisa/data/sample_datasets/sample.emb",
model_name="sample",
newline="EOS",
delimiter="\t",
min_count=0
)



def test_mecab_system_eval(self):
# test_19
# test_23
system_file = "nagisa/data/sample_datasets/sample.pred"
answer_file = "nagisa/data/sample_datasets/sample.test"

Expand Down

0 comments on commit 34fc795

Please sign in to comment.