In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
sys.path.append("python_lib")

In [3]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import nnsplit
from nnsplit import train, utils, models

In [4]:
cache_dir = Path("cache")
cache_dir.mkdir(exist_ok=True)

# Prepare data

## German

In [5]:
paragraphs = train.xml_to_paragraphs("train_data/dewiki-20180920-corpus.xml", max_n_paragraphs=3_000_000)

HBox(children=(FloatProgress(value=0.0, max=3000000.0), HTML(value='')))

In [6]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("de")

In [7]:
with open(cache_dir / "de_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

HBox(children=(FloatProgress(value=0.0, max=3000000.0), HTML(value='')))

## English

In [None]:
paragraphs = train.xml_to_paragraphs("train_data/enwiki-20181001-corpus.xml", max_n_paragraphs=3_000_000)

In [None]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("en")

In [None]:
with open(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

# Train model (german)

In [7]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "de_data" / "tokenized_paragraphs.pkl")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='GLOBAL', whitespace=' '), Token(text='_set_float', whitespace=''), Token(text='Extend', whitespace=''), Token(text=';', whitespace=' '), Token(text=';', whitespace=' '), Token(text='Sprunglabel', whitespace=' '), Token(text='global', whitespace=' '), Token(text='sichtbar', whitespace=' '), Token(text='_set_float', whitespace=''), Token(text='Extend', whitespace=''), Token(text=':', whitespace=' '), Token(text=';', whitespace=' '), Token(text='Sprunglabel', whitespace=' '), Token(text='angeben', whitespace=''), Token(text=',', whitespace=' '), Token(text='das', whitespace=' '), Token(text='ist', whitespace=' '), Token(text='der', whitespace=' '), Token(text='Name', whitespace=' '), Token(text='des', whitespace=' '), Token(text='Unterprogramms', whitespace=''), Token(text=',', whitespace=' '), Token(text=';', whitespace=' '), Token(text='aus', whitespace=' '), Token(text='C', whitespace=' '), Token(text='ohne', whitespace=

In [8]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)

In [18]:
de_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=15)

epoch,train_loss,valid_loss,time
0,0.03182,0.031392,02:04


In [8]:
torch.save(de_model, cache_dir / "de_data" / "model.pt")

In [9]:
utils.store_model(de_model, "data/de")

  return h5py.File(h5file)


## Evaluate

In [6]:
de_model = torch.load(cache_dir / "de_data" / "model.pt")

In [7]:
train.evaluate(de_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=1172.0), HTML(value='')))


Target: Tokenize 

F1: 0.9976920067068228
Precision: 0.9979479501565672
Recall: 0.997436194506916



Target: Sentencize 

F1: 0.9584361802317378
Precision: 0.9373894356306759
Recall: 0.9804497366219669





In [11]:
quantized_model = torch.quantization.quantize_dynamic(de_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

In [12]:
train.evaluate(quantized_model, x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))


Target: Tokenize 

F1: 0.9985806618134485
Precision: 0.9977847594261771
Recall: 0.9993778349483429



Target: Sentencize 

F1: 0.9581120292762628
Precision: 0.9363237215520647
Recall: 0.9809385261100492





# Train model (english)

In [8]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "en")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [9]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)

In [7]:
en_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=15)

epoch,train_loss,valid_loss,time
0,0.040443,0.038892,20:00
1,0.032624,0.031846,20:00
2,0.030703,0.030959,20:01
3,0.030231,0.030827,20:01
4,0.029231,0.030218,20:02
5,0.028824,0.029291,20:02
6,0.029228,0.028684,20:02
7,0.028699,0.027927,20:03
8,0.026893,0.027063,20:03
9,0.027493,0.026111,20:03


In [8]:
torch.save(en_model, cache_dir / "en_data" / "model.pt")

In [10]:
utils.store_model(en_model, "data/en")

  return h5py.File(h5file)


## Evaluate

In [10]:
en_model = torch.load(cache_dir / "en_data" / "model.pt")

In [11]:
train.evaluate(en_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=1172.0), HTML(value='')))


Target: Tokenize 

F1: 0.9978134907239008
Precision: 0.9981640633909579
Recall: 0.9974631642248826



Target: Sentencize 

F1: 0.9413314661724994
Precision: 0.9120622157846446
Recall: 0.9725415771060246





In [11]:
quantized_model = torch.quantization.quantize_dynamic(en_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

In [12]:
train.evaluate(quantized_model, x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))


Target: Tokenize 

F1: 0.9985806618134485
Precision: 0.9977847594261771
Recall: 0.9993778349483429



Target: Sentencize 

F1: 0.9581120292762628
Precision: 0.9363237215520647
Recall: 0.9809385261100492





# Test

In [7]:
from nnsplit import NNSplit

In [8]:
splitter = NNSplit(utils.load_model("data/de").float())

In [9]:
splitter.split(["Das ist ein Test Das ist noch ein Test."])

[[[Token(text='Das', whitespace=' '),
   Token(text='ist', whitespace=' '),
   Token(text='ein', whitespace=' '),
   Token(text='Test', whitespace=' ')],
  [Token(text='Das', whitespace=' '),
   Token(text='ist', whitespace=' '),
   Token(text='noch', whitespace=' '),
   Token(text='ein', whitespace=' '),
   Token(text='Test', whitespace=''),
   Token(text='.', whitespace='')]]]

In [10]:
splitter = NNSplit(utils.load_model("data/en").float())