<a href="https://colab.research.google.com/github/pdh93621/Deep-learning/blob/main/BPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import re, collections

In [None]:
# 합병 횟수 지정
num_merges = 10

In [None]:
dictionary = {'l o w </w>' : 5,
              'l o w e r </w>': 2,
              'n e w e s t </w>' : 6,
              'w i d e s t </w>' : 3}

In [None]:
# char 단위로 분절 후, pair 별 빈도 카운트
def get_stats(dictionary):
  pairs = collections.defaultdict(int)
  for word, freq in dictionary.items():
    symbols = word.split()
    for i in range(len(symbols)-1):
      pairs[symbols[i], symbols[i+1]] += freq
  print('현재 pair들의 빈도수:', dict(pairs))
  return pairs

In [None]:
# 최빈도 pair를 골라, merge 수행
def merge_dictionary(pair, v_in):
  v_out = {}
  bigram = re.escape(' '.join(pair))
  p = re.compile(r'(?<!\S)' + bigram +r'(?!\S)')
  for word in v_in:
    w_out = p.sub(''.join(pair), word)
    v_out[w_out] = v_in[word]

  return v_out

In [None]:
bpe_codes = {}
bpe_codes_reverse = {}

In [None]:
# pair별 빈도 카운트 업데이트 반복
for i in range(num_merges):
  print(">> Step {0}".format(i+1))
  pairs = get_stats(dictionary)
  best = max(pairs, key=pairs.get)
  dictionary = merge_dictionary(best,dictionary)

  bpe_codes[best] = i
  bpe_codes_reverse[best[0]+best[1]] = best

  print(f'new merge: {best}')
  print(f'dictionary: {dictionary}')

>> Step 1
현재 pair들의 빈도수: {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}
new merge: ('e', 's')
dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
>> Step 2
현재 pair들의 빈도수: {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}
new merge: ('es', 't')
dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
>> Step 3
현재 pair들의 빈도수: {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('est', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('est', '</w>')
dictionary:

In [None]:
print(bpe_codes)

{('e', 's'): 0, ('es', 't'): 1, ('est', '</w>'): 2, ('l', 'o'): 3, ('lo', 'w'): 4, ('n', 'e'): 5, ('ne', 'w'): 6, ('new', 'est</w>'): 7, ('low', '</w>'): 8, ('w', 'i'): 9}


## OOV에 대처


In [None]:
# 연속된 두 char로 묶기('apple': ('a', 'p'), ('p', 'p'), ('l', 'e'), ('p', 'l'))
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [None]:
# 인코딩: 위에서 만든 BPE로 orig 분해하기
def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('</w>',)
    #display(Markdown("__word split into characters:__ <tt>{}</tt>".format(word)))
    print("__word split into characters:__ <tt>{}</tt>".format(word))

    pairs = get_pairs(word)    

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        #display(Markdown("__Iteration {}:__".format(iteration)))
        print("__Iteration {}:__".format(iteration))

        print("bigrams in the word: {}".format(pairs))
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print("candidate for merging: {}".format(bigram))
        if bigram not in bpe_codes:
            #display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
            print("__Candidate not in BPE merges, algorithm stops.__")
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # 특별 토큰인 </w>는 출력하지 않는다.
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1].replace('</w>',''),)

    return word



In [None]:
encode('hellong')

__word split into characters:__ <tt>('h', 'e', 'l', 'l', 'o', 'n', 'g', '</w>')</tt>
__Iteration 1:__
bigrams in the word: {('n', 'g'), ('e', 'l'), ('l', 'l'), ('g', '</w>'), ('o', 'n'), ('h', 'e'), ('l', 'o')}
candidate for merging: ('l', 'o')
word after merging: ('h', 'e', 'l', 'lo', 'n', 'g', '</w>')
__Iteration 2:__
bigrams in the word: {('n', 'g'), ('e', 'l'), ('g', '</w>'), ('lo', 'n'), ('l', 'lo'), ('h', 'e')}
candidate for merging: ('n', 'g')
__Candidate not in BPE merges, algorithm stops.__


('h', 'e', 'l', 'lo', 'n', 'g')