In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
sys.path.append("python_lib")

In [3]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from sklearn import metrics
from tqdm.auto import tqdm
import nnsplit
from nnsplit import train, utils, models, tokenizer

In [4]:
cache_dir = Path("cache")
cache_dir.mkdir(exist_ok=True)

# Prepare data

## German

In [None]:
paragraphs = train.xml_to_paragraphs("train_data/dewiki-20180920-corpus.xml", max_n_paragraphs=3_000_000)

In [None]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("de")

In [None]:
with open(cache_dir / "de_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

## English

In [None]:
paragraphs = train.xml_to_paragraphs("train_data/enwiki-20181001-corpus.xml", max_n_paragraphs=3_000_000)

In [None]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("en")

In [None]:
with open(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "wb") as f:
    for x in tokenizer.split(paragraphs, verbose=True):
        f.write(pickle.dumps(x))

# Train model (german)

In [5]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "de_data" / "tokenized_paragraphs.pkl")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='GLOBAL', whitespace=' '), Token(text='_set_float', whitespace=''), Token(text='Extend', whitespace=''), Token(text=';', whitespace=' '), Token(text=';', whitespace=' '), Token(text='Sprunglabel', whitespace=' '), Token(text='global', whitespace=' '), Token(text='sichtbar', whitespace=' '), Token(text='_set_float', whitespace=''), Token(text='Extend', whitespace=''), Token(text=':', whitespace=' '), Token(text=';', whitespace=' '), Token(text='Sprunglabel', whitespace=' '), Token(text='angeben', whitespace=''), Token(text=',', whitespace=' '), Token(text='das', whitespace=' '), Token(text='ist', whitespace=' '), Token(text='der', whitespace=' '), Token(text='Name', whitespace=' '), Token(text='des', whitespace=' '), Token(text='Unterprogramms', whitespace=''), Token(text=',', whitespace=' '), Token(text=';', whitespace=' '), Token(text='aus', whitespace=' '), Token(text='C', whitespace=' '), Token(text='ohne', whitespace=

In [8]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)

In [8]:
de_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=10)

epoch,train_loss,valid_loss,time
0,0.011943,0.011528,32:33
1,0.008184,0.008171,32:34
2,0.007726,0.007424,32:33
3,0.006705,0.006933,32:32
4,0.006225,0.006536,32:38
5,0.006061,0.006045,32:39
6,0.005765,0.005611,32:43
7,0.005283,0.005218,32:44
8,0.00442,0.004872,32:45
9,0.004273,0.004806,32:46


In [9]:
torch.save(de_model.state_dict(), cache_dir / "de_data" / "model.pt")

In [10]:
utils.store_model(de_model, "data/de")

  return h5py.File(h5file)


## Evaluate

In [6]:
de_model = models.Network()
de_model.load_state_dict(torch.load(cache_dir / "de_data" / "model.pt"))

<All keys matched successfully>

In [9]:
train.evaluate(de_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=586.0), HTML(value='')))


Target: Tokenize 

F1: 0.9991633615400703
Precision: 0.9991012479617996
Recall: 0.9992254828419558



Target: Sentencize 

F1: 0.9788622524288
Precision: 0.9695062375662943
Recall: 0.988400603371718





# Train model (english)

In [5]:
sentences, labels = train.prepare_tokenized_paragraphs(cache_dir / "en_data" / "tokenized_paragraphs.pkl", "en", subsample=2_500_000)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='NET', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C', whitespace=''), Token(text='#', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text=')', whitespace=' '), Token(text='C++', whitespace=' '), Token(text='(', whitespace=''), Token(text='via', whitespace=' '), Token(text='the', whitespace=' '), Token(text='library', whitespace=' '), Token(text='and', whitespace=' '), Token(text=')', whitespace=' '), Token(text='D', whitespace=' '), Token(text='('

Faulty paragraph:
[[Token(text='.', whitespace='')], [Token(text='Burhanpur', whitespace=' '), Token(text='is', whitespace=' '), Token(text='a', whitespace=' '), Token(text='mid-sized', whitespace=' '), Token(text='historical', whitespace=' '), Token(text='city', whitespace=' '), Token(text='in', whitespace=' '), Token(text='the', whitespace=' '), Token(text='Nimar', whitespace=' '), Token(text='region', whitespace=' '), Token(text='of', whitespace=' '), Token(text='Madhya', whitespace=' '), Token(text='Pradesh', whitespace=' '), Token(text='state', whitespace=''), Token(text=',', whitespace=' '), Token(text='India', whitespace=''), Token(text='.', whitespace=' ')], [Token(text='It', whitespace=' '), Token(text='is', whitespace=' '), Token(text='the', whitespace=' '), Token(text='administrative', whitespace=' '), Token(text='seat', whitespace=' '), Token(text='of', whitespace=' '), Token(text='Burhanpur', whitespace=' '), Token(text='District', whitespace=''), Token(text='.', whitespac

In [6]:
x_train, x_valid, y_train, y_valid = train_test_split(sentences, labels, test_size=0.1, random_state=1234)

In [7]:
en_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=10)

epoch,train_loss,valid_loss,time
0,0.017559,0.016862,28:39
1,0.012065,0.012335,28:40
2,0.011233,0.01134,28:44
3,0.010577,0.010653,28:45
4,0.009801,0.010193,28:46
5,0.009156,0.009462,28:48
6,0.008747,0.008899,28:49
7,0.008379,0.008362,28:50
8,0.008053,0.007935,28:50
9,0.006977,0.00786,28:50


In [8]:
torch.save(en_model.state_dict(), cache_dir / "en_data" / "model.pt")

In [9]:
utils.store_model(en_model, "data/en")

  return h5py.File(h5file)


## Evaluate

In [10]:
en_model = models.Network()
en_model.load_state_dict(torch.load(cache_dir / "en_data" / "model.pt"))

<All keys matched successfully>

In [11]:
train.evaluate(en_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=489.0), HTML(value='')))


Target: Tokenize 

F1: 0.9991483749429025
Precision: 0.9991111376555529
Recall: 0.9991856150060541



Target: Sentencize 

F1: 0.9639170595081528
Precision: 0.9484387271613143
Recall: 0.9799089801564257





In [None]:
quantized_model = torch.quantization.quantize_dynamic(en_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

# PR curve

In [13]:
import matplotlib.pyplot as plt

In [14]:
tokenizer = nnsplit.tokenizer.SoMaJoTokenizer("de")
paragraphs = train.xml_to_paragraphs("train_data/dewiki-20180920-corpus.xml", max_n_paragraphs=1000)

tokenized_ps = list(tokenizer.split(paragraphs, verbose=True))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [15]:
texts, labels = [], []

for p in tokenized_ps:
    text, label = utils.label_tokens(p)
    
    texts.append(text)
    labels.append(label)

In [16]:
model = utils.load_model("data/de")

In [56]:
splitter = nnsplit.NNSplit(model, stride=250)

preds, all_idx, n_cuts_per_text = splitter._get_raw_preds(texts, batch_size=1024)
avg_preds = [x[splitter.start_padding:] for x in splitter._average_preds(texts, preds, all_idx, n_cuts_per_text)]

In [57]:
scores = np.array([metrics.precision_score(x[:, 1], y[:, 1] > 0.5) for x, y in zip(labels, avg_preds)])
scores = scores[scores > 0]

  _warn_prf(average, modifier, msg_start, len(result))


In [58]:
np.argsort(scores)

array([ 57, 151, 817, 335, ..., 329, 330, 318, 475])

In [20]:
flat_labels = np.concatenate(labels)
flat_preds = np.concatenate(avg_preds)

In [24]:
precision, recall, thresholds = metrics.precision_recall_curve(flat_labels[:, 1], flat_preds[:, 1])

# Test

In [12]:
from nnsplit import NNSplit

In [13]:
model = utils.load_model("data/en")

In [28]:
splitter = NNSplit(model, threshold=0.1)
splitter.split(["Fast, robust sentence splitting with Javascript, Rust and Python bindings Punctuation is not necessary to split sentences correctly sometimes even incorrect case is split correctly."])

[[[Token(text='Fast', whitespace=''),
   Token(text=',', whitespace=' '),
   Token(text='robust', whitespace=' '),
   Token(text='sentence', whitespace=' '),
   Token(text='splitting', whitespace=' '),
   Token(text='with', whitespace=' '),
   Token(text='Javascript', whitespace=''),
   Token(text=',', whitespace=' '),
   Token(text='Rust', whitespace=' '),
   Token(text='and', whitespace=' '),
   Token(text='Python', whitespace=' '),
   Token(text='bindings', whitespace=' ')],
  [Token(text='Punctuation', whitespace=' '),
   Token(text='is', whitespace=' '),
   Token(text='not', whitespace=' '),
   Token(text='necessary', whitespace=' '),
   Token(text='to', whitespace=' '),
   Token(text='split', whitespace=' '),
   Token(text='sentences', whitespace=' '),
   Token(text='correctly', whitespace=' ')],
  [Token(text='sometimes', whitespace=' '),
   Token(text='even', whitespace=' '),
   Token(text='incorrect', whitespace=' '),
   Token(text='case', whitespace=' '),
   Token(text='is', 