In [40]:
from sklearn.metrics import classification_report, r2_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_validate
from sklearn.model_selection import GridSearchCV
from sklearn_crfsuite import metrics
import torch
import torch.nn as nn

from torchcrf import CRF
import numpy as np
import json

from torch_model_base import TorchModelBase
from torch_rnn_classifier import TorchRNNDataset, TorchRNNClassifier, TorchRNNModel
import utils

In [41]:
with open('annotations2.jsonl') as jsonl_file:
    # note: after running data-preprocessing.ipynb this file already has token-level labels
    lines = jsonl_file.readlines()
annot = [json.loads(line) for line in lines]

In [42]:
# now get data into format that TorchRNN expects:
X=[] 
y=[]
for j in range(0,len(annot)):
    a = annot[j]['tokens']
    auxX = []
    auxy = []
    if annot[j]['spans']!=[]: # are there annot for this example?
        for i in range(0,len(a)):
            #token_element = (a[i]['text'],a[i]['label'])
            auxX.append(a[i]['text'])
            auxy.append(a[i]['label'])
        X.append(auxX)
        y.append(auxy)
#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
X_train, X_test, y_train, y_test = X[:120], X[120:], y[:120], y[120:]
vocab = sorted({w for seq in X_train for w in seq}) + ["$UNK"]

In [43]:
# reload modules
import torch_rnn_classifier, torch_model_base
import importlib
importlib.reload(torch_model_base)
importlib.reload(torch_rnn_classifier)
from torch_model_base import TorchModelBase
from torch_rnn_classifier import TorchRNNClassifier, TorchRNNModel, TorchRNNDataset

In [44]:
class TorchCRFSequenceLabeler_3(TorchRNNClassifier):

    def __init__(self,             
            vocab,
            hidden_dim=50,
            embedding=None,
            use_embedding=True,
            embed_dim=50,
            rnn_cell_class=nn.LSTM,
            bidirectional=True,
            freeze_embedding=False,
            classifier_activation=nn.ReLU(),
            **base_kwargs):   
        self.vocab = vocab
        self.hidden_dim = hidden_dim
        self.embedding = embedding
        self.use_embedding = use_embedding
        self.embed_dim = embed_dim
        self.rnn_cell_class = rnn_cell_class
        self.bidirectional = bidirectional
        self.freeze_embedding = freeze_embedding
        self.classifier_activation = classifier_activation
        super().__init__(vocab,**base_kwargs)
        self.params += [
            'hidden_dim',
            'embed_dim',
            'embedding',
            'use_embedding',
            'rnn_cell_class',
            'bidirectional',
            'freeze_embedding',
            'classifier_activation']
        self.loss = lambda x:x
        if self.bidirectional:
            self.classifier_dim = self.hidden_dim * 2
        else:
            self.classifier_dim = self.hidden_dim
       # self.classifier_layer = nn.Linear(
       #     self.classifier_dim, self.n_classes_)

       
    def build_graph(self): # uses this build_graph instead of TorchRNNClassifier.build_graph
       # print("here0")
        rnn = TorchRNNModel(
            vocab_size=len(self.vocab),
            embedding=self.embedding,
            use_embedding=self.use_embedding,
            embed_dim=self.embed_dim,
            rnn_cell_class=self.rnn_cell_class,
            hidden_dim=self.hidden_dim,
            bidirectional=self.bidirectional,
            freeze_embedding=self.freeze_embedding)
      #  print("here02")
        model = TorchSequenceLabeler_forCRF_3( # this defines self.model
            rnn=rnn,
            output_dim=self.n_classes_)
      #  print("here002")
        self.embed_dim = rnn.embed_dim
        self.rnn = rnn
        return model

    def build_dataset(self, X, y=None):
        X, seq_lengths = self._prepare_sequences(X) # converts tokens into tokenIds
        if y is None:
            return TorchRNNDataset(X, seq_lengths)
        else:
            # These are the changes from a regular classifier. All
            # concern the fact that our labels are sequences of labels.
            self.classes_ = sorted({x for seq in y for x in seq})
            self.n_classes_ = len(self.classes_)
            class2index = dict(zip(self.classes_, range(self.n_classes_)))
            #class2index = dict(zip(self.classes_, range(2,2+self.n_classes_)))
            #class2index[STOP_TAG]=0    # add start and stop tags (note: stop needs to be 0 as that is default for padding in collate_fn)
            #class2index[START_TAG]=1 
            # `y` is a list of tensors of different length. Our Dataset
            # class will turn it into a padding tensor for processing.
            y = [torch.tensor([class2index[label] for label in seq])
                 for seq in y] # converts labels to indices
            return TorchRNNDataset(X, seq_lengths, y)

    def predict(self, X): # for CRF-RNN X are logits from RNN
       # probs = self.predict_proba(X)
       # return [[self.classes_[i] for i in seq.argmax(axis=1)] for seq in probs] # seq.argmax(axis=1) gives index of col that maximizes softmax prob
        device = "cuda" if torch.cuda.is_available() else "cpu"
        seq_lengths = [len(ex) for ex in X]
        preds = self._predict(X)     
        mask=self.create_mask(seq_lengths).to(device, non_blocking=True) # creates mask matrix (1s are obs used in CRF; 0s are discarded)  
        tag_seq = self.crf.decode(preds,mask=mask) # note: X is (nExs,maxTokLen) and here input must be (nExs,maxTokLen,nDistinctTags); out is optimal seq of tagIds
        return [[self.classes_[i] for i in seq] for seq in tag_seq] 
        # see difference vs TorchRNNClassifier.predict
    
    def score(self, X, y):
        preds = self.predict(X)
        flat_preds = [x for seq in preds for x in seq]
        flat_y = [x for seq in y for x in seq]
        return utils.safe_macro_f1(flat_y, flat_preds)  
    
    def nClasses(self):
        return len(self.classes_)
    
    def classes(self):
        return self.classes_
    
    def create_mask(self, seq_length):
        maxLen=max(seq_length)
        auxLen=len(seq_length)
        auxOne = torch.ones(maxLen)
        auxZero = torch.zeros(maxLen)
        auxOne_l=[1]*maxLen
        auxZero_l=[0]*maxLen
        auxMatrix=[]
        for i in range(auxLen):
            auxRow=auxOne_l[:seq_length[i]]+auxZero_l[seq_length[i]:]
            auxMatrix.append(auxRow)
        return torch.tensor(auxMatrix,dtype=torch.uint8)  

    
    def fit(self, *args):
        """
        Generic optimization method.

        Parameters
        ----------
        *args: list of objects
            We assume that the final element of args give the labels
            and all the preceding elements give the system inputs.
            For regular supervised learning, this is like (X, y), but
            we allow for models that might use multiple data structures
            for their inputs.

        Attributes
        ----------
        model: nn.Module or subclass thereof
            Set by `build_graph`. If `warm_start=True`, then this is
            initialized only by the first call to `fit`.

        optimizer: torch.optimizer.Optimizer
            Set by `build_optimizer`. If `warm_start=True`, then this is
            initialized only by the first call to `fit`.

        errors: list of float
            List of errors. If `warm_start=True`, then this is
            initialized only by the first call to `fit`. Thus, where
            `max_iter=5`, if we call `fit` twice with `warm_start=True`,
            then `errors` will end up with 10 floats in it.

        validation_scores: list
            List of scores. This is filled only if `early_stopping=True`.
            If `warm_start=True`, then this is initialized only by the
            first call to `fit`. Thus, where `max_iter=5`, if we call
            `fit` twice with `warm_start=True`, then `validation_scores`
            will end up with 10 floats in it.

        no_improvement_count: int
            Used to control early stopping and convergence. These values
            are controlled by `_update_no_improvement_count_early_stopping`
            or `_update_no_improvement_count_errors`.  If `warm_start=True`,
            then this is initialized only by the first call to `fit`. Thus,
            in that situation, the values could accumulate across calls to
            `fit`.

        best_error: float
           Used to control convergence. Smaller is assumed to be better.
           If `warm_start=True`, then this is initialized only by the first
           call to `fit`. It will be reset by
           `_update_no_improvement_count_errors` depending on how the
           optimization is proceeding.

        best_score: float
           Used to control early stopping. If `warm_start=True`, then this
           is initialized only by the first call to `fit`. It will be reset
           by `_update_no_improvement_count_early_stopping` depending on how
           the optimization is proceeding. Important: we currently assume
           that larger scores are better. As a result, we will not get the
           correct results for, e.g., a scoring function based in
           `mean_squared_error`. See `self.score` for additional details.

        best_parameters: dict
            This is a PyTorch state dict. It is used if and only if
            `early_stopping=True`. In that case, it is updated whenever
            `best_score` is improved numerically. If the early stopping
            criteria are met, then `self.model` is reset to contain these
            parameters before `fit` exits.

        Returns
        -------
        self

        """
      #  print("here00")
        if self.early_stopping:
            args, dev = self._build_validation_split(
                *args, validation_fraction=self.validation_fraction)
            

        # Dataset:
        dataset = self.build_dataset(*args)
        dataloader = self._build_dataloader(dataset, shuffle=True)

        # Graph:
        if not self.warm_start or not hasattr(self, "model"):
            self.model = self.build_graph()
            # This device move has to happen before the optimizer is built:
            # https://pytorch.org/docs/master/optim.html#constructing-it
            self.model.to(self.device)
            self.optimizer = self.build_optimizer()
            self.errors = []
            self.validation_scores = []
            self.no_improvement_count = 0
            self.best_error = np.inf
            self.best_score = -np.inf
            self.best_parameters = None

        # Make sure the model is where we want it:
        self.model.to(self.device)

        self.model.train()
        self.optimizer.zero_grad()
        
        self.crf = CRF(self.n_classes_,batch_first=True).to(self.device, non_blocking=True)

        for iteration in range(1, self.max_iter+1):

            epoch_error = 0.0

            for batch_num, batch in enumerate(dataloader, start=1):
               # print("batch"+str(batch_num)) 

               # print(batch)
                batch = [x.to(self.device, non_blocking=True) for x in batch]

                X_batch = batch[: -1] # list w/ 2 els: 1st el is tensor (108xmaxLen) w/ tokens for each example in batch; 2nd el is (108x1) with lengths of each example
                y_batch = batch[-1] # list with each element of this batch (108 el in list) with tensor (maxLen x 1) labels converted to ints and w/ len = maxLen of all example sequences # print(y_batch[0].shape)
               # print(X_batch[1].shape)
               # print(y_batch[0])
               
                batch_preds = self.model(*X_batch) # produces logits outputs of lstm
               # print("batch_preds")

               # print("here-model2")
                mask = (self.create_mask(X_batch[1])).to(self.device, non_blocking=True)
                #err = self.loss(batch_preds, y_batch) # batch_preds = (108,12,117); y_batch = (108,117)
                err = -self.crf(batch_preds,y_batch,mask=mask,reduction='mean') 
                # NOTE: self.crf outputs log likelihood so we multiply by (-1) so as to minimize this result

                if self.gradient_accumulation_steps > 1 and \
                  self.loss.reduction == "mean":
                    err /= self.gradient_accumulation_steps

                err.backward()

                epoch_error += err.item()

                if batch_num % self.gradient_accumulation_steps == 0 or \
                  batch_num == len(dataloader):
                    if self.max_grad_norm is not None:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.max_grad_norm)
                    # print("before")
                    # print(self.model.rnn.rnn.weight_ih_l0) # check if lstm weights are being updated
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    # print("after")
                    # print(self.model.rnn.rnn.weight_ih_l0) # check if lstm weights are being updated

            # Stopping criteria:

            if self.early_stopping:
                self._update_no_improvement_count_early_stopping(*dev) # here we max macro avg f1 score (on dev = validation set)
                if self.no_improvement_count > self.n_iter_no_change:
                    utils.progress_bar(
                        "Stopping after epoch {}. Validation score did "
                        "not improve by tol={} for more than {} epochs. "
                        "Final error is {}".format(iteration, self.tol,
                            self.n_iter_no_change, epoch_error),
                        verbose=self.display_progress)
                    break

            else:
                self._update_no_improvement_count_errors(epoch_error)
                if self.no_improvement_count > self.n_iter_no_change:
                    utils.progress_bar(
                        "Stopping after epoch {}. Training loss did "
                        "not improve more than tol={}. Final error "
                        "is {}.".format(iteration, self.tol, epoch_error),
                        verbose=self.display_progress)
                    break

            utils.progress_bar(
                "Finished epoch {} of {}; error is {}".format(
                    iteration, self.max_iter, epoch_error),
                verbose=self.display_progress)

        if self.early_stopping:
            self.model.load_state_dict(self.best_parameters)

        return self

In [45]:
class TorchSequenceLabeler_forCRF_3(nn.Module): # no self.hidden_layer or self.classifier_activation as TorchRNNClassifierModel
    def __init__(self, rnn, output_dim):
       # print("here021")
        super().__init__()
        self.rnn = rnn
        self.output_dim = output_dim
        if self.rnn.bidirectional:
            self.classifier_dim = self.rnn.hidden_dim * 2
        else:
            self.classifier_dim = self.rnn.hidden_dim
        self.classifier_layer = nn.Linear(
            self.classifier_dim, self.output_dim)
        

    def forward(self, X, seq_lengths): # X is (noExsInBatch,MaxLen)=(108,117), seq_lengths is the number of tokens in each example in each batch
        # this is the forward method of self.model
       # print("here2")
        outputs, state = self.rnn(X, seq_lengths) # X is (batchSize, maxLen of exs in batch); outputs is (noTokensInEx,hiddDim), state is ((batch_size,1,hiddDim),(batch_size,1,hiddDim)) = (finalHiddState,finalCellState) 
        outputs, seq_length = torch.nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True) # outputs is (batchSize,MaxLen of examples in batch,hidden_dim); seq_length is noTokenInEx for each ex in batch
        logits = self.classifier_layer(outputs) # this is an FCL from hidden_dim to output_dim (NoLabelClasses)
       # print(logits.shape)
        # logits are (108,117,12) or (1,11,5) = (batchSize,MaxLen of examples in batch,noLabelClasses) noLabelClasses include Start + End
        return logits  

In [46]:
seq_mod3 = TorchCRFSequenceLabeler_3(
    vocab,
    early_stopping=True,
    eta=0.001)

In [47]:
%time _ = seq_mod3.fit(X_train, y_train)

Stopping after epoch 18. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 106.95479583740234

Wall time: 3.18 s


In [48]:
y_pred = seq_mod3.predict(X_test)

In [49]:
classes = seq_mod3.classes()
print(metrics.flat_f1_score(y_test, y_pred,
                      average='weighted', labels=classes))
sorted_labels = sorted(
    classes,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

0.6513730510909207
                     precision    recall  f1-score   support

                  O      0.787     0.921     0.849       643
            KAEUFER      0.000     0.000     0.000        18
DATUM_VERBUECHERUNG      0.000     0.000     0.000        25
      DATUM_VERTRAG      0.000     0.000     0.000        27
         VERKAEUFER      0.000     0.000     0.000        24
   TERRASSENGROESSE      0.000     0.000     0.000         5
        GESAMTPREIS      0.056     0.091     0.069        11
            FLAECHE      0.000     0.000     0.000        15
           IMMO_TYP      0.000     0.000     0.000        19
            QMPREIS      0.000     0.000     0.000        10
                ORT      0.000     0.000     0.000        26
            STRASSE      0.000     0.000     0.000        16

           accuracy                          0.707       839
          macro avg      0.070     0.084     0.076       839
       weighted avg      0.604     0.707     0.651       839

