In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
sys.path.append("python_lib")

In [3]:
import torch
from torch import nn
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm_notebook as tqdm
import nnsplit
from nnsplit import train, utils, models
train.tqdm = tqdm

In [4]:
cache_dir = Path("cache")
cache_dir.mkdir(exist_ok=True)

# Prepare data

In [None]:
_ = train.prepare_data("train_data/dewiki-20180920-corpus.xml", "de", max_n_sentences=10_000_000, 
                       data_directory=cache_dir / "de_data")

In [None]:
_ = train.prepare_data("train_data/enwiki-20181001-corpus.xml", "en", max_n_sentences=10_000_000, 
                       data_directory=cache_dir / "en_data")

# Train model (german)

In [5]:
x = torch.load(cache_dir / "de_data/all_sentences.pt")
y = torch.load(cache_dir / "de_data/all_labels.pt")

In [6]:
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.1, random_state=1234)

In [7]:
de_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=15)

epoch,train_loss,valid_loss,time
0,0.034266,0.03627,17:59
1,0.028387,0.02738,17:56
2,0.025026,0.026409,17:55
3,0.026419,0.025348,17:55
4,0.023266,0.024745,17:56
5,0.023075,0.02419,17:56
6,0.023747,0.023252,17:57
7,0.021663,0.022639,17:56
8,0.021643,0.022049,17:57
9,0.022001,0.021097,17:57


In [8]:
torch.save(de_model, Path("cache/de_data") / "model.pt")

In [9]:
utils.store_model(de_model, "data/de")

  return h5py.File(h5file)


## Evaluate

In [8]:
de_model = torch.load(Path("cache/de_data") / "model.pt")

In [9]:
train.evaluate(de_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))


Target: Tokenize 

F1: 0.9985709185426798
Precision: 0.9977352694599613
Recall: 0.9994079685876922



Target: Sentencize 

F1: 0.9581387972891665
Precision: 0.9362705478411474
Recall: 0.9810530203414978





In [11]:
quantized_model = torch.quantization.quantize_dynamic(de_model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

In [12]:
train.evaluate(quantized_model, x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))


Target: Tokenize 

F1: 0.9985806618134485
Precision: 0.9977847594261771
Recall: 0.9993778349483429



Target: Sentencize 

F1: 0.9581120292762628
Precision: 0.9363237215520647
Recall: 0.9809385261100492





# Train model (english)

In [10]:
x = torch.load(cache_dir / "en_data/all_sentences.pt")
y = torch.load(cache_dir / "en_data/all_labels.pt")

In [11]:
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.1, random_state=1234)

In [12]:
en_model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=15)

epoch,train_loss,valid_loss,time
0,0.045592,0.045185,18:15
1,0.037289,0.036578,18:16
2,0.037222,0.035395,18:17
3,0.034971,0.034768,18:16
4,0.032515,0.034059,18:18
5,0.033572,0.033148,18:17
6,0.032469,0.032019,18:18
7,0.03008,0.031354,18:18
8,0.031937,0.030739,18:18
9,0.029535,0.029577,18:18


In [13]:
torch.save(en_model, Path("cache/en_data") / "model.pt")

In [14]:
utils.store_model(en_model, "data/en")

In [15]:
train.evaluate(en_model.cuda().half(), x_valid, y_valid)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))




RuntimeError: CUDA out of memory. Tried to allocate 156.00 MiB (GPU 0; 10.75 GiB total capacity; 64.64 MiB already allocated; 155.31 MiB free; 76.00 MiB reserved in total by PyTorch)

# Test

In [None]:
from nnsplit import NNSplit

In [None]:
splitter = NNSplit(utils.load_model("data/de").float())

In [None]:
splitter.split(["Das ist ein Test Das ist noch ein Test."])

In [None]:
splitter = NNSplit(utils.load_model("data/en").float())

In [None]:
splitter.split(["He is making test, exercises etc. and examples. This is another test."])

# Tune model

In [63]:
n = 1_000_000
x = torch.load(cache_dir / "de_data/all_sentences.pt")[:n]
y = torch.load(cache_dir / "de_data/all_labels.pt")[:n]

In [73]:
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.1)

In [78]:
model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=1)

epoch,train_loss,valid_loss,time
0,0.035032,0.035142,01:51


In [74]:
model = train.train(x_train, y_train, x_valid, y_valid, n_epochs=1)

epoch,train_loss,valid_loss,time
0,0.002735,0.002637,01:54


In [46]:
quantized_model = torch.quantization.quantize_dynamic(model.float().cpu(), {nn.LSTM, nn.Linear}, dtype=torch.qint8)

In [47]:
train.evaluate(quantized_model, x_valid, y_valid)


  0%|          | 0/98 [00:00<?, ?it/s][A
  1%|          | 1/98 [00:00<01:09,  1.41it/s][A
  2%|▏         | 2/98 [00:01<01:09,  1.39it/s][A
  3%|▎         | 3/98 [00:02<01:12,  1.32it/s][A
  4%|▍         | 4/98 [00:03<01:22,  1.14it/s][A
  5%|▌         | 5/98 [00:04<01:28,  1.06it/s][A
  6%|▌         | 6/98 [00:05<01:19,  1.16it/s][A
  7%|▋         | 7/98 [00:06<01:28,  1.03it/s][A
  8%|▊         | 8/98 [00:07<01:27,  1.03it/s][A
  9%|▉         | 9/98 [00:08<01:25,  1.04it/s][A
 10%|█         | 10/98 [00:09<01:27,  1.01it/s][A
 11%|█         | 11/98 [00:10<01:24,  1.03it/s][A
 12%|█▏        | 12/98 [00:11<01:19,  1.09it/s][A
 13%|█▎        | 13/98 [00:12<01:17,  1.10it/s][A
 14%|█▍        | 14/98 [00:13<01:23,  1.01it/s][A
 15%|█▌        | 15/98 [00:14<01:23,  1.00s/it][A
 16%|█▋        | 16/98 [00:15<01:28,  1.08s/it][A
 17%|█▋        | 17/98 [00:16<01:23,  1.03s/it][A
 18%|█▊        | 18/98 [00:17<01:23,  1.04s/it][A
 19%|█▉        | 19/98 [00:18<01:25,  1.08s/it]

Target: Tokenize 

F1: 0.9982945732838129
Precision: 0.9974475067007087
Recall: 0.9991430798056675



Target: Sentencize 

F1: 0.897711766558203
Precision: 0.8391659852820932
Recall: 0.9650393049234588





In [77]:
train.evaluate(model.cuda().half(), x_valid, y_valid)


  0%|          | 0/98 [00:00<?, ?it/s][A
 11%|█         | 11/98 [00:00<00:00, 109.64it/s][A
 22%|██▏       | 22/98 [00:00<00:00, 107.66it/s][A
 35%|███▍      | 34/98 [00:00<00:00, 109.46it/s][A
 49%|████▉     | 48/98 [00:00<00:00, 116.24it/s][A
 63%|██████▎   | 62/98 [00:00<00:00, 120.74it/s][A
 77%|███████▋  | 75/98 [00:00<00:00, 121.12it/s][A
100%|██████████| 98/98 [00:00<00:00, 122.83it/s][A


Target: Tokenize 

F1: 0.9973287289215611
Precision: 0.9961537390852487
Recall: 0.9985064938947623



Target: Sentencize 

F1: 0.9055118110236221
Precision: 0.8516719990905105
Recall: 0.9666181268548045



