Skip to content

Commit

Permalink
- New option: negative_wc
Browse files Browse the repository at this point in the history
  • Loading branch information
fstahlberg committed Feb 20, 2019
1 parent f049c52 commit 0f78d12
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions cam/sgnmt/decode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def add_predictors(decoder):
args.syntax_nonterminal_ids,
args.syntax_min_terminal_id,
args.syntax_max_terminal_id,
args.negative_wc,
_get_override_args("pred_trg_vocab_size"))
elif pred == "ngramc":
p = NgramCountPredictor(_get_override_args("ngramc_path"),
Expand Down
27 changes: 15 additions & 12 deletions cam/sgnmt/predictors/length.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,66 +245,69 @@ def __init__(self, word=-1,
nonterminal_ids=None,
min_terminal_id=0,
max_terminal_id=30003,
negative_wc=True,
vocab_size=30003):
"""Creates a new word count predictor instance.
Args:
word (int): If this is non-negative we count only the
number of the specified word. If its
negative, count all words
nonterminal_penalty (bool): If true, apply penalty only to tokens in a range
(the range *outside* min/max terminal id)
nonterminal_penalty (bool): If true, apply penalty only to
tokens in a range (the range *outside*
min/max terminal id)
nonterminal_ids: file containing ids of nonterminal tokens
min_terminal_id: lower bound of tokens *not* to penalize,
if nonterminal_penalty selected
max_terminal_id: upper bound of tokens *not* to penalize,
if nonterminal_penalty selected
negative_wc: If true, the score of this predictor is the
negative word count.
vocab_size: upper bound of tokens, used to find nonterminal range
"""
super(WordCountPredictor, self).__init__()
val = 1.0
if negative_wc:
val = -1.0
if nonterminal_penalty:
if nonterminal_ids:
nts = load_external_ids(nonterminal_ids)
else:
min_nt_range = range(0, min_terminal_id)
max_nt_range = range(max_terminal_id + 1, vocab_size)
nts = min_nt_range + max_nt_range
self.posterior = {nt: -1.0 for nt in nts}
self.posterior = {nt: val for nt in nts}
self.posterior[utils.EOS_ID] = 0.0
self.posterior[utils.UNK_ID] = 0.0
self.unk_prob = 0.0
elif word < 0:
self.posterior = {utils.EOS_ID : 0.0}
#self.unk_prob = -1.0
self.unk_prob = 1.0
self.unk_prob = val
else:
self.posterior = {word : -1.0}
self.posterior = {word : val}
self.unk_prob = 0.0

def get_unk_probability(self, posterior):
return self.unk_prob

def predict_next(self):
"""Set score for EOS to the number of consumed words """
return self.posterior

def initialize(self, src_sentence):
"""Empty
"""
"""Empty"""
pass

def consume(self, word):
"""Empty
"""
"""Empty"""
pass

def get_state(self):
"""Returns true """
return True

def set_state(self, state):
"""Empty """
"""Empty"""
pass

def is_equal(self, state1, state2):
Expand Down
7 changes: 6 additions & 1 deletion cam/sgnmt/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def get_parser():
" Options: rules_path, "
"grammar_feature_weights, use_grammar_weights\n"
"* 'wc': Number of words feature.\n"
" Options: wc_word.\n"
" Options: wc_word, negative_wc.\n"
"* 'unkc': Poisson model for number of UNKs.\n"
" Options: unk_count_lambdas, "
"pred_src_vocab_size.\n"
Expand Down Expand Up @@ -941,6 +941,11 @@ def get_parser():
group.add_argument("--wc_word", default=-1, type=int,
help="If negative, the wc predictor counts all "
"words. Otherwise, count only the specific word")
group.add_argument("--negative_wc", default=True, type='bool',
help="If true, wc is the negative word count and thus "
"makes translations shorter. Otherwise, it makes "
"translations longer. Set early_stopping to False if "
"negative_wc=False and wc has a positive weight")
group.add_argument("--wc_nonterminal_penalty", default=False,
action='store_true', help="if true, "
"use syntax_[max|min]_terminal_id to apply penalty to "
Expand Down

0 comments on commit 0f78d12

Please sign in to comment.