## Simple Language Model Training with wikipedia data

In [21]:
import os
import pickle
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tqdm.auto import tqdm

In [22]:
class LanguageModel:
    """Simple character-level language model"""

    def __init__(self, chars: list) -> None:
        self._unigram = {c: 0 for c in chars}
        self._bigram = {c: {d: 0 for d in chars} for c in chars}
        self.chars = chars

    def get_char_unigram(self, c: str) -> float:
        """Probability of character c."""
        return self._unigram[c]

    def get_char_bigram(self, c: str, d: str) -> float:
        """Probability that character c is followed by character d."""
        return self._bigram[c][d]

    def train(self, txt: str):
        """Create language model from text corpus."""
        # compute unigrams
        for c in txt:
            # ignore unknown chars
            if c not in self._unigram:
                continue
            self._unigram[c] += 1

        # compute bigrams
        for i in range(len(txt) - 1):
            c = txt[i]
            d = txt[i + 1]

            # ignore unknown chars
            if c not in self._bigram or d not in self._bigram[c]:
                continue

            self._bigram[c][d] += 1

    def normalize(self):
        # normalize
        sum_unigram = sum(self._unigram.values())
        for c in self.chars:
            self._unigram[c] /= sum_unigram

        for c in self.chars:
            sum_bigram = sum(self._bigram[c].values())
            if sum_bigram == 0:
                continue
            for d in self.chars:
                self._bigram[c][d] /= sum_bigram

    def save(self, path):
        with open(os.path.join(path, "unigram.pkl"), 'wb') as pkl:
            pickle.dump(self._unigram, pkl)
        with open(os.path.join(path, "bigram.pkl"), 'wb') as pkl:
            pickle.dump(self._bigram, pkl)

    def load(self, path):
        with open(os.path.join(path, "unigram.pkl"), 'rb') as pkl:
            self._unigram = pickle.load(pkl)
        with open(os.path.join(path, "bigram.pkl"), 'rb') as pkl:
            self._bigram = pickle.load(pkl)

In [23]:
# Loading the wikipedia dataset.
DATASET_NAME = 'wikipedia/20190301.en'
# DATASET_NAME = 'wikipedia/20190301.uk'

dataset, dataset_info = tfds.load(
    name=DATASET_NAME,
    data_dir='tmp',
    with_info=True,
    split=tfds.Split.TRAIN)

In [24]:
print(dataset)

<PrefetchDataset shapes: {text: (), title: ()}, types: {text: tf.string, title: tf.string}>


In [25]:
TRAIN_NUM_EXAMPLES = dataset_info.splits['train'].num_examples
print('Total number of articles: ', TRAIN_NUM_EXAMPLES)

Total number of articles:  5824596


In [26]:
vocab_dict = {"<pad>": 0,
              "<s>": 1,
              "</s>": 2,
              "<unk>": 3,
              "|": 4,
              "E": 5,
              "T": 6,
              "A": 7,
              "O": 8,
              "N": 9,
              "I": 10,
              "H": 11,
              "S": 12,
              "R": 13,
              "D": 14,
              "L": 15,
              "U": 16,
              "M": 17,
              "W": 18,
              "C": 19,
              "F": 20,
              "G": 21,
              "Y": 22,
              "P": 23,
              "B": 24,
              "V": 25,
              "K": 26,
              "'": 27,
              "X": 28,
              "J": 29,
              "Q": 30,
              "Z": 31
              }

vocab_list = [key for key, value in vocab_dict.items()]

In [27]:
def change_digit_to_word(x):
    x = x.replace("0", "zero ")
    x = x.replace("1", "one ")
    x = x.replace("2", "two ")
    x = x.replace("3", "three ")
    x = x.replace("4", "four ")
    x = x.replace("5", "five ")
    x = x.replace("6", "six ")
    x = x.replace("7", "seven ")
    x = x.replace("8", "eight ")
    x = x.replace("9", "nine ")
    x = x.replace("  ", " ")
    x = x.strip()
    return x

In [31]:
lm = LanguageModel(chars=vocab_list[1:])

In [32]:
sample_per_corpus = 1000
corpus = ""
step = 0
for example in tqdm(dataset):
    corpus += example['title'].numpy().decode('utf-8')
    corpus += " "
    corpus += example['text'].numpy().decode('utf-8')
    step += 1
    if step == sample_per_corpus:
        lm.train(corpus)
        step = 0
        corpus = ""
lm.normalize()

  0%|          | 0/5824596 [00:00<?, ?it/s]

In [35]:
os.makedirs("lm", exist_ok=True)
lm.save("lm")

In [37]:
!zip lm.zip -r lm

  adding: lm/ (stored 0%)
  adding: lm/unigram.pkl (deflated 19%)
  adding: lm/bigram.pkl (deflated 60%)
