Skip to content

Commit

Permalink
new option --total-symbols in learn-bpe
Browse files Browse the repository at this point in the history
redefines "--symbols" to be the number of merge operations,
minus the character vocabulary size, so that "--symbols" becomes
an estimate of the final symbol vocabulary size.

thx @phikoehn
  • Loading branch information
rsennrich committed Jun 28, 2018
1 parent 71b22d1 commit 61ad855
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions subword_nmt/learn_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def create_parser(subparsers=None):
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
parser.add_argument('--dict-input', action="store_true",
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
parser.add_argument(
'--total-symbols', '-t', action="store_true",
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
parser.add_argument(
'--verbose', '-v', action="store_true",
help="verbose mode.")
Expand Down Expand Up @@ -197,7 +200,7 @@ def prune_stats(stats, big_stats, threshold):
big_stats[item] = freq


def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False):
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
"""

Expand All @@ -211,6 +214,19 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d

stats, indices = get_pair_statistics(sorted_vocab)
big_stats = copy.deepcopy(stats)

if total_symbols:
uniq_char_internal = set()
uniq_char_final = set()
for word in vocab:
for char in word[:-1]:
uniq_char_internal.add(char)
uniq_char_final.add(char[-1])
sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
num_symbols -= len(uniq_char_internal) + len(uniq_char_final)

# threshold is inspired by Zipfian assumption, but should only affect speed
threshold = max(stats.values()) / 10
for i in range(num_symbols):
Expand Down Expand Up @@ -270,4 +286,4 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')

learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input)
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)

0 comments on commit 61ad855

Please sign in to comment.