In [33]:
import string
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from trie import Trie
import re
import unicodedata
import tqdm

In [19]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [20]:
#Import dataset
tree = ET.parse('lag1734.xml/lag1734.xml')
root = tree.getroot()
root.attrib

{'id': 'lag1734'}

In [21]:
# Import stopwords
with open('stopwords.txt', 'r', encoding='utf-8') as f:
    stop_words = [line.strip() for line in f.readlines()]
stop_words

['aderton',
 'adertonde',
 'adjö',
 'aldrig',
 'alla',
 'allas',
 'allt',
 'alltid',
 'alltså',
 'andra',
 'andras',
 'annan',
 'annat',
 'artonde',
 'artonn',
 'att',
 'av',
 'bakom',
 'bara',
 'behöva',
 'behövas',
 'behövde',
 'behövt',
 'beslut',
 'beslutat',
 'beslutit',
 'bland',
 'blev',
 'bli',
 'blir',
 'blivit',
 'bort',
 'borta',
 'bra',
 'bäst',
 'bättre',
 'båda',
 'bådas',
 'dag',
 'dagar',
 'dagarna',
 'dagen',
 'de',
 'del',
 'delen',
 'dem',
 'den',
 'denna',
 'deras',
 'dess',
 'dessa',
 'det',
 'detta',
 'dig',
 'din',
 'dina',
 'dit',
 'ditt',
 'dock',
 'dom',
 'du',
 'där',
 'därför',
 'då',
 'e',
 'efter',
 'eftersom',
 'ej',
 'elfte',
 'eller',
 'elva',
 'emot',
 'en',
 'enkel',
 'enkelt',
 'enkla',
 'enligt',
 'ens',
 'er',
 'era',
 'ers',
 'ert',
 'ett',
 'ettusen',
 'fanns',
 'fem',
 'femte',
 'femtio',
 'femtionde',
 'femton',
 'femtonde',
 'fick',
 'fin',
 'finnas',
 'finns',
 'fjorton',
 'fjortonde',
 'fjärde',
 'fler',
 'flera',
 'flesta',
 'fram',
 'framf

In [22]:
stopword_trie = Trie()
stopword_trie.add_multiple(*stop_words)
stop_words_regex = re.compile(r'\b' + stopword_trie.pattern() + r'\b', re.IGNORECASE)


In [23]:
chapters = []
for child in root:
    chapters.append(child)

sentances = []
for chapter in chapters:
    for paragraph in chapter:
        for sentence in paragraph:
            sentance = ''
            for child in sentence:
                sentance += f'{child.text} '
            sentances.append(sentance[:-1])
sentances

['D O M A R E R E G L E R. Någre almennelige Regler , ther en Domare skal sigh aldeles effter rätta .',
 'En Domare skal först besinna , at han en Gudz Befalningsman , och thet Embete han förer , thet hörer Gudh til , och icke honom sielffuom , och therföre hörer Domen , som han afsäger , Gudhi til , efter thet han afsagd warder i Gudz Embete på Gudz wegna , så at thet är wisserliga Gudz Dom , och icke Menniskiors .',
 'Och ty ligger Domaren ther Macht vppå , at han seer sigh wijsligen före , at han icke på Gudz wegna dömer en falskan Dom , med hwilken han dömer sig til en ewigh Fördömelse , effter thet han misbrukat Guds Dom och Befalning til Öffuerwold och Orätt , som til Rätt af Gudhi insatt är .',
 'Men ther han haffuer wilia til at döma Rätt , och ransakar grant effter sitt ytersta Förstånd om Rätten , och kan dock icke för sin Oförståndigheet finna på Rätten , och säger så en falsk Dom , tå haffuer han någor Vrsächt , at han är kommen på then falska Domen , emot sin wilia aff wåd

In [24]:
#Pre-processing
for idx, sentance in enumerate(sentances):
    sentance = unicodedata.normalize('NFKC', sentance)
    sentance = stop_words_regex.sub('', sentance)
    sentances[idx] = sentance

In [25]:
#Tokenization
tfidf = TfidfVectorizer(stop_words = stop_words)
X = tfidf.fit_transform(sentances)
print(*tfidf.get_feature_names_out())



In [26]:
feature_array = np.array(tfidf.get_feature_names_out())
tfidf_sorting = np.argsort(X.toarray().sum(axis=0))[::-1]
n = 100
top_n = feature_array[tfidf_sorting[:n]]
print(top_n)

['cap' 'thet' 'then' 'ther' 'til' 'at' 'tå' 'the' 'af' 'böte' 'daler'
 'sagdt' 'må' 'någor' 'ock' 'skal' '10' 'vare' 'them' 'hafver' '11' 'äro'
 'domaren' 'giör' 'konungens' 'åter' 'lag' 'hafve' 'sägs' 'saken' 'förr'
 'thes' 'gånge' 'gods' 'vil' '12' 'hvar' 'skada' 'tid' '13' 'bör' 'sker'
 'huru' 'rätten' 'hus' 'hofrätten' 'laga' 'annars' 'tijo' 'ware' 'äntå'
 'balken' 'ifrån' 'staden' 'sielf' 'jord' 'skadan' 'alt' 'skola' 'äger'
 'barn' 'landet' 'thertil' 'niute' 'plichte' 'hvad' 'konungen' 'stadgadt'
 '14' 'gifve' 'tage' 'hafva' 'gälde' 'finnes' 'varder' 'måge' 'theras'
 'tiugu' 'öfver' 'fä' 'thy' 'miste' 'up' 'åhr' 'hos' 'hafwer' 'emellan'
 'skiäl' 'therom' 'ske' 'arf' 'fängelse' 'befalningshafvande' 'mål' '16'
 '15' 'theraf' 'mans' 'stånde' 'lof']


In [27]:
feature_array = np.array(tfidf.get_feature_names_out())
tfidf_sorting = np.argsort(X.toarray().sum(axis=0))
n = 1000
top_n = feature_array[tfidf_sorting[:n]]
print(top_n)

['västerbotn' 'södermanland' 'skaraborgs' 'rautalambi' 'nyland' 'kalmare'
 'skåne' 'kymmenegårds' 'bohuslän' 'dahl' 'gestrikeland' 'ångermanland'
 'vermeland' 'ingifve' 'upland' 'nerike' 'kopparbergs' 'västmanland'
 'blekinge' 'halland' 'östergöthland' 'österbotn' 'herjedalen' 'jämteland'
 'gothland' 'helsingeland' 'småland' 'göta' 'tavastehus' 'lagmansdomen'
 'medelpad' 'elfsborgs' 'förändringar' 'riksdag' 'företrädare' 'brukliga'
 'adolph' 'samtelige' 'stadslagens' 'behöringen' 'sorgfällighet'
 'esomoftast' 'ändskap' 'berömmelig' 'brukelig' 'mångahanda' 'våre'
 'vittre' 'gångne' 'gustaf' 'förändrat' 'ändrade' 'hälsosamt' 'sedvana'
 'angelägit' 'förständiga' 'svårigheter' 'förbättra' '1731' '1618' 'daga'
 'öfversedde' 'emellankomna' 'förbättrade' 'önskan' 'högloflige' 'härtils'
 'ehuruväl' 'sinnande' 'svänska' 'nöigt' 'lagfarne' 'utgifvande'
 'utesluta' 'bringas' 'högtärade' 'lärda' 'moederfaders' 'grundval' 'fant'
 'förenad' 'månsons' 'stadfästades' 'efterlevande' 'spa' 'födelse'
 'p

In [28]:
counter = CountVectorizer(stop_words=stop_words)
Y = counter.fit_transform(sentances)
Y_sum = Y.toarray().sum(axis=0)
bpe_words = list(zip(map(lambda x: ' '.join(list(x)) + ' </w>', counter.get_feature_names_out()), Y_sum))
bpe_words

[('1 0 </w>', 58),
 ('1 0 0 </w>', 2),
 ('1 1 </w>', 40),
 ('1 1 3 </w>', 1),
 ('1 2 </w>', 29),
 ('1 2 8 </w>', 1),
 ('1 3 </w>', 35),
 ('1 3 3 </w>', 1),
 ('1 4 </w>', 23),
 ('1 4 4 2 </w>', 1),
 ('1 4 d e </w>', 1),
 ('1 5 </w>', 16),
 ('1 6 </w>', 21),
 ('1 6 0 8 </w>', 2),
 ('1 6 1 8 </w>', 1),
 ('1 6 8 6 </w>', 1),
 ('1 7 </w>', 12),
 ('1 7 3 1 </w>', 1),
 ('1 7 3 4 </w>', 3),
 ('1 7 3 6 </w>', 1),
 ('1 8 </w>', 11),
 ('1 9 </w>', 6),
 ('2 0 </w>', 6),
 ('2 1 </w>', 7),
 ('2 2 </w>', 8),
 ('2 3 </w>', 7),
 ('2 3 5 </w>', 1),
 ('2 4 </w>', 5),
 ('2 4 7 </w>', 1),
 ('2 5 </w>', 11),
 ('2 5 6 </w>', 1),
 ('2 6 </w>', 8),
 ('2 7 </w>', 8),
 ('2 8 </w>', 3),
 ('2 8 9 </w>', 1),
 ('2 9 </w>', 3),
 ('3 0 </w>', 8),
 ('3 0 8 </w>', 1),
 ('3 1 </w>', 3),
 ('3 2 </w>', 4),
 ('3 3 </w>', 2),
 ('3 4 </w>', 3),
 ('3 4 2 </w>', 1),
 ('3 5 </w>', 2),
 ('3 6 </w>', 3),
 ('3 7 </w>', 3),
 ('3 8 </w>', 1),
 ('3 9 </w>', 1),
 ('4 0 </w>', 1),
 ('4 1 </w>', 1),
 ('4 2 </w>', 3),
 ('4 2 3 </w>', 1),


In [29]:
def get_pair_stats(vocab: 'list[tuple[str, int]]'):
    pairs: 'dict[tuple[str], int]' = {}
    for word, frequency in vocab:
        symbols = word.split()

        # count occurrences of pairs
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            current_frequency = pairs.get(pair, 0)
            pairs[pair] = current_frequency + frequency

    return pairs

pairs = get_pair_stats(bpe_words)
pairs

{('1', '0'): 60,
 ('0', '</w>'): 78,
 ('0', '0'): 4,
 ('1', '1'): 41,
 ('1', '</w>'): 52,
 ('1', '3'): 37,
 ('3', '</w>'): 48,
 ('1', '2'): 30,
 ('2', '</w>'): 46,
 ('2', '8'): 5,
 ('8', '</w>'): 20,
 ('3', '3'): 3,
 ('1', '4'): 25,
 ('4', '</w>'): 34,
 ('4', '4'): 1,
 ('4', '2'): 6,
 ('4', 'd'): 1,
 ('d', 'e'): 3561,
 ('e', '</w>'): 6434,
 ('1', '5'): 16,
 ('5', '</w>'): 30,
 ('1', '6'): 25,
 ('6', '</w>'): 35,
 ('6', '0'): 4,
 ('0', '8'): 3,
 ('6', '1'): 1,
 ('1', '8'): 12,
 ('6', '8'): 1,
 ('8', '6'): 1,
 ('1', '7'): 17,
 ('7', '</w>'): 26,
 ('7', '3'): 5,
 ('3', '1'): 4,
 ('3', '4'): 7,
 ('3', '6'): 4,
 ('1', '9'): 6,
 ('9', '</w>'): 11,
 ('2', '0'): 6,
 ('2', '1'): 7,
 ('2', '2'): 8,
 ('2', '3'): 9,
 ('3', '5'): 3,
 ('2', '4'): 6,
 ('4', '7'): 2,
 ('2', '5'): 12,
 ('5', '6'): 1,
 ('2', '6'): 8,
 ('2', '7'): 8,
 ('8', '9'): 1,
 ('2', '9'): 3,
 ('3', '0'): 9,
 ('3', '2'): 4,
 ('3', '7'): 3,
 ('3', '8'): 1,
 ('3', '9'): 1,
 ('4', '0'): 1,
 ('4', '1'): 1,
 ('5', '3'): 1,
 ('7', '0'): 

In [30]:
def merge_vocab(best_pair: 'tuple[str, str]', vocab_in: 'list[tuple[str, int]]'):

    vocab_out: 'dict[str, int]' = {}

    # re.escape
    # ensures the characters of our input pair will be handled as is and
    # not get mistreated as special characters in the regular expression.
    pattern = r'(?<=\b)' + re.escape(' '.join(best_pair)) + r'( |$)'
    replacement = ''.join(best_pair) + r'\g<1>'

    for word_in, freq in vocab_in:
        # replace most frequent pair in all vocabulary
        word_out = re.sub(pattern, replacement, word_in)
        vocab_out[word_out] = freq

    return [(word, frequency) for word, frequency in vocab_out.items()]

In [31]:
best_pair = max(pairs, key=pairs.get)
print(best_pair)

new_vocab = merge_vocab(best_pair, bpe_words)
new_vocab

('r', '</w>')


[('1 0 </w>', 58),
 ('1 0 0 </w>', 2),
 ('1 1 </w>', 40),
 ('1 1 3 </w>', 1),
 ('1 2 </w>', 29),
 ('1 2 8 </w>', 1),
 ('1 3 </w>', 35),
 ('1 3 3 </w>', 1),
 ('1 4 </w>', 23),
 ('1 4 4 2 </w>', 1),
 ('1 4 d e </w>', 1),
 ('1 5 </w>', 16),
 ('1 6 </w>', 21),
 ('1 6 0 8 </w>', 2),
 ('1 6 1 8 </w>', 1),
 ('1 6 8 6 </w>', 1),
 ('1 7 </w>', 12),
 ('1 7 3 1 </w>', 1),
 ('1 7 3 4 </w>', 3),
 ('1 7 3 6 </w>', 1),
 ('1 8 </w>', 11),
 ('1 9 </w>', 6),
 ('2 0 </w>', 6),
 ('2 1 </w>', 7),
 ('2 2 </w>', 8),
 ('2 3 </w>', 7),
 ('2 3 5 </w>', 1),
 ('2 4 </w>', 5),
 ('2 4 7 </w>', 1),
 ('2 5 </w>', 11),
 ('2 5 6 </w>', 1),
 ('2 6 </w>', 8),
 ('2 7 </w>', 8),
 ('2 8 </w>', 3),
 ('2 8 9 </w>', 1),
 ('2 9 </w>', 3),
 ('3 0 </w>', 8),
 ('3 0 8 </w>', 1),
 ('3 1 </w>', 3),
 ('3 2 </w>', 4),
 ('3 3 </w>', 2),
 ('3 4 </w>', 3),
 ('3 4 2 </w>', 1),
 ('3 5 </w>', 2),
 ('3 6 </w>', 3),
 ('3 7 </w>', 3),
 ('3 8 </w>', 1),
 ('3 9 </w>', 1),
 ('4 0 </w>', 1),
 ('4 1 </w>', 1),
 ('4 2 </w>', 3),
 ('4 2 3 </w>', 1),


In [35]:
bpe_codes = OrderedDict()
num_merges = 5000  # hyperparameter
vocab = bpe_words
for i in tqdm.tqdm(range(num_merges)):
    # print('\niteration', i)
    pair_stats = get_pair_stats(vocab)
    if not pair_stats:
        break

    best_pair = max(pair_stats, key=pair_stats.get)
    bpe_codes[best_pair] = i

    # print('vocabulary: ', vocab)
    # print('best pair:', best_pair)
    vocab = merge_vocab(best_pair, vocab)

print('\nfinal vocabulary: ', vocab)
print('\nbyte pair encoding: ', bpe_codes)


100%|██████████| 5000/5000 [10:04<00:00,  8.27it/s]


final vocabulary:  [('10</w>', 58), ('100</w>', 2), ('11</w>', 40), ('1 13</w>', 1), ('12</w>', 29), ('1 28</w>', 1), ('13</w>', 35), ('1 33</w>', 1), ('14</w>', 23), ('14 42</w>', 1), ('14 de</w>', 1), ('15</w>', 16), ('16</w>', 21), ('1608</w>', 2), ('16 18</w>', 1), ('16 8 6</w>', 1), ('17</w>', 12), ('17 31</w>', 1), ('1734</w>', 3), ('17 36</w>', 1), ('18</w>', 11), ('19</w>', 6), ('20</w>', 6), ('21</w>', 7), ('22</w>', 8), ('23</w>', 7), ('2 35</w>', 1), ('24</w>', 5), ('2 47</w>', 1), ('25</w>', 11), ('2 5 6</w>', 1), ('26</w>', 8), ('27</w>', 8), ('28</w>', 3), ('2 8 9</w>', 1), ('29</w>', 3), ('30</w>', 8), ('3 08</w>', 1), ('31</w>', 3), ('32</w>', 4), ('33</w>', 2), ('34</w>', 3), ('3 42</w>', 1), ('35</w>', 2), ('36</w>', 3), ('37</w>', 3), ('3 8</w>', 1), ('3 9</w>', 1), ('4 0</w>', 1), ('4 1</w>', 1), ('42</w>', 3), ('4 23</w>', 1), ('47</w>', 1), ('5 3</w>', 1), ('6 0</w>', 1), ('6 00</w>', 1), ('7 00</w>', 1), ('8 7</w>', 1), ('acht</w>', 11), ('a cht andes</w>', 1), 




In [36]:
bpe_codes

OrderedDict([(('r', '</w>'), 0),
             (('t', '</w>'), 1),
             (('e', '</w>'), 2),
             (('n', '</w>'), 3),
             (('t', 'h'), 4),
             (('th', 'e'), 5),
             (('a', '</w>'), 6),
             (('s', '</w>'), 7),
             (('e', 'r</w>'), 8),
             (('a', 'r'), 9),
             (('t', 'i'), 10),
             (('e', 'n</w>'), 11),
             (('s', 'k'), 12),
             (('l', '</w>'), 13),
             (('n', 'g'), 14),
             (('s', 't'), 15),
             (('d', '</w>'), 16),
             (('ö', 'r'), 17),
             (('a', 'f'), 18),
             (('n', 'd'), 19),
             (('a', 'l'), 20),
             (('a', 'g'), 21),
             (('f', 'ör'), 22),
             (('e', 'r'), 23),
             (('a', 'd'), 24),
             (('a', 'r</w>'), 25),
             (('r', 'ä'), 26),
             (('e', 'n'), 27),
             (('t', 't'), 28),
             (('å', '</w>'), 29),
             (('the', 't</w>'), 30),
  

In [37]:
string.printable

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'

In [38]:
single_char_vocab = list(set((string.printable[:-5]+ 'åäö').lower())) + ['</w>', '[END]']
tokenizer_dict = defaultdict(lambda: 1)
tokenizer_dict['<pad>'] = 0
for idx, char in enumerate(single_char_vocab):
    tokenizer_dict[char] = idx+2
tokenizer_dict.update({''.join(word): idx+len(tokenizer_dict)+2 for word, idx in bpe_codes.items()})
# tokenizer_list
tokenizer_dict


defaultdict(<function __main__.<lambda>()>,
            {'<pad>': 0,
             '1': 2,
             '`': 3,
             'o': 4,
             '{': 5,
             'a': 6,
             '(': 7,
             '-': 8,
             '9': 9,
             "'": 10,
             'ä': 11,
             'm': 12,
             'w': 13,
             '@': 14,
             '\\': 15,
             'å': 16,
             'ö': 17,
             ',': 18,
             'i': 19,
             's': 20,
             'r': 21,
             '"': 22,
             '<': 23,
             'g': 24,
             'j': 25,
             'u': 26,
             ')': 27,
             'n': 28,
             'l': 29,
             '~': 30,
             '8': 31,
             'e': 32,
             '#': 33,
             '+': 34,
             'f': 35,
             'c': 36,
             '_': 37,
             'b': 38,
             '2': 39,
             '$': 40,
             '4': 41,
             '^': 42,
             '|': 43,
             '

### Regex challenges
1. Naive implementation - Slow (45 minutes entire 1734 law 1000 tokens)
2. Trie - No priority
3. \b Word boundaries - '>' counts as a word boundary
4. Tensorflow - Final solution (2 minutes entire 1734 law 5000 tokens)

In [39]:
def tokenize(corpus: list[str], bpe: 'OrderedDict[tuple[str, str], int]'):
    corpus = [sentance.lower() for sentance in corpus]
    words = []
    for sentance in corpus:
        words += [' '.join(list(word)) + ' </w>' for word in sentance.split()]
        words += '[END]'
    str_tensor = tf.constant(words)
    print(len(words))
    for pair in tqdm.tqdm(bpe):
        pattern = r'(\b)' + re.escape(' '.join(pair)) + r'( |$)'
        replacement = r'\1' + ''.join(pair) + r'\2'
        str_tensor = tf.strings.regex_replace(str_tensor, pattern, replacement)  
    tokens = []
    for idx, word in enumerate(str_tensor):
        word = word.numpy().decode('utf-8')
        bpe_tokens = word.split()
        tokenization = [tokenizer_dict[token] for token in bpe_tokens]
        tokens += tokenization
    return tokens

def tokenize_corpus(corpus: 'list[str]', bpe_codes: 'OrderedDict[tuple[str, str], int]'):
    tokens = []
    for i, sentence in enumerate(corpus):
        print(i)
        tokens.append(tokenize(sentence, bpe_codes))
    return tokens
    
tokenized_corpus = [tokenize(sentances, bpe_codes)]
print(*tokenized_corpus, sep='\n')

102880


100%|██████████| 5000/5000 [03:07<00:00, 26.60it/s]


[93, 123, 205, 83, 77, 77, 257, 90, 21, 70, 74, 185, 393, 79, 97, 5069, 4738, 85, 18, 74, 116, 796, 196, 1164, 2109, 1003, 527, 70, 74, 68, 1, 1, 1, 51, 796, 196, 507, 21, 166, 2885, 274, 18, 74, 120, 1923, 425, 261, 18, 74, 107, 1254, 507, 21, 85, 18, 74, 107, 430, 21, 85, 2509, 113, 18, 74, 339, 715, 233, 18, 74, 145, 507, 549, 430, 21, 85, 699, 18, 74, 95, 220, 620, 18, 74, 4439, 113, 18, 74, 107, 3581, 613, 1923, 1254, 1923, 4982, 18, 74, 120, 107, 163, 310, 100, 207, 1923, 18, 74, 2187, 3811, 70, 74, 68, 1, 1, 1, 51, 741, 252, 116, 525, 1970, 16, 74, 18, 74, 120, 3966, 1164, 1529, 20, 1653, 18, 74, 120, 1923, 4982, 2668, 856, 125, 18, 74, 2045, 2668, 113, 32, 163, 864, 507, 21, 296, 330, 18, 74, 1003, 107, 1079, 4231, 2984, 2435, 113, 17, 1936, 1867, 861, 18, 74, 113, 141, 4439, 3826, 70, 74, 68, 1, 1, 1, 51, 116, 1007, 2193, 113, 120, 887, 18, 74, 1197, 102, 393, 2292, 1003, 49, 331, 415, 507, 21, 817, 215, 18, 74, 4, 507, 21, 273, 708, 589, 126, 2112, 215, 18, 74, 3679, 18, 74, 

# Embedding

In [40]:
inverse_vocab = {index: token for token, index in tokenizer_dict.items()}

In [47]:
window_size = 2
vocab_size = len(tokenizer_dict)
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      tokenized_corpus[0],
      vocabulary_size=vocab_size,
      window_size=window_size,
      negative_samples=0)
positive_skip_grams

[[74, 513],
 [74, 68],
 [51, 278],
 [72, 1],
 [51, 193],
 [261, 2205],
 [1, 51],
 [113, 16],
 [406, 1492],
 [729, 51],
 [70, 68],
 [501, 967],
 [68, 70],
 [1709, 3279],
 [123, 58],
 [668, 6],
 [120, 21],
 [17, 18],
 [1, 507],
 [2515, 3586],
 [21, 359],
 [110, 186],
 [74, 18],
 [74, 1032],
 [507, 3263],
 [1, 74],
 [2795, 74],
 [1, 1],
 [1137, 346],
 [521, 51],
 [68, 1],
 [1, 68],
 [224, 788],
 [70, 68],
 [801, 865],
 [1036, 1054],
 [113, 74],
 [4666, 72],
 [70, 51],
 [74, 3119],
 [74, 1],
 [72, 141],
 [1, 1],
 [675, 756],
 [74, 18],
 [1, 74],
 [3670, 18],
 [74, 18],
 [141, 18],
 [3805, 74],
 [74, 4592],
 [74, 18],
 [149, 17],
 [521, 51],
 [11, 395],
 [1, 1058],
 [51, 1],
 [1289, 1928],
 [74, 70],
 [51, 1],
 [864, 113],
 [18, 74],
 [51, 74],
 [18, 149],
 [1051, 120],
 [51, 1],
 [663, 207],
 [549, 70],
 [51, 1],
 [1, 1],
 [51, 1],
 [114, 44],
 [51, 70],
 [16, 2383],
 [2097, 21],
 [74, 222],
 [4738, 18],
 [507, 21],
 [1, 68],
 [68, 74],
 [66, 16],
 [171, 16],
 [74, 892],
 [1, 1],
 [74, 1],

In [48]:
print([inverse_vocab[t] for t in tokenized_corpus[0]])

['d</w>', 'o</w>', 'm</w>', 'a</w>', 'r</w>', 'r</w>', 'g</w>', 'l</w>', 'r', '.', '</w>', 'nå', 'gr', 'e</w>', 'al', 'mennelige</w>', 'regl', 'er</w>', ',', '</w>', 'ther</w>', 'domare</w>', 'skal</w>', 'sigh</w>', 'aldeles</w>', 'effter</w>', 'rätta</w>', '.', '</w>', '[', 'utsökas</w>', 'utsökas</w>', 'utsökas</w>', ']', 'domare</w>', 'skal</w>', 'fö', 'r', 'st</w>', 'besi', 'nna</w>', ',', '</w>', 'at</w>', 'gudz</w>', 'befalnings', 'man</w>', ',', '</w>', 'thet</w>', 'embete</w>', 'fö', 'r', 'er</w>', ',', '</w>', 'thet</w>', 'hö', 'r', 'er</w>', 'gudh</w>', 'til</w>', ',', '</w>', 'siel', 'ffu', 'om</w>', ',', '</w>', 'ther', 'fö', 're</w>', 'hö', 'r', 'er</w>', 'domen</w>', ',', '</w>', 'af', 'sä', 'ger</w>', ',', '</w>', 'gudhi</w>', 'til</w>', ',', '</w>', 'thet</w>', 'afsagd</w>', 'warder</w>', 'gudz</w>', 'embete</w>', 'gudz</w>', 'wegna</w>', ',', '</w>', 'at</w>', 'thet</w>', 'wi', 'ss', 'er', 'liga</w>', 'gudz</w>', ',', '</w>', 'menniski', 'ors</w>', '.', '</w>', '[', 'u

In [49]:
for target, context in positive_skip_grams[:5]:
  print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")

(74, 513): (</w>, fader</w>)
(74, 68): (</w>, [)
(51, 278): (], sagdt</w>)
(72, 1): (;, utsökas</w>)
(51, 193): (], någor</w>)


In [50]:
# Get target and context words for one positive skip-gram.
target_word, context_word = positive_skip_grams[0]

# Set the number of negative samples per positive context.
num_ns = 4

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class,  # class that should be sampled as 'positive'
    num_true=1,  # each positive skip-gram has 1 positive context class
    num_sampled=num_ns,  # number of negative context words to sample
    unique=True,  # all the negative samples should be unique
    range_max=vocab_size,  # pick index of the samples from [0, vocab_size]
    # seed=SEED,  # seed for reproducibility
    name="negative_sampling"  # name of this operation
)
print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])

tf.Tensor([  1   6 803 116], shape=(4,), dtype=int64)
['utsökas</w>', 'a', 'lt</w>', 'ther</w>']


In [51]:
# Reduce a dimension so you can use concatenation (in the next step).
squeezed_context_class = tf.squeeze(context_class, 1)

# Concatenate a positive context word with negative sampled words.
context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)

# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64")
target = target_word

In [52]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 74
target_word     : </w>
context_indices : [513   1   6 803 116]
context_words   : ['fader</w>', 'utsökas</w>', 'a', 'lt</w>', 'ther</w>']
label           : [1 0 0 0 0]


In [53]:
print("target  :", target)
print("context :", context)
print("label   :", label)

target  : 74
context : tf.Tensor([513   1   6 803 116], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int64)


In [54]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in sequences:

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in tqdm.tqdm(positive_skip_grams):
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [55]:
targets, contexts, labels = generate_training_data(tokenized_corpus, window_size=2, num_ns=4, vocab_size=vocab_size, seed=42)
targets, contexts, labels = np.array(targets), np.array(contexts), np.array(labels)
print('\n')
print(f"targets.shape: {targets.shape}")
print(f"contexts.shape: {contexts.shape}")
print(f"labels.shape: {labels.shape}")

100%|██████████| 76044/76044 [00:54<00:00, 1392.82it/s]




targets.shape: (76044,)
contexts.shape: (76044, 5)
labels.shape: (76044, 5)


In [56]:
BATCH_SIZE = 32
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

<_PrefetchDataset element_spec=((TensorSpec(shape=(32,), dtype=tf.int32, name=None), TensorSpec(shape=(32, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(32, 5), dtype=tf.int64, name=None))>


In [58]:
class EmbeddingModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(EmbeddingModel, self).__init__()
    self.target_embedding = Embedding(vocab_size,
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding")
    self.context_embedding = Embedding(vocab_size,
                                       embedding_dim,
                                       input_length=num_ns+1)

  def call(self, pair):
    target, context = pair
    # target: (batch, dummy?)  # The dummy axis doesn't exist in TF2.7+
    # context: (batch, context)
    if len(target.shape) == 2:
      target = tf.squeeze(target, axis=1)
    # target: (batch,)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

In [59]:
embedding_dim = 32

model = EmbeddingModel(vocab_size, embedding_dim)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='accuracy', min_delta=0.005, patience=2)

model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit(dataset, epochs=100, callbacks=[tensorboard_callback, early_stopping_callback])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100


<keras.callbacks.History at 0x168f2da72d0>

In [None]:
model.summary()

Model: "embedding_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 w2v_embedding (Embedding)   multiple                  165792    
                                                                 
 embedding_1 (Embedding)     multiple                  165792    
                                                                 
Total params: 331,584
Trainable params: 331,584
Non-trainable params: 0
_________________________________________________________________


In [60]:
%tensorboard --logdir logs

In [61]:
weights = model.get_layer('w2v_embedding').get_weights()[0]
vocab = list(tokenizer_dict.keys())

In [63]:
import io
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
  if index == 0:
    continue  # skip 0, it's padding.
  vec = weights[index]
  # out_v.write(word + ',')
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()

# TODO
* <s>Load stopwords from file</s>
* <s>Remove stopwords from corpus and vocab</s>
* <s>Use entire law corpus</s>
* <s>Improve BPE performance</s>
* Use Wikipedia corpus
* Create embedder from weights