In [44]:
import string
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from trie import Trie
import re
import unicodedata
import tqdm

In [19]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Load from xml file

In [2]:
#Import dataset
tree = ET.parse('lag1734.xml/lag1734.xml')
root = tree.getroot()
root.attrib

{'id': 'lag1734'}

In [3]:
chapters = []
for child in root:
    chapters.append(child)

sentances = []
for chapter in chapters:
    for paragraph in chapter:
        for sentence in paragraph:
            sentance = ''
            for child in sentence:
                sentance += f'{child.text} '
            sentances.append(sentance[:-1])
sentances

['D O M A R E R E G L E R. Någre almennelige Regler , ther en Domare skal sigh aldeles effter rätta .',
 'En Domare skal först besinna , at han en Gudz Befalningsman , och thet Embete han förer , thet hörer Gudh til , och icke honom sielffuom , och therföre hörer Domen , som han afsäger , Gudhi til , efter thet han afsagd warder i Gudz Embete på Gudz wegna , så at thet är wisserliga Gudz Dom , och icke Menniskiors .',
 'Och ty ligger Domaren ther Macht vppå , at han seer sigh wijsligen före , at han icke på Gudz wegna dömer en falskan Dom , med hwilken han dömer sig til en ewigh Fördömelse , effter thet han misbrukat Guds Dom och Befalning til Öffuerwold och Orätt , som til Rätt af Gudhi insatt är .',
 'Men ther han haffuer wilia til at döma Rätt , och ransakar grant effter sitt ytersta Förstånd om Rätten , och kan dock icke för sin Oförståndigheet finna på Rätten , och säger så en falsk Dom , tå haffuer han någor Vrsächt , at han är kommen på then falska Domen , emot sin wilia aff wåd

## Load from txt file

In [45]:
with open('wikipedia-sv.xml/wikipedia-sv.txt', 'r', encoding='utf-8') as f:
    words = f.readlines()
    sentances = [word.strip() for word in words]
len(sentances)

193900721

# Stopwords

In [21]:
# Import stopwords
with open('stopwords.txt', 'r', encoding='utf-8') as f:
    stop_words = [line.strip() for line in f.readlines()]
stop_words

['aderton',
 'adertonde',
 'adjö',
 'aldrig',
 'alla',
 'allas',
 'allt',
 'alltid',
 'alltså',
 'andra',
 'andras',
 'annan',
 'annat',
 'artonde',
 'artonn',
 'att',
 'av',
 'bakom',
 'bara',
 'behöva',
 'behövas',
 'behövde',
 'behövt',
 'beslut',
 'beslutat',
 'beslutit',
 'bland',
 'blev',
 'bli',
 'blir',
 'blivit',
 'bort',
 'borta',
 'bra',
 'bäst',
 'bättre',
 'båda',
 'bådas',
 'dag',
 'dagar',
 'dagarna',
 'dagen',
 'de',
 'del',
 'delen',
 'dem',
 'den',
 'denna',
 'deras',
 'dess',
 'dessa',
 'det',
 'detta',
 'dig',
 'din',
 'dina',
 'dit',
 'ditt',
 'dock',
 'dom',
 'du',
 'där',
 'därför',
 'då',
 'e',
 'efter',
 'eftersom',
 'ej',
 'elfte',
 'eller',
 'elva',
 'emot',
 'en',
 'enkel',
 'enkelt',
 'enkla',
 'enligt',
 'ens',
 'er',
 'era',
 'ers',
 'ert',
 'ett',
 'ettusen',
 'fanns',
 'fem',
 'femte',
 'femtio',
 'femtionde',
 'femton',
 'femtonde',
 'fick',
 'fin',
 'finnas',
 'finns',
 'fjorton',
 'fjortonde',
 'fjärde',
 'fler',
 'flera',
 'flesta',
 'fram',
 'framf

In [22]:
stopword_trie = Trie()
stopword_trie.add_multiple(*stop_words)
stop_words_regex = re.compile(r'\b' + stopword_trie.pattern() + r'\b', re.IGNORECASE)


In [4]:
#Pre-processing
for idx, sentance in enumerate(sentances):
    sentance = unicodedata.normalize('NFKC', sentance)
    # sentance = stop_words_regex.sub('', sentance)
    sentances[idx] = sentance

In [25]:
#Tokenization
tfidf = TfidfVectorizer(stop_words = stop_words)
X = tfidf.fit_transform(sentances)
print(*tfidf.get_feature_names_out())



In [26]:
feature_array = np.array(tfidf.get_feature_names_out())
tfidf_sorting = np.argsort(X.toarray().sum(axis=0))[::-1]
n = 100
top_n = feature_array[tfidf_sorting[:n]]
print(top_n)

['cap' 'thet' 'then' 'ther' 'til' 'at' 'tå' 'the' 'af' 'böte' 'daler'
 'sagdt' 'må' 'någor' 'ock' 'skal' '10' 'vare' 'them' 'hafver' '11' 'äro'
 'domaren' 'giör' 'konungens' 'åter' 'lag' 'hafve' 'sägs' 'saken' 'förr'
 'thes' 'gånge' 'gods' 'vil' '12' 'hvar' 'skada' 'tid' '13' 'bör' 'sker'
 'huru' 'rätten' 'hus' 'hofrätten' 'laga' 'annars' 'tijo' 'ware' 'äntå'
 'balken' 'ifrån' 'staden' 'sielf' 'jord' 'skadan' 'alt' 'skola' 'äger'
 'barn' 'landet' 'thertil' 'niute' 'plichte' 'hvad' 'konungen' 'stadgadt'
 '14' 'gifve' 'tage' 'hafva' 'gälde' 'finnes' 'varder' 'måge' 'theras'
 'tiugu' 'öfver' 'fä' 'thy' 'miste' 'up' 'åhr' 'hos' 'hafwer' 'emellan'
 'skiäl' 'therom' 'ske' 'arf' 'fängelse' 'befalningshafvande' 'mål' '16'
 '15' 'theraf' 'mans' 'stånde' 'lof']


In [27]:
feature_array = np.array(tfidf.get_feature_names_out())
tfidf_sorting = np.argsort(X.toarray().sum(axis=0))
n = 1000
top_n = feature_array[tfidf_sorting[:n]]
print(top_n)

['västerbotn' 'södermanland' 'skaraborgs' 'rautalambi' 'nyland' 'kalmare'
 'skåne' 'kymmenegårds' 'bohuslän' 'dahl' 'gestrikeland' 'ångermanland'
 'vermeland' 'ingifve' 'upland' 'nerike' 'kopparbergs' 'västmanland'
 'blekinge' 'halland' 'östergöthland' 'österbotn' 'herjedalen' 'jämteland'
 'gothland' 'helsingeland' 'småland' 'göta' 'tavastehus' 'lagmansdomen'
 'medelpad' 'elfsborgs' 'förändringar' 'riksdag' 'företrädare' 'brukliga'
 'adolph' 'samtelige' 'stadslagens' 'behöringen' 'sorgfällighet'
 'esomoftast' 'ändskap' 'berömmelig' 'brukelig' 'mångahanda' 'våre'
 'vittre' 'gångne' 'gustaf' 'förändrat' 'ändrade' 'hälsosamt' 'sedvana'
 'angelägit' 'förständiga' 'svårigheter' 'förbättra' '1731' '1618' 'daga'
 'öfversedde' 'emellankomna' 'förbättrade' 'önskan' 'högloflige' 'härtils'
 'ehuruväl' 'sinnande' 'svänska' 'nöigt' 'lagfarne' 'utgifvande'
 'utesluta' 'bringas' 'högtärade' 'lärda' 'moederfaders' 'grundval' 'fant'
 'förenad' 'månsons' 'stadfästades' 'efterlevande' 'spa' 'födelse'
 'p

In [5]:
counter = CountVectorizer()
Y = counter.fit_transform(sentances)
Y_sum = Y.toarray().sum(axis=0)
bpe_words = list(zip(map(lambda x: ' '.join(list(x)) + ' </w>', counter.get_feature_names_out()), Y_sum))
bpe_words

[('1 0 </w>', 58),
 ('1 0 0 </w>', 2),
 ('1 1 </w>', 40),
 ('1 1 3 </w>', 1),
 ('1 2 </w>', 29),
 ('1 2 8 </w>', 1),
 ('1 3 </w>', 35),
 ('1 3 3 </w>', 1),
 ('1 4 </w>', 23),
 ('1 4 4 2 </w>', 1),
 ('1 4 d e </w>', 1),
 ('1 5 </w>', 16),
 ('1 6 </w>', 21),
 ('1 6 0 8 </w>', 2),
 ('1 6 1 8 </w>', 1),
 ('1 6 8 6 </w>', 1),
 ('1 7 </w>', 12),
 ('1 7 3 1 </w>', 1),
 ('1 7 3 4 </w>', 3),
 ('1 7 3 6 </w>', 1),
 ('1 8 </w>', 11),
 ('1 9 </w>', 6),
 ('2 0 </w>', 6),
 ('2 1 </w>', 7),
 ('2 2 </w>', 8),
 ('2 3 </w>', 7),
 ('2 3 5 </w>', 1),
 ('2 4 </w>', 5),
 ('2 4 7 </w>', 1),
 ('2 5 </w>', 11),
 ('2 5 6 </w>', 1),
 ('2 6 </w>', 8),
 ('2 7 </w>', 8),
 ('2 8 </w>', 3),
 ('2 8 9 </w>', 1),
 ('2 9 </w>', 3),
 ('3 0 </w>', 8),
 ('3 0 8 </w>', 1),
 ('3 1 </w>', 3),
 ('3 2 </w>', 4),
 ('3 3 </w>', 2),
 ('3 4 </w>', 3),
 ('3 4 2 </w>', 1),
 ('3 5 </w>', 2),
 ('3 6 </w>', 3),
 ('3 7 </w>', 3),
 ('3 8 </w>', 1),
 ('3 9 </w>', 1),
 ('4 0 </w>', 1),
 ('4 1 </w>', 1),
 ('4 2 </w>', 3),
 ('4 2 3 </w>', 1),


In [6]:
def get_pair_stats(vocab: 'list[tuple[str, int]]'):
    pairs: 'dict[tuple[str], int]' = {}
    for word, frequency in vocab:
        symbols = word.split()

        # count occurrences of pairs
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            current_frequency = pairs.get(pair, 0)
            pairs[pair] = current_frequency + frequency

    return pairs

pairs = get_pair_stats(bpe_words)
pairs

{('1', '0'): 60,
 ('0', '</w>'): 78,
 ('0', '0'): 4,
 ('1', '1'): 41,
 ('1', '</w>'): 52,
 ('1', '3'): 37,
 ('3', '</w>'): 48,
 ('1', '2'): 30,
 ('2', '</w>'): 46,
 ('2', '8'): 5,
 ('8', '</w>'): 20,
 ('3', '3'): 3,
 ('1', '4'): 25,
 ('4', '</w>'): 34,
 ('4', '4'): 1,
 ('4', '2'): 6,
 ('4', 'd'): 1,
 ('d', 'e'): 3829,
 ('e', '</w>'): 7118,
 ('1', '5'): 16,
 ('5', '</w>'): 30,
 ('1', '6'): 25,
 ('6', '</w>'): 35,
 ('6', '0'): 4,
 ('0', '8'): 3,
 ('6', '1'): 1,
 ('1', '8'): 12,
 ('6', '8'): 1,
 ('8', '6'): 1,
 ('1', '7'): 17,
 ('7', '</w>'): 26,
 ('7', '3'): 5,
 ('3', '1'): 4,
 ('3', '4'): 7,
 ('3', '6'): 4,
 ('1', '9'): 6,
 ('9', '</w>'): 11,
 ('2', '0'): 6,
 ('2', '1'): 7,
 ('2', '2'): 8,
 ('2', '3'): 9,
 ('3', '5'): 3,
 ('2', '4'): 6,
 ('4', '7'): 2,
 ('2', '5'): 12,
 ('5', '6'): 1,
 ('2', '6'): 8,
 ('2', '7'): 8,
 ('8', '9'): 1,
 ('2', '9'): 3,
 ('3', '0'): 9,
 ('3', '2'): 4,
 ('3', '7'): 3,
 ('3', '8'): 1,
 ('3', '9'): 1,
 ('4', '0'): 1,
 ('4', '1'): 1,
 ('5', '3'): 1,
 ('7', '0'): 

In [7]:
def merge_vocab(best_pair: 'tuple[str, str]', vocab_in: 'list[tuple[str, int]]'):

    vocab_out: 'dict[str, int]' = {}

    # re.escape
    # ensures the characters of our input pair will be handled as is and
    # not get mistreated as special characters in the regular expression.
    pattern = r'(?<=\b)' + re.escape(' '.join(best_pair)) + r'( |$)'
    replacement = ''.join(best_pair) + r'\g<1>'

    for word_in, freq in vocab_in:
        # replace most frequent pair in all vocabulary
        word_out = re.sub(pattern, replacement, word_in)
        vocab_out[word_out] = freq

    return [(word, frequency) for word, frequency in vocab_out.items()]

In [31]:
best_pair = max(pairs, key=pairs.get)
print(best_pair)

new_vocab = merge_vocab(best_pair, bpe_words)
new_vocab

('r', '</w>')


[('1 0 </w>', 58),
 ('1 0 0 </w>', 2),
 ('1 1 </w>', 40),
 ('1 1 3 </w>', 1),
 ('1 2 </w>', 29),
 ('1 2 8 </w>', 1),
 ('1 3 </w>', 35),
 ('1 3 3 </w>', 1),
 ('1 4 </w>', 23),
 ('1 4 4 2 </w>', 1),
 ('1 4 d e </w>', 1),
 ('1 5 </w>', 16),
 ('1 6 </w>', 21),
 ('1 6 0 8 </w>', 2),
 ('1 6 1 8 </w>', 1),
 ('1 6 8 6 </w>', 1),
 ('1 7 </w>', 12),
 ('1 7 3 1 </w>', 1),
 ('1 7 3 4 </w>', 3),
 ('1 7 3 6 </w>', 1),
 ('1 8 </w>', 11),
 ('1 9 </w>', 6),
 ('2 0 </w>', 6),
 ('2 1 </w>', 7),
 ('2 2 </w>', 8),
 ('2 3 </w>', 7),
 ('2 3 5 </w>', 1),
 ('2 4 </w>', 5),
 ('2 4 7 </w>', 1),
 ('2 5 </w>', 11),
 ('2 5 6 </w>', 1),
 ('2 6 </w>', 8),
 ('2 7 </w>', 8),
 ('2 8 </w>', 3),
 ('2 8 9 </w>', 1),
 ('2 9 </w>', 3),
 ('3 0 </w>', 8),
 ('3 0 8 </w>', 1),
 ('3 1 </w>', 3),
 ('3 2 </w>', 4),
 ('3 3 </w>', 2),
 ('3 4 </w>', 3),
 ('3 4 2 </w>', 1),
 ('3 5 </w>', 2),
 ('3 6 </w>', 3),
 ('3 7 </w>', 3),
 ('3 8 </w>', 1),
 ('3 9 </w>', 1),
 ('4 0 </w>', 1),
 ('4 1 </w>', 1),
 ('4 2 </w>', 3),
 ('4 2 3 </w>', 1),


In [8]:
bpe_codes = OrderedDict()
num_merges = 1000  # hyperparameter
vocab = bpe_words
for i in tqdm.tqdm(range(num_merges)):
    # print('\niteration', i)
    pair_stats = get_pair_stats(vocab)
    if not pair_stats:
        break

    best_pair = max(pair_stats, key=pair_stats.get)
    bpe_codes[best_pair] = i

    # print('vocabulary: ', vocab)
    # print('best pair:', best_pair)
    vocab = merge_vocab(best_pair, vocab)

print('\nfinal vocabulary: ', vocab)
print('\nbyte pair encoding: ', bpe_codes)


100%|██████████| 1000/1000 [00:32<00:00, 30.54it/s]


final vocabulary:  [('10</w>', 58), ('1 0 0</w>', 2), ('11</w>', 40), ('1 13</w>', 1), ('1 2</w>', 29), ('1 2 8 </w>', 1), ('13</w>', 35), ('1 3 3</w>', 1), ('1 4</w>', 23), ('1 4 4 2</w>', 1), ('1 4 de</w>', 1), ('1 5 </w>', 16), ('1 6</w>', 21), ('1 6 0 8 </w>', 2), ('1 6 1 8 </w>', 1), ('1 6 8 6</w>', 1), ('1 7 </w>', 12), ('1 7 3 1</w>', 1), ('1 7 3 4</w>', 3), ('1 7 3 6</w>', 1), ('1 8 </w>', 11), ('1 9 </w>', 6), ('2 0</w>', 6), ('2 1</w>', 7), ('2 2</w>', 8), ('2 3</w>', 7), ('2 3 5 </w>', 1), ('2 4</w>', 5), ('2 4 7 </w>', 1), ('2 5 </w>', 11), ('2 5 6</w>', 1), ('2 6</w>', 8), ('2 7 </w>', 8), ('2 8 </w>', 3), ('2 8 9 </w>', 1), ('2 9 </w>', 3), ('3 0</w>', 8), ('3 0 8 </w>', 1), ('3 1</w>', 3), ('3 2</w>', 4), ('3 3</w>', 2), ('3 4</w>', 3), ('3 4 2</w>', 1), ('3 5 </w>', 2), ('3 6</w>', 3), ('3 7 </w>', 3), ('3 8 </w>', 1), ('3 9 </w>', 1), ('4 0</w>', 1), ('4 1</w>', 1), ('4 2</w>', 3), ('4 2 3</w>', 1), ('4 7 </w>', 1), ('5 3</w>', 1), ('6 0</w>', 1), ('6 0 0</w>', 1), ('




In [11]:
bpe_codes

1000

In [37]:
string.printable

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'

In [43]:
single_char_vocab = list(set((string.printable[:-5]+ 'åäö').lower())) + ['</w>', '[END]']
tokenizer_dict = defaultdict(lambda: 1)
tokenizer_dict['<pad>'] = 0
for idx, char in enumerate(single_char_vocab):
    tokenizer_dict[char] = idx+1
tokenizer_dict.update({''.join(word): idx+len(tokenizer_dict)+1 for word, idx in bpe_codes.items()})
# tokenizer_list
tokenizer_dict


defaultdict(<function __main__.<lambda>()>,
            {'<pad>': 0,
             ',': 1,
             ';': 2,
             '1': 3,
             '%': 4,
             '@': 5,
             '+': 6,
             '}': 7,
             '5': 8,
             'x': 9,
             'h': 10,
             ')': 11,
             '>': 12,
             'j': 13,
             ':': 14,
             'n': 15,
             'g': 16,
             '"': 17,
             'e': 18,
             'v': 19,
             '!': 20,
             'f': 21,
             's': 22,
             '-': 23,
             '?': 24,
             '_': 25,
             '~': 26,
             '<': 27,
             '4': 28,
             'p': 29,
             '`': 30,
             'y': 31,
             '2': 32,
             'o': 33,
             'k': 34,
             '$': 35,
             '7': 36,
             '.': 37,
             'q': 38,
             'i': 39,
             'b': 40,
             '[': 41,
             '&': 42,
             '#'

In [32]:
len(tokenizer_dict)

1127

### Regex challenges
1. Naive implementation - Slow (45 minutes entire 1734 law 1000 tokens)
2. Trie - No priority
3. \b Word boundaries - '>' counts as a word boundary
4. Tensorflow - Final solution (2 minutes entire 1734 law 5000 tokens)

In [31]:
def tokenize(corpus: list[str], bpe: 'OrderedDict[tuple[str, str], int]'):
    corpus = [sentance.lower() for sentance in corpus]
    words = []
    for sentance in corpus:
        words += [' '.join(list(word)) + ' </w>' for word in sentance.split()]
        words += ['[END]']
    str_tensor = tf.constant(words)
    print(len(words))
    for pair in tqdm.tqdm(bpe):
        pattern = r'(\b)' + re.escape(' '.join(pair)) + r'( |$)'
        replacement = r'\1' + ''.join(pair) + r'\2'
        str_tensor = tf.strings.regex_replace(str_tensor, pattern, replacement)  
    tokens = []
    token_sentance = []
    for idx, word in enumerate(str_tensor):
        word = word.numpy().decode('utf-8')
        bpe_tokens = word.split()
        tokenization = [tokenizer_dict[token] for token in bpe_tokens]
        token_sentance += tokenization
        if word == '[END]':
            tokens.append(token_sentance)
            token_sentance = []
    return tokens
    
tokenized_corpus = tokenize(sentances, bpe_codes)
print(*tokenized_corpus, sep='\n')

104357


100%|██████████| 1000/1000 [00:22<00:00, 45.43it/s]


[91, 143, 233, 80, 75, 79, 75, 79, 294, 101, 79, 61, 36, 72, 171, 15, 227, 116, 616, 14, 134, 677, 61, 17, 15, 63, 78, 0, 72, 131, 82, 882, 225, 21, 996, 116, 439, 159, 198, 673, 593, 36, 72, 73]
[82, 882, 225, 824, 151, 138, 658, 0, 72, 126, 129, 82, 779, 44, 72, 486, 165, 0, 72, 89, 118, 153, 611, 129, 110, 78, 0, 72, 118, 399, 61, 78, 779, 309, 127, 0, 72, 89, 509, 239, 385, 805, 90, 0, 72, 89, 901, 399, 61, 78, 660, 0, 72, 106, 129, 104, 259, 513, 0, 72, 779, 9, 351, 127, 0, 72, 230, 118, 129, 104, 287, 91, 696, 351, 779, 44, 72, 153, 611, 1005, 72, 779, 44, 72, 47, 17, 15, 177, 0, 72, 229, 72, 126, 118, 68, 75, 213, 352, 109, 473, 779, 44, 72, 370, 0, 72, 89, 509, 616, 357, 178, 128, 83, 36, 72, 73]
[89, 825, 388, 513, 291, 131, 595, 18, 522, 62, 72, 0, 72, 126, 129, 434, 78, 21, 996, 213, 12, 21, 388, 82, 431, 0, 72, 126, 129, 509, 1005, 72, 779, 44, 72, 47, 17, 15, 177, 202, 298, 82, 302, 95, 87, 370, 0, 72, 157, 9, 687, 360, 129, 202, 298, 211, 127, 82, 17, 47, 996, 110, 448, 3

# Embedding

In [33]:
inverse_vocab = {index: token for token, index in tokenizer_dict.items()}

In [34]:
window_size = 2
vocab_size = len([token for token in tokenizer_dict.values() if token != -1])
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      tokenized_corpus[0],
      vocabulary_size=vocab_size,
      window_size=window_size,
      negative_samples=0)
positive_skip_grams

[[198, 593],
 [116, 227],
 [233, 91],
 [15, 17],
 [72, 61],
 [14, 116],
 [91, 233],
 [79, 101],
 [159, 198],
 [171, 72],
 [15, 171],
 [80, 233],
 [79, 75],
 [79, 294],
 [101, 294],
 [61, 72],
 [198, 673],
 [673, 198],
 [63, 17],
 [72, 15],
 [72, 82],
 [14, 677],
 [616, 116],
 [593, 36],
 [225, 21],
 [439, 116],
 [882, 225],
 [75, 79],
 [143, 80],
 [72, 131],
 [227, 15],
 [72, 171],
 [673, 159],
 [17, 15],
 [79, 101],
 [134, 61],
 [171, 227],
 [63, 15],
 [131, 882],
 [159, 439],
 [63, 78],
 [116, 21],
 [15, 63],
 [143, 233],
 [79, 294],
 [21, 996],
 [439, 159],
 [294, 79],
 [225, 82],
 [171, 15],
 [996, 225],
 [21, 225],
 [82, 72],
 [101, 79],
 [116, 439],
 [198, 159],
 [73, 36],
 [80, 143],
 [61, 79],
 [61, 17],
 [73, 72],
 [233, 80],
 [61, 101],
 [225, 882],
 [72, 36],
 [14, 134],
 [882, 131],
 [75, 294],
 [79, 79],
 [36, 61],
 [677, 14],
 [79, 75],
 [79, 79],
 [36, 673],
 [225, 996],
 [72, 593],
 [134, 677],
 [116, 996],
 [15, 61],
 [134, 616],
 [82, 225],
 [78, 63],
 [198, 439],
 [9

In [36]:
print([inverse_vocab[t] for t in tokenized_corpus[0]])

['d</w>', 'o</w>', 'm</w>', 'a</w>', 'r</w>', 'e</w>', 'r</w>', 'e</w>', 'g</w>', 'l</w>', 'e</w>', 'r', '.', '</w>', 'nå', 'g', 're</w>', 'al', 'men', 'n', 'el', 'ige</w>', 'r', 'e', 'g', 'l', 'er</w>', ',', '</w>', 'ther</w>', 'en</w>', 'domare</w>', 'skal</w>', 's', 'igh</w>', 'al', 'del', 'es</w>', 'ef', 'fter</w>', 'rätta</w>', '.', '</w>', '[END]']


In [37]:
for target, context in positive_skip_grams[:5]:
  print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")

(198, 593): (ef, rätta</w>)
(116, 227): (al, re</w>)
(233, 91): (m</w>, d</w>)
(15, 17): (g, e)
(72, 61): (</w>, r)


In [38]:
# Get target and context words for one positive skip-gram.
target_word, context_word = positive_skip_grams[0]

# Set the number of negative samples per positive context.
num_ns = 4

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class,  # class that should be sampled as 'positive'
    num_true=1,  # each positive skip-gram has 1 positive context class
    num_sampled=num_ns,  # number of negative context words to sample
    unique=True,  # all the negative samples should be unique
    range_max=vocab_size,  # pick index of the samples from [0, vocab_size]
    # seed=SEED,  # seed for reproducibility
    name="negative_sampling"  # name of this operation
)
print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])

tf.Tensor([663  87 244   6], shape=(4,), dtype=int64)
['förs', 'an</w>', 'da</w>', '}']


In [39]:
# Reduce a dimension so you can use concatenation (in the next step).
squeezed_context_class = tf.squeeze(context_class, 1)

# Concatenate a positive context word with negative sampled words.
context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)

# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64")
target = target_word

In [40]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 198
target_word     : ef
context_indices : [593 663  87 244   6]
context_words   : ['rätta</w>', 'förs', 'an</w>', 'da</w>', '}']
label           : [1 0 0 0 0]


In [53]:
print("target  :", target)
print("context :", context)
print("label   :", label)

target  : 74
context : tf.Tensor([513   1   6 803 116], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int64)


In [41]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [42]:
targets, contexts, labels = generate_training_data(tokenized_corpus, window_size=2, num_ns=4, vocab_size=vocab_size, seed=42)
targets, contexts, labels = np.array(targets), np.array(contexts), np.array(labels)
print('\n')
print(f"targets.shape: {targets.shape}")
print(f"contexts.shape: {contexts.shape}")
print(f"labels.shape: {labels.shape}")

  0%|          | 23/6237 [00:00<00:55, 111.53it/s]


InvalidArgumentError: {{function_node __wrapped__LogUniformCandidateSampler_device_/job:localhost/replica:0/task:0/device:CPU:0}} `true_candidate` out of range [0, 1074), received -1 [Op:LogUniformCandidateSampler] name: negative_sampling

In [19]:
BATCH_SIZE = 32
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

In [20]:
class EmbeddingModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(EmbeddingModel, self).__init__()
    self.target_embedding = Embedding(vocab_size,
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding")
    self.context_embedding = Embedding(vocab_size,
                                       embedding_dim,
                                       input_length=num_ns+1)

  def call(self, pair):
    target, context = pair
    # target: (batch, dummy?)  # The dummy axis doesn't exist in TF2.7+
    # context: (batch, context)
    if len(target.shape) == 2:
      target = tf.squeeze(target, axis=1)
    # target: (batch,)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

In [21]:
embedding_dim = 32

model = EmbeddingModel(vocab_size, embedding_dim)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='accuracy', min_delta=0.005, patience=2)

model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit(dataset, epochs=100, callbacks=[tensorboard_callback, early_stopping_callback])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100


<keras.callbacks.History at 0x1d6511edd50>

In [None]:
model.summary()

Model: "embedding_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 w2v_embedding (Embedding)   multiple                  165792    
                                                                 
 embedding_1 (Embedding)     multiple                  165792    
                                                                 
Total params: 331,584
Trainable params: 331,584
Non-trainable params: 0
_________________________________________________________________


In [60]:
%tensorboard --logdir logs

In [22]:
weights = model.get_layer('w2v_embedding').get_weights()[0]
vocab = list(tokenizer_dict.keys())

In [24]:
import io
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

out_v.write('\t'.join([str(0) for _ in range(embedding_dim)]) + '\n')
out_m.write('<pad>\t0\n')
out_v.write('\t'.join([str(0) for _ in range(embedding_dim)]) + '\n')
out_m.write('<oov>\t1\n')

for index, (word, token) in enumerate(tokenizer_dict.items()):
  if index == 0:
    continue  # skip 0, it's padding.
  vec = weights[index]
  # out_v.write(word + ',')
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(f'{word}\t{token}\n')
out_v.close()
out_m.close()

In [28]:
print(*tokenizer_dict.items(), sep='\n')

('<pad>', 0)
(',', 2)
(';', 3)
('1', 4)
('%', 5)
('@', 6)
('+', 7)
('}', 8)
('5', 9)
('x', 10)
('h', 11)
(')', 12)
('>', 13)
('j', 14)
(':', 15)
('n', 16)
('g', 17)
('"', 18)
('e', 19)
('v', 20)
('!', 21)
('f', 22)
('s', 23)
('-', 24)
('?', 25)
('_', 26)
('~', 27)
('<', 28)
('4', 29)
('p', 30)
('`', 31)
('y', 32)
('2', 33)
('o', 34)
('k', 35)
('$', 36)
('7', 37)
('.', 38)
('q', 39)
('i', 40)
('b', 41)
('[', 42)
('&', 43)
('#', 44)
('m', 45)
('z', 46)
('d', 47)
('(', 48)
('w', 49)
(' ', 50)
('{', 51)
('3', 52)
('/', 53)
('^', 54)
('|', 55)
('u', 56)
(']', 57)
('c', 58)
('a', 59)
('0', 60)
('\\', 61)
('9', 62)
('r', 63)
('å', 64)
('l', 65)
('ö', 66)
('6', 67)
("'", 68)
('*', 69)
('ä', 70)
('8', 71)
('=', 72)
('t', 73)
('</w>', 74)
('[END]', 75)
('r</w>', 77)
('n</w>', 78)
('t</w>', 79)
('er</w>', 80)
('e</w>', 81)
('a</w>', 82)
('th', 83)
('en</w>', 84)
('s</w>', 85)
('om', 86)
('ch', 87)
('ll', 88)
('an</w>', 89)
('och', 90)
('och</w>', 91)
('om</w>', 92)
('d</w>', 93)
('ar', 94)
('ti',

# TODO
* <s>Load stopwords from file</s>
* <s>Remove stopwords from corpus and vocab</s>
* <s>Use entire law corpus</s>
* <s>Improve BPE performance</s>
* Use Wikipedia corpus
* Create embedder from weights

In [68]:
import csv
import pickle
import tensorflow as tf
import tqdm

class TokenizerEmbedder():
    PAD = 0

    def __init__(self, bpe_file: str, vectors_file: str, tokens_file: str):
        self._embedder_table = []
        self._tokens = defaultdict(lambda: 1)

        with open(bpe_file, 'rb') as fd:
            self._bpe = pickle.load(fd)

        with open(vectors_file, encoding='utf-8') as fd:
            rd = csv.reader(fd, delimiter="\t", quotechar='"')
            for row in rd:
                self._embedder_table.append([float(v) for v in row])

        with open(tokens_file, encoding='utf-8') as fd:
            rd = csv.reader(fd, delimiter="\t", quotechar=None)
            for idx, row in enumerate(rd):
                self._tokens[row[0]] = idx+2

        self._inverse_tokens = {index: token for token, index in self._tokens.items()}
                
    def tokenize(self, corpus: list[str]):
        corpus = [sentance.lower() for sentance in corpus]
        words = []
        for sentance in corpus:
            words += [' '.join(list(word)) + ' </w>' for word in sentance.split()]
            words += ['[END]']
        print(words)
        str_tensor = tf.constant(words)
        for pair in tqdm.tqdm(self._bpe):
            pattern = r'(\b)' + re.escape(' '.join(pair)) + r'( |$)'
            replacement = r'\1' + ''.join(pair) + r'\2'
            str_tensor = tf.strings.regex_replace(str_tensor, pattern, replacement)  
        tokens: list[list[int]] = []
        tokens_sentance = []
        for word in str_tensor:
            word = word.numpy().decode('utf-8')
            bpe_tokens = word.split()
            tokenization = [self._tokens[token] for token in bpe_tokens]
            tokens_sentance += tokenization
            if word == '[END]':
                tokens.append(tokens_sentance)
                tokens_sentance = []
        return tokens

    def embed(self, tokens: list[list[int]]) -> list[list[list[float]]]:
        embedding = []
        for sentance in tokens:
            embeddings_sentance = []
            for token in sentance:
                embeddings_sentance.append(self._embedder_table[token-2])
            embedding.append(embeddings_sentance)
        return embedding
    
    def __call__(self, corpus: list[str]):
        tokens = self.tokenize(corpus)
        embeddings = self.embed(tokens)
        return embeddings

    def __getitem__(self, key):
        if isinstance(key, int):
            return (self._inverse_tokens[key], self._embedder_table[key-2])
        elif isinstance(key, str):
            return self._tokens[key]
        else:
            raise TypeError(f'{key} is not a valid indexing type.')
            

te = TokenizerEmbedder('bpe.pckl', 'vectors.tsv', 'metadata.tsv')
print(te(['Victor och Ahmed har gjort en embedder']))
print(te['och'])
print(te[169])


['v i c t o r </w>', 'o c h </w>', 'a h m e d </w>', 'h a r </w>', 'g j o r t </w>', 'e n </w>', 'e m b e d d e r </w>', '[END]']


100%|██████████| 5000/5000 [00:00<00:00, 8965.46it/s] 

[[[0.65323406, -0.27303812, -0.30425933, -1.0413593, -0.7426681, -0.20532407, 0.04982093, 0.08311119, 0.1765727, -0.6403428, -0.22497958, -0.78932726, -0.054533683, 0.033854883, 0.7027275, 0.45232072, 0.49789938, -0.18292224, -0.41658378, -0.34115317, -0.50937414, -0.5524997, 0.45004642, -0.6512374, 1.3526832, 0.64217854, -0.32416457, -0.49196768, -1.0151479, 0.11888651, -0.18898936, 0.74240786], [-0.01636193, 0.031975772, -0.041453518, -0.014624942, 0.010977149, -0.0025534257, -0.042255737, -0.004006099, -0.028320504, 0.01004646, -0.045394193, 0.02495325, -0.043610133, 0.041049752, 0.013031509, -0.041475464, -0.036164366, 0.040030446, 0.01384211, -0.03294321, -0.03575196, 0.04388896, 0.01328193, 0.011511408, -0.015473496, -0.02524804, 0.04039445, -0.015838575, -0.029515004, 0.016001377, -0.0066956766, 0.029326748], [-0.029489195, -0.017436467, 0.043717716, -0.016688216, -0.008654725, -0.016261052, -0.0192454, 0.023500714, 0.03389578, -0.03577707, -0.031029522, -0.02989391, -0.03589612




89