# n元语言模型回退算法

本次作业要求补全本笔记中的n元语言模型的采用Good-Turing折扣的Katz回退算法。

### 预处理

首先创建一些预处理函数。

引入必要的模块，定义些类型别名。

In [1]:
import re
import itertools

from typing import List, Dict, Tuple

Sentence = List[str]
IntSentence = List[int]

Corpus = List[Sentence]
IntCorpus = List[IntSentence]

Gram = Tuple[int]

下面的函数用于将文本正则化并词元化。该函数会将所有英文文本转为小写，去除文本中所有的标点，简单起见将所有连续的数字用一个`N`代替，将形如`let's`的词组拆分为`let`和`'s`两个词。

In [2]:
_splitor_pattern = re.compile(r"[^a-zA-Z']+|(?=')")
_digit_pattern = re.compile(r"\d+")
def normaltokenize(corpus: List[str]) -> Corpus:
    """
    Normalizes and tokenizes the sentences in `corpus`. Turns the letters into
    lower case and removes all the non-alphadigit characters and splits the
    sentence into words and added BOS and EOS marks.

    Args:
        corpus - list of str

    Return:
        list of list of str where each inner list of str represents the word
          sequence in a sentence from the original sentence list
    """

    tokeneds = [ ["<s>"]
               + list(
                   filter(lambda tkn: len(tkn)>0,
                       _splitor_pattern.split(
                           _digit_pattern.sub("N", stc.lower()))))
               + ["</s>"]
                    for stc in corpus
               ]
    return tokeneds

接下来定义两个函数用来从训练语料中构建词表，并将句子中的单词从字符串表示转为整数索引表示。

In [3]:
def extract_vocabulary(corpus: Corpus) -> Dict[str, int]:
    """
    Extracts the vocabulary from `corpus` and returns it as a mapping from the
    word to index. The words will be sorted by the codepoint value.

    Args:
        corpus - list of list of str

    Return:
        dict like {str: int}
    """

    vocabulary = set(itertools.chain.from_iterable(corpus))
    vocabulary = dict(
            map(lambda itm: (itm[1], itm[0]),
                enumerate(
                    sorted(vocabulary))))
    return vocabulary

def words_to_indices(vocabulary: Dict[str, int], sentence: Sentence) -> IntSentence:
    """
    Convert sentence in words to sentence in word indices.

    Args:
        vocabulary - dict like {str: int}
        sentence - list of str

    Return:
        list of int
    """

    return list(map(lambda tkn: vocabulary.get(tkn, len(vocabulary)), sentence))

接下来读入训练数据，将数据预处理。

In [4]:
import functools

with open("data/news.2007.en.shuffled.deduped.train") as f:
    texts = list(map(lambda l: l.strip(), f.readlines()))

print("Loaded training set.")

corpus = normaltokenize(texts)
vocabulary = extract_vocabulary(corpus)
corpus = list(
        map(functools.partial(words_to_indices, vocabulary),
            corpus))

print("Preprocessed training set.")

Loaded training set.
Preprocessed training set.


### 设计模型

参照公式

$$
P_{\text{bo}}(w_k | W_{k-n+1}^{k-1}) = \begin{cases}
    d(W_{k-n+1}^k) \dfrac{C(W_{k-n+1}^k)}{C(W_{k-n+1}^{k-1})} &  C(W_{k-n+1}^k) > 0 \\
    \alpha(W_{k-n+1}^{k-1}) P_{\text{bo}}(w_k | W_{k-n+2}^{k-1}) &  \text{否则} \\
\end{cases}
$$

实现n元语言模型及采用Good-Turing折扣的Katz回退算法。

需要实现的功能包括：

1. 统计各词组（gram）在训练语料中的频数
2. 计算同频词组个数$N_r$
3. 计算$d(W_{k-n+1}^k)$
4. 计算$\alpha(W_{k-n+1}^{k-1})$
5. 根据公式计算回退概率
6. 计算概率对数与困惑度（PPL）

$d$与$\alpha$如何计算可以参考作业文件中的算法说明以及[SRILM](http://www.speech.sri.com/projects/srilm/)的[`ngram-discount(7)`手册页](http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html)。

In [8]:
import math

class NGramModel:
    def __init__(self, vocab_size: int, n: int = 4):
        """
        Constructs `n`-gram model with a `vocab_size`-size vocabulnextry.

        Args:
            vocab_size - int
            n - int
        """

        self.vocab_size: int = vocab_size
        self.n: int = n

        self.frequencies: List[Dict[Gram, int]]\
            = [{} for _ in range(n)]
        self.disfrequencies: List[Dict[Gram, int]]\
            = [{} for _ in range(n)]
        """
         We use a list of length n to store the count of counts.
         The keys of each dictionary are the frequencies of each grams 
         and the values are the count of all grams having the same frequency.
        """
        
        self.ncounts: List[Dict[int,int]]\
            = [{} for _ in range(n)]

        self.discount_threshold:int = 7
        self._d: Dict[Gram, float] = {} # we only need to return a float instead of a tuple
        self._alpha: List[Dict[Gram, float]]\
            = [{} for _ in range(n)]

        self.eps = 1e-10

    def learn(self, corpus: IntCorpus):
        """
        Learns the parameters of the n-gram model.

        Args:
            corpus - list of list of int
        """
        for stc in corpus:
            for i in range(1, len(stc)+1):
                for j in range(min(i, self.n)):
                    # TODO: count the frequencies of the grams
                    
                    gram = tuple(stc[i - j - 1: i])
                    if gram not in self.frequencies[j].keys():
                        self.frequencies[j][gram] = 1
                    else:
                        self.frequencies[j][gram] += 1
        for i in range(1, self.n):
            # TODO: calculates the value of $N_r$
            
            gram_sorted = sorted(self.frequencies[i].items(), key = lambda itm: itm[1])
            gram_grouped = itertools.groupby(gram_sorted, key = lambda itm: itm[1])
            
            for freq, group in gram_grouped:
                count = len(list(group))
                self.ncounts[i][freq] = count

    def d(self, gram: Gram) -> float:
        """
        Calculates the interpolation coefficient.

        Args:
            gram - tuple of int

        Return:
            float
        """

        if gram not in self._d:
            # TODO: calculates the value of $d'$

            theta = self.discount_threshold
            idx = len(gram)-1
            freq = self.frequencies[idx][gram]
            
            if freq > theta:
                self._d[gram] = 1
            else:
                lmbda = self.ncounts[idx][1] / (self.ncounts[idx][1] - (theta + 1) * self.ncounts[idx][theta + 1])
                self._d[gram] = lmbda * (freq + 1) * self.ncounts[idx][freq + 1] / (freq * self.ncounts[idx][freq]) + 1 - lmbda
            
        return self._d[gram]

    def alpha(self, gram: Gram) -> float:
        """
        Calculates the back-off weight alpha(`gram`)

        Args:
            gram - tuple of int

        Return:
            float
        """

        n = len(gram)
        if gram not in self._alpha[n]:
            if gram in self.frequencies[n-1]:
                # TODO: calculates the value of $\alpha$
                
                numerator = 1
                denominator = 0
                for i in range(self.vocab_size + 1):
                    if(tuple(list(gram) + [i]) in self.frequencies[n]):
                        numerator -= self.__getitem__(tuple(list(gram) + [i]))
                    else:
                        denominator += self.__getitem__(tuple(list(gram[1:]) + [i]))
                self._alpha[n][gram] = (1 - numerator)/denominator
                
            else:
                self._alpha[n][gram] = 1.
        return self._alpha[n][gram]

    def __getitem__(self, gram: Gram) -> float:
        """
        Calculates smoothed conditional probability P(`gram[-1]`|`gram[:-1]`).

        Args:
            gram - tuple of int

        Return:
            float
        """

        n = len(gram)-1
        if gram not in self.disfrequencies[n]:
            if n>0:
                # TODO: calculates the smoothed probability value according to the formulae
                
                gram1 = tuple(list(gram)[:-1])
                gram2 = tuple(list(gram)[1:])
                if gram in self.frequencies[n]:
                    self.disfrequencies[n][gram] = self.d(gram) * self.frequencies[n][gram] / self.frequencies[n-1][gram1]
                else:
                    alpha = self.alpha(gram1) 
                    self.disfrequencies[n][gram] = alpha * self.__getitem__(gram2)
                
            else: # uni-gram, n == 0
                self.disfrequencies[n][gram] = \
                self.frequencies[n].get(gram, self.eps)/float(len(self.frequencies[0]))
        return self.disfrequencies[n][gram]

    def log_prob(self, sentence: IntSentence) -> float:
        """
        Calculates the log probability of the given sentence. Assumes that the
        first token is always "<s>".

        Args:
            sentence: list of int

        Return:
            float
        """

        log_prob = 0.
        for i in range(4, len(sentence) + 1):
            # TODO: calculates the log probability
            log_prob += math.log(self.__getitem__(tuple(sentence[i-4: i])))
        log_prob += math.log(self.__getitem__(tuple(sentence[0: 2])))
        log_prob += math.log(self.__getitem__(tuple(sentence[0: 3])))
        
        return log_prob

    def ppl(self, sentence: IntSentence) -> float:
        """
        Calculates the PPL of the given sentence. Assumes that the first token
        is always "<s>".

        Args:
            sentence: list of int

        Return:
            float
        """

        # TODO: calculates the PPL
        return math.exp(-1 * self.log_prob(sentence) / (len(sentence) - 1))

### 训练与测试

现在数据与模型均已齐备，可以训练并测试了。

训练模型：

In [9]:
import pickle as pkl

model = NGramModel(len(vocabulary))
model.learn(corpus)
with open("model.pkl", "wb") as f:
    pkl.dump(vocabulary, f)
    pkl.dump(model, f)

print("Dumped model.")

Dumped model.


在测试集上测试计算困惑度：

In [10]:
with open("model.pkl", "rb") as f:
    vocabulary = pkl.load(f)
    model = pkl.load(f)
print("Loaded model.")

with open("data/news.2007.en.shuffled.deduped.test") as f:
    test_set = list(map(lambda l: l.strip(), f.readlines()))
test_corpus = normaltokenize(test_set)
test_corpus = list(
        map(functools.partial(words_to_indices, vocabulary),
            test_corpus))
ppls = []
for t in test_corpus:
    ppls.append(model.ppl(t))
    print(ppls[-1])
print("Avg: ", sum(ppls)/len(ppls))

Loaded model.
18079.050232736678
819.7713360237044
592.8504480760279
120.09015300262477
17365.898830753187
93.08048525133171
20.37816197083465
457.7118188383788
1668.937631158522
115.5702037088461
394.22456872284727
169.48386540413424
7280.683379928021
3568.4986525394406
1040.0046720849966
550.6566746024754
19298.266324042885
60499.421599851565
156.89266771511913
173.18912226867104
943.4349274193535
2264.31199429678
725.4606749246966
488.96146604927674
162.38716562115496
81.64059481212614
586.7571212443323
10.30420213772118
281.6601330774344
430.1665450005788
426.8215601881857
610.0945012528679
130.39684326420726
354.21120801916305
754.4355350132953
155.65144994689538
1433.7408034008251
747.3955723350838
809.5848352965743
3198.2760151427306
244.07105187312484
1687.250012871627
207.95090575782635
225.08600149495763
2913.763384811634
1868.9372456809563
452.18866458327335
1712.637259742351
127.72777763821362
537.4465288335156
Avg:  3140.748256208222
