Here we will train a tokeniser using our break-and-stitch strategy. API-wise a tokeniser is a function of form `(Text) -> List[Interval[Text]]` (where `Interval` is defined in `scilk.util.intervals`), that is for any text X it returns an ordered sequence of nonoverlapping ranges corresponding to individual tokens.

In [None]:
from typing import Sequence, Iterable, Iterator, Tuple, List, Mapping, Any, Optional, Callable, TypeVar
from itertools import dropwhile, groupby, chain, starmap
from functools import reduce
from collections import Counter
from math import ceil
import operator as op
import re
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from sklearn import metrics
import joblib
from fn import F
from fn.iters import splitby, droplast
from binpacking import to_constant_bin_number
from keras import layers, models

from scilk.corpora import chemdner, corpus
from scilk.collections import common
from scilk.util import preprocessing, binning, segments
from scilk.util.intervals import Interval
from scilk.util.patterns import ptokenise, ptransform
from scilk.util.networks import wrappers, blocks, callbacks

T = TypeVar('T')

#### Parse and encode the corpus

In [None]:
chemnder_train = corpus.flatten_abstracts(
    chemdner.parse('data/chemdner_corpus/training.abstracts.txt',
                   'data/chemdner_corpus/training.annotations.txt')
)
chemnder_dev = corpus.flatten_abstracts(
    chemdner.parse('data/chemdner_corpus/development.abstracts.txt',
                   'data/chemdner_corpus/development.annotations.txt')
)

texts, annotations, borders = zip(*chemnder_train+chemnder_dev)

In [None]:
maxcharval, charmap, text_encoder = common.build_charencoder(''.join(texts))

In [None]:
# this is an extra-fine primary tokenisation (i.e. the "break" step)
tokeniser = F(ptokenise, [re.compile('\w+|[^\s\w]')])
tokens = (F(map, tokeniser) >> list)(texts)

# find the stitchpoints
token_stitches = list(starmap(segments.stitchpoints, zip(tokens, annotations)))

# since we are not going to stitch words on basic white-space characters, we filter these points out
def filter_stitches(text, points) -> List[int]:
    return [point for point in points if ' ' not in text[point:point+2]]

filtered_stitches = [filter_stitches(text, points) for text, points in zip(texts, token_stitches)]

In [None]:
# prepare the dataset for stateful-learning using reversible binpacking 
chunksize = 256
batchsize = 64

encoded_texts, encoded_stitches = (
    F(map, lambda text, points: (text_encoder(text), encode_annotation(len(text), points))) >> 
    (map, F(map, reverse)) >> 
    (lambda x: zip(*x)) >>
    (map, np.array)
)(texts, filtered_stitches)

In [None]:
# separate training and validation data

indices = np.arange(len(texts))
np.random.shuffle(indices)

trainsplit = 0.9
train_split = indices[:int(len(indices)*trainsplit)]
val_split = indices[int(len(indices)*trainsplit):]

In [None]:
# encode the datasets

texts_train, anno_train = [arr[train_split] for arr in [encoded_texts, encoded_stitches]]
texts_val, anno_val = [arr[val_split] for arr in [encoded_texts, encoded_stitches]]

bins_train, bins_val = map(F(binning.binpack, batchsize, len), [texts_train, texts_val])

x_train, y_train = map(F(binning.merge_bins, bins=bins_train), [texts_train, anno_train])
x_val, y_val = map(F(binning.merge_bins, bins=bins_val), [texts_val, anno_val])

x_batches_train, y_batches_train = map(F(preprocessing.chunksteps, chunksize), [x_train, y_train])
x_batches_val, y_batches_val = map(F(preprocessing.chunksteps, chunksize), [x_val, y_val])

In [None]:
x_batches_train.shape, y_batches_train.shape

#### Build and train the network, save the results.

In [None]:
inputs = layers.Input(batch_shape=(batchsize, chunksize))
embeddings = layers.Embedding(maxcharval+1, 32)(inputs)
l_cnn = blocks.cnn([256, 256], 3, [0.3, None])
l_rnn1 = wrappers.HalfStatefulBidirectional(
    layers.GRU(16, stateful=True, dropout=0.3, recurrent_dropout=0.3, return_sequences=True))
l_rnn2 = wrappers.HalfStatefulBidirectional(
    layers.GRU(16, stateful=True, dropout=0.3, recurrent_dropout=0.3, return_sequences=True))
l_rnn3 = wrappers.HalfStatefulBidirectional(
    layers.GRU(16, stateful=True, dropout=0.3, recurrent_dropout=0.3, return_sequences=True))

# Keras' Dense layers now behave like TimeDistributed layers when applied to sequential inputs
labels = layers.Dense(1, activation='sigmoid')(
    reduce(lambda graph, layer: layer(graph), [l_cnn, l_rnn1, l_rnn2, l_rnn3], embeddings)
)

model = models.Model(inputs, labels)
model.compile(optimizer='Adam', loss='binary_crossentropy')

In [None]:
# since our networks have stateful layers, we need to reset them between epochs.
callables = [(lambda _: l_rnn1.reset_states()), 
             (lambda _: l_rnn2.reset_states()), 
             (lambda _: l_rnn3.reset_states())]
resetter = callbacks.Caller({'on_epoch_begin': callables, 'on_epoch_end': callables})

# prepare the validators; the validators will keep track of the F1 scores on the validation dataset 
# and save model weights upon performance improvements
scores = {'precision': F(metrics.precision_score, average='binary', labels=[1]),
          'recall': F(metrics.recall_score, average='binary', labels=[1]),
          'f1': F(metrics.f1_score, average='binary', labels=[1])}

! mkdir -p trainlogs
logfile = open('trainlogs/stitches-text.log', 'w')
validator = callbacks.Validator([np.vstack(x_batches_val)], np.vstack(y_batches_val).flatten(), 
                                batchsize, scores, lambda pred: (np.vstack(pred) > 0.5).astype(int).flatten(), 
                                'precision', prefix=f'trainlogs/stitches-text', stream=logfile)

In [None]:
model.fit(np.vstack(x_batches_train), np.vstack(y_batches_train)[:,:,None],
          batch_size=batchsize, shuffle=False, epochs=45, callbacks=[resetter, validator])