In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
sys.path.append("python_lib")

In [3]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from sklearn import metrics
from tqdm.auto import tqdm
import nnsplit
from nnsplit import train, utils, models, tokenizer

In [4]:
cache_dir = Path("cache")
cache_dir.mkdir(exist_ok=True)

# Prepare data

## German

In [None]:
paragraphs = train.xml_to_paragraphs("train_data/dewiki-20180920-corpus.xml", max_n_paragraphs=3_000_000)

HBox(children=(FloatProgress(value=0.0, max=3000000.0), HTML(value='')))

In [41]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("de")

In [None]:
with open(cache_dir / "de_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

HBox(children=(FloatProgress(value=0.0, max=3000000.0), HTML(value='')))

## English

In [None]:
paragraphs = train.xml_to_paragraphs("train_data/enwiki-20181001-corpus.xml", max_n_paragraphs=3_000_000)

In [None]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("en")

In [None]:
with open(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

# Train model (german)

In [5]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "de_data" / "tokenized_paragraphs.pkl")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [6]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)
#x_train, x_valid, y_train, y_valid = torch.load("data.pth")

In [7]:
torch.save([x_train, x_valid, y_train, y_valid], "data.pth")

In [8]:
de_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=10)

epoch,train_loss,valid_loss,time
0,0.011943,0.011528,32:33
1,0.008184,0.008171,32:34
2,0.007726,0.007424,32:33
3,0.006705,0.006933,32:32
4,0.006225,0.006536,32:38
5,0.006061,0.006045,32:39
6,0.005765,0.005611,32:43
7,0.005283,0.005218,32:44
8,0.00442,0.004872,32:45
9,0.004273,0.004806,32:46


In [9]:
torch.save(de_model.state_dict(), cache_dir / "de_data" / "model.pt")

In [10]:
utils.store_model(de_model, "data/de")

  return h5py.File(h5file)


## Evaluate

In [11]:
de_model = models.Network()
de_model.load_state_dict(torch.load(cache_dir / "de_data" / "model.pt"))

<All keys matched successfully>

In [14]:
train.evaluate(de_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Target: Tokenize 

F1: 0.9983272884529866
Precision: 0.9981672707378787
Recall: 0.9984873574816878



Target: Sentencize 

F1: 0.9606947055137844
Precision: 0.9401852217352896
Recall: 0.9821189441509751





In [None]:
train.evaluate(de_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Target: Tokenize 

F1: 0.998631769380241


In [20]:
quantized_model = torch.quantization.quantize_dynamic(de_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

# Train model (english)

In [5]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "en")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='NET', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C', whitespace=''), Token(text='#', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C++', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text='and', whitespace=' '), Token(text=')', whitespace=' '), Token(text='D', whitespace=' '), Token(text='('

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='REG', whitespace=' '), Token(text='files', whitespace=' '), Token(text='(', whitespace=''), Token(text='also', whitespace=' '), Token(text='known', whitespace=' '), Token(text='as', whitespace=' '), Token(text='Registration', whitespace=' '), Token(text='entries', whitespace=''), Token(text=')', whitespace=' '), Token(text='are', whitespace=' '), Token(text='text', whitespace=''), Token(text='-', whitespace=''), Token(text='based', whitespace=' '), Token(text='human', whitespace=''), Token(text='-', whitespace=''), Token(text='readable', whitespace=' '), Token(text='files', whitespace=' '), Token(text='for', whitespace=' '), Token(text='exporting', whitespace=' '), Token(text='and', whitespace=' '), Token(text='importing', whitespace=' '), Token(text='portions', whitespace=' '), Token(text='of', whitespace=' '), Token(text='the', whitespace=' '), Token(text='registry', whitespace=''), Token(text='.', whitespace=' ')], [T

In [6]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)

In [7]:
en_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=20)

epoch,train_loss,valid_loss,time
0,0.033321,0.034246,23:42
1,0.028566,0.02898,23:32
2,0.0278,0.027489,20:09
3,0.026879,0.027608,19:10
4,0.027294,0.027869,19:10
5,0.02797,0.027195,21:57
6,0.025401,0.026654,19:10
7,0.026775,0.026406,19:09
8,0.025835,0.025802,19:10
9,0.023397,0.025574,19:09


In [8]:
torch.save(en_model.state_dict(), cache_dir / "en_data" / "model.pt")

In [6]:
utils.store_model(en_model, "data/en")

  return h5py.File(h5file)


## Evaluate

In [5]:
en_model = models.Network()
en_model.load_state_dict(torch.load(cache_dir / "en_data" / "model.pt"))

<All keys matched successfully>

In [12]:
train.evaluate(en_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=1172.0), HTML(value='')))


Target: Tokenize 

F1: 0.9977046875608959
Precision: 0.9981388983517799
Recall: 0.9972708543868518



Target: Sentencize 

F1: 0.9406308381250802
Precision: 0.9113823973956516
Recall: 0.9718188284629058





In [None]:
quantized_model = torch.quantization.quantize_dynamic(en_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

In [13]:
# train.evaluate(quantized_model, x_valid, y_valid)

# Tune mask

In [13]:
import matplotlib.pyplot as plt

In [14]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("de")
paragraphs = train.xml_to_paragraphs("train_data/dewiki-20180920-corpus.xml", max_n_paragraphs=1000)

tokenized_ps = list(tokenizer.split(paragraphs, verbose=True))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [15]:
texts, labels = [], []

for p in tokenized_ps:
    text, label = utils.label_tokens(p)
    
    texts.append(text)
    labels.append(label)

In [16]:
model = utils.load_model("data/de")

In [56]:
splitter = nnsplit.NNSplit(model, stride=250)

preds, all_idx, n_cuts_per_text = splitter._get_raw_preds(texts, batch_size=1024)
avg_preds = [x[splitter.start_padding:] for x in splitter._average_preds(texts, preds, all_idx, n_cuts_per_text)]

In [57]:
scores = np.array([metrics.precision_score(x[:, 1], y[:, 1] > 0.5) for x, y in zip(labels, avg_preds)])
scores = scores[scores > 0]

  _warn_prf(average, modifier, msg_start, len(result))


In [58]:
np.argsort(scores)

array([ 57, 151, 817, 335, ..., 329, 330, 318, 475])

In [20]:
flat_labels = np.concatenate(labels)
flat_preds = np.concatenate(avg_preds)

In [24]:
precision, recall, thresholds = metrics.precision_recall_curve(flat_labels[:, 1], flat_preds[:, 1])

# Test

In [62]:
from nnsplit import NNSplit

In [63]:
model = utils.load_model("data/de")

In [85]:
splitter = NNSplit(model, stride=20, threshold=0.1)
splitter.split(["Ich bin ein Baum er ist ein Baum."])

[[[Token(text='Ich', whitespace=' '),
   Token(text='bin', whitespace=' '),
   Token(text='ein', whitespace=' '),
   Token(text='Baum', whitespace=' ')],
  [Token(text='er', whitespace=' '),
   Token(text='ist', whitespace=' '),
   Token(text='ein', whitespace=' '),
   Token(text='Baum', whitespace=''),
   Token(text='.', whitespace='')]]]

In [159]:
splitter = NNSplit(model, stride=20, threshold=0.2)
splitter.split(["Fast, robust sentence splitting with bindings for Python, Rust and Javascript Punctuation is not necessary to split sentences correctly sometimes even incorrect case is split correctly."])

[[[Token(text='Fast', whitespace=''),
   Token(text=',', whitespace=' '),
   Token(text='robust', whitespace=' '),
   Token(text='sentence', whitespace=' '),
   Token(text='splitting', whitespace=' '),
   Token(text='with', whitespace=' '),
   Token(text='bindings', whitespace=' '),
   Token(text='for', whitespace=' '),
   Token(text='Python', whitespace=''),
   Token(text=',', whitespace=' '),
   Token(text='Rust', whitespace=' '),
   Token(text='and', whitespace=' '),
   Token(text='Javascript', whitespace=' ')],
  [Token(text='Punctuation', whitespace=' '),
   Token(text='is', whitespace=' '),
   Token(text='not', whitespace=' '),
   Token(text='necessary', whitespace=' '),
   Token(text='to', whitespace=' '),
   Token(text='split', whitespace=' '),
   Token(text='sentences', whitespace=' '),
   Token(text='correctly', whitespace=' ')],
  [Token(text='sometimes', whitespace=' '),
   Token(text='even', whitespace=' '),
   Token(text='incorrect', whitespace=' '),
   Token(text='case',