In [None]:
from typing import List, Dict, Set
from itertools import chain
import re
from collections import defaultdict, Counter


def build_bpe(
        corpus: List[str],
        max_vocab_size: int
) -> List[int]:
    """ BPE Vocabulary Builder
    Implement vocabulary builder for byte pair encoding.
    Please sort your idx2word by subword length in descending manner.

    Hint: Counter in collection library would be helpful

    Note: If you convert sentences list to word frequence dictionary,
          building speed is enhanced significantly because duplicated words are
          preprocessed together

    Arguments:
    corpus -- List of words to build vocab
    max_vocab_size -- The maximum size of vocab

    Return:
    idx2word -- Subword list
    """
    # Special tokens
    PAD = BytePairEncoding.PAD_token  # Index of <PAD> must be 0
    UNK = BytePairEncoding.UNK_token  # Index of <UNK> must be 1
    CLS = BytePairEncoding.CLS_token  # Index of <CLS> must be 2
    SEP = BytePairEncoding.SEP_token  # Index of <SEP> must be 3
    MSK = BytePairEncoding.MSK_token  # Index of <MSK> must be 4
    SPECIAL = [PAD, UNK, CLS, SEP, MSK]

    WORD_END = BytePairEncoding.WORD_END  # Use this token as the
    def get_stats(vocab):
      pairs = defaultdict(int)
      for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
          pairs[symbols[i], symbols[i+1]] += freq
      return pairs

    def merge_vocab(pair, v_in):
      v_out = {}
      bigram = re.escape(' '.join(pair))
      p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
      for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
      return v_out

    def make_unigram(vocab):
      unigram = set()

      for word, _ in vocab.items():
            chars = word.split()
            for char in chars[:-1]:
              unigram.add(char)
      return list(unigram)

    vocab_counter = Counter(corpus)
    vocab = dict([(str(' '.join(x)) + ' ' + WORD_END, y) for (x, y) in vocab_counter.items()])
    final_vocab = make_unigram(vocab)

    num_merge = max_vocab_size - len(final_vocab) - len(SPECIAL) - 1
    for i in range(num_merge):
      pairs = get_stats(vocab)
      
      try:
        best = max(pairs, key=pairs.get)
        final_vocab.append(''.join(best))
        vocab = merge_vocab(best, vocab)
      except: break

    idx2word = SPECIAL + sorted(final_vocab, key=lambda x:len(x), reverse=True) + [WORD_END]
          
    
    return idx2word