In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import re
import json
import os
import itertools
from math import log

In [None]:
# Xóa dấu tiếng Việt
def remove_vn_accent(s):
    s = re.sub('[áàảãạăắằẳẵặâấầẩẫậ]', 'a', s)
    s = re.sub('[éèẻẽẹêếềểễệ]', 'e', s)
    s = re.sub('[óòỏõọôốồổỗộơớờởỡợ]', 'o', s)
    s = re.sub('[íìỉĩị]', 'i', s)
    s = re.sub('[úùủũụưứừửữự]', 'u', s)
    s = re.sub('[ýỳỷỹỵ]', 'y', s)
    s = re.sub('đ', 'd', s)
    return s

In [None]:
# Download VN syllables
!wget -O vn_syllables.txt "https://gist.githubusercontent.com/nguyenvanhieuvn/8d7c3440590a6db732ef6e05498c1566/raw/0164ccb7094b22a48a52844e7fa748cf2820cec9/all-vietnamese-syllables.txt(G%25C3%25B5%2520d%25E1%25BA%25A5u%2520ki%25E1%25BB%2583u%2520c%25C5%25A9)"
# Download language model
!wget "https://github.com/nguyenvanhieuvn/vn-accent-resoration/raw/master/vn_en_nextwords.txt.zip"
!unzip vn_en_nextwords.txt.zip
# Download test data
!wget "https://github.com/nguyenvanhieuvn/vn-accent-resoration/raw/master/test.txt"


In [None]:
lm = {}
next_for_repair = {}
output = []
for line in open('/content/vn_en_nextwords.txt'):
    out = {}
    data = json.loads(line)
    key = data['s']
#       print(key)
    if key == "nơi":
        next_for_repair = data['next']
        data['next'].update({"ĐKHK": 200000})
        data['next'].update({"đkhk": 200000})
        data['next'].update({"thường": 200000})
        data['next'].update({"cư": 200000})
    if key == 'thường':
        data['next'].update({"trú": 200000})
        # print(data['next'])
    if key == 'nguyên':
        data['next'].update({"quán": 200000})
        # print(data['next'])
    out.update(s = key, sum = data['sum'],next= data['next'])
    output.append(out)
out = {}
out.update(s = 'đkhk', sum = 2000000, next = {"thường": 200000, "thưởng":100})
output.append(out)
with open("nextwords.txt", 'w') as writeout:
    for dict in output:
        dict = json.dumps(dict)
        # print(dict)
        writeout.writelines(str(dict) + '\n')

In [None]:
# Tạo bộ từ điển sinh dấu câu cho các từ không dấu
map_accents = {}
for word in open('vn_syllables.txt').read().splitlines():
  # print(word)
  word = word.lower()
  # print(word)
  no_accent_word = remove_vn_accent(word)
  # print(no_accent_word)
  if no_accent_word not in map_accents:
    map_accents[no_accent_word] = set()
    # print(map_accents)
  map_accents[no_accent_word].add(word)

In [None]:
print(map_accents['dkhk'])

{'đkhk'}


In [None]:
# Đọc lm
lm = {}
for line in open('nextwords.txt'):
  data = json.loads(line)
  key = data['s']
  lm[key] = data
vocab_size = len(lm)
total_word = 0
for word in lm:
  total_word += lm[word]['sum']
print(vocab_size, total_word)

50001 6345849795


In [None]:
lm['đkhk']

{'next': {'thường': 200000, 'thưởng': 100}, 's': 'đkhk', 'sum': 2000000}

In [None]:
# tính xác suất dùng smoothing
def get_proba(current_word, next_word):
  if current_word not in lm:
    return 1 / total_word;
  if next_word not in lm[current_word]['next']:
    return 1 / (lm[current_word]['sum'] + vocab_size)
  return (lm[current_word]['next'][next_word] + 1) / (lm[current_word]['sum'] + vocab_size)

# def get_proba(current_word, next_word):
#   try:
#     return (lm[current_word]['next'][next_word]) / (lm[current_word]['sum'])
#   except:
#     return 1e-30

In [None]:
# hàm beam search
def beam_search(words, k=3):
  sequences = []
  for idx, word in enumerate(words):
    if idx == 0:
      sequences = [([x], 0.0) for x in map_accents.get(word, [word])]
    else:
      all_sequences = []
      for seq in sequences:
        for next_word in map_accents.get(word, [word]):
          current_word = seq[0][-1]
          proba = get_proba(current_word, next_word)
          # print(current_word, next_word, proba, log(proba))
          proba = log(proba)
          new_seq = seq[0].copy()
          new_seq.append(next_word)
          all_sequences.append((new_seq, seq[1] + proba))
      # print(all_sequences) 
      all_sequences = sorted(all_sequences,key=lambda x: x[1], reverse=True)
      sequences = all_sequences[:k]
  return sequences

    

In [None]:
# tiền xử lý text
def preprocess(sentence):
  sentence = sentence.lower()
  sentence = re.sub(r'[.,~`!@#$%\^&*\(\)\[\]\\|:;\'"]+', ' ', sentence)
  sentence = re.sub(r'\s+', ' ', sentence).strip()
  return sentence

In [None]:
# test 1 câu
# import time
# t1 = time.time()
sentence = "Nởi DKHK thương trù"
sentence = preprocess(sentence)
_sentence = remove_vn_accent(sentence)
words = _sentence.split()
results = beam_search(words, k=5)
print('INP:', sentence)
# print(time.time()-t1)
print('OUT:', ' '.join(results[0][0]))
print('CMP:', ' '.join(results[0][0]) == sentence)

INP: nởi dkhk thương trù
OUT: nơi đkhk thường trú
CMP: False


In [None]:
# test qua bộ test ~ 5000 câu
k = 10
sentences = open('test.txt').read().splitlines()
test_size = len(sentences)
print(test_size)
correct = 0
for sent in sentences:
  try:
    sent = preprocess(sent)
    _sent = remove_vn_accent(sent)
    words = _sent.split()
    results = beam_search(words, k)
    if ' '.join(results[0][0]) == sent:
      correct += 1
  except:
    print('err', sent)
    break

print(correct / test_size)

  

4983
0.32691149909692957
