# End to end medical NER on NCBI-Disease dataset

We load NCBI dataset and perform medical mention detection on it (without the normalization of the mentions) using a token and character level neural network.

Every component (except the CRF) needed to perform the NER tagging is coded in this notebook.
Likewise, every preprocessing step is shown, and can be written by composing very few core functions into a few lines.
Finally, we convert back predictions to character spans to be able to display the annotated documents.

In [1]:
import numpy as np
import pandas as pd

## Data loading

In [2]:
from nlstruct.dataloaders.ncbi_disease import load_ncbi_disease
from sklearn.utils import check_random_state

# A Dataset object is just a boosted dict of pandas DataFrames
dataset = load_ncbi_disease()

# Split into train / test
test_split = 0.2
splits = (
    (["val"] * int(len(dataset['docs']) * test_split)) + 
    (["train"] * (len(dataset['docs']) - int(len(dataset['docs']) * test_split))))
check_random_state(42).shuffle(splits)
dataset['docs']['split'] = splits
train_dataset = dataset.query("split == 'train'", propagate=True)
test_dataset = dataset.query("split == 'val'", propagate=True)
{"train": train_dataset, "test": test_dataset}

Using cache /Users/perceval/Development/data/cache/ncbi_raw_files/4d8c0405832b0f7e


{'train': Dataset(
   (docs):       634 * ('doc_id', 'text', 'split')
   (mentions):  5590 * ('doc_id', 'mention_id', 'category')
   (labels):    5739 * ('doc_id', 'label_id', 'mention_id', 'label')
   (fragments): 5590 * ('doc_id', 'mention_id', 'begin', 'end', 'fragment_id')
 ),
 'test': Dataset(
   (docs):       158 * ('doc_id', 'text', 'split')
   (mentions):  1291 * ('doc_id', 'mention_id', 'category')
   (labels):    1320 * ('doc_id', 'label_id', 'mention_id', 'label')
   (fragments): 1291 * ('doc_id', 'mention_id', 'begin', 'end', 'fragment_id')
 )}

## Training preprocessing

In [3]:
docs, mentions, fragments = train_dataset[["docs", "mentions", "fragments"]].copy()

### Text transformations

In [4]:
from nlstruct.core.text import transform_text, apply_deltas
from nlstruct.core.cache import cached
import re
import string

# Define subs as ("pattern", "replacements") list
subs = [
    (r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "),
    (r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "),
    ("(?<=[a-zA-Z])(?=[0-9])", r" "),
    ("(?<=[0-9])(?=[A-Za-z])", r" "),
]
# Clean the text / perform substitutions
# `deltas` contains the character span shifts made by the substitutions
# we will reuse it at the end of the notebook to convert map predictions on input text
docs, deltas = cached(transform_text)(docs, *zip(*subs), return_deltas=True)

# Apply transformations to the spans
fragments = apply_deltas(fragments, deltas, on='doc_id')

Using cache /Users/perceval/Development/data/cache/nlstruct/core/text/transform_text/fb53b7431d13db40
Loading /Users/perceval/Development/data/cache/nlstruct/core/text/transform_text/fb53b7431d13db40/output.pkl... Done


### Problem reformulation
Consider each fragment in a mention as a mention

In [5]:
# Each fragment is a mention
mentions = mentions.merge(fragments)
mentions["mention_id"] = mentions["mention_id"].astype(str) + '-' + mentions["fragment_id"].astype(str)
del mentions["fragment_id"]

### Split docs into sentences

In [6]:
from nlstruct.chunking.spacy_tokenization import sentencize
from nlstruct.core.pandas import merge_with_spans, make_id_from_merged
from nlstruct.core.text import partition_spans

# Sentencize and make new docs from sentences
sentences = sentencize(docs)
[mentions], sentences, old_to_new_doc_mapper = partition_spans([mentions], sentences, new_id_name="doc_id")
docs = sentences.merge(old_to_new_doc_mapper).merge(docs.rename({"doc_id": "_doc_id"}, axis=1)).drop(columns=["_doc_id"])
docs["text"] = docs.apply(lambda row: row["text"][row["begin"]:row["end"]], axis=1)

### Deal with overlapping spans
For now we just select the largest mention in an overlapping group and throw the others away

In [7]:
# Extract overlapping spans
conflicts = merge_with_spans(mentions, mentions, on=["doc_id", ("begin", "end")], how="outer", suffixes=("", "_other"))

# Assign a cluster (overlapping fragments) to each fragment
mentions_cluster_ids = make_id_from_merged(
    conflicts[["doc_id", "mention_id"]], 
    conflicts[["doc_id", "mention_id_other"]], 
    apply_on=[(0, mentions[["doc_id", "mention_id"]])])

mentions = (
    mentions
    .groupby(mentions_cluster_ids, as_index=False, observed=True, group_keys=False)
    .apply(lambda group: group.assign(depth=np.argsort(group["begin"]-group["end"])))
    .query('depth == 0').drop(columns=["depth"]))

### Tokenize documents and mentions

In [8]:
from nlstruct.chunking.spacy_tokenization import spacy_tokenize, SPACY_ATTRIBUTES
from nlstruct.core.text import split_into_spans, encode_as_tag

# Tokenize
tokens = spacy_tokenize(docs, lang="en_core_web_sm", spacy_attributes=["norm_"])
mentions = split_into_spans(mentions, tokens, pos_col="token_idx")

In [9]:
# Encode each mention as a BIO tag on its tokens
tokens, label_categories = encode_as_tag(tokens, mentions, tag_scheme="bio", use_token_idx=True, verbose=1, label_cols=["category"])
tokens.rename({"category": "tag"}, axis=1, inplace=True)

100%|██████████| 5590/5590 [00:05<00:00, 987.20it/s] 


### Compute vocabularies and prepare batching

In [10]:
from nlstruct.core.batcher import Batcher
from nlstruct.core.pandas import factorize_rows, normalize_vocabularies, df_to_csr

In [11]:
# Compute idx for each mention in its doc according to its begin indice
mentions = mentions.groupby('doc_id', as_index=False).apply(lambda group: group.assign(idx=np.argsort(np.argsort(group['begin'])))).reset_index(drop=True)

In [12]:
# Construct charset from token->token_norm
tokens["token_charset_id"] = tokens["token_norm"]
charsets = (
    tokens[['token_norm', 'token_charset_id']]
    .drop_duplicates().astype(str)
    .apply(lambda x: pd.Series({"char": tuple(x["token_norm"]), "token_charset_id": x["token_charset_id"]}, name=x.name), axis=1)
    .nlstruct.flatten("char_idx", tile_index=True)
)
[tokens["doc_id"], mentions["doc_id"]], unique_doc_ids = factorize_rows([tokens["doc_id"], mentions["doc_id"]])
[mentions["mention_id"]], unique_mention_ids = factorize_rows([mentions["mention_id"]])
[charsets["token_charset_id"], tokens["token_charset_id"]], unique_charset_ids = factorize_rows([charsets["token_charset_id"], tokens["token_charset_id"]])

unk = {
    "token_norm": "<unk>",
    "char": "<unk>",
}
[tokens, mentions, charsets], vocabularies = normalize_vocabularies([tokens, mentions, charsets], vocabularies=label_categories, train_vocabularies=True, unk=unk, verbose=1)

Will train vocabulary for token_norm
Will train vocabulary for tag
Will train vocabulary for char
Discovered existing vocabulary (9 entities) for tag
Normalized category, with given vocabulary and no unk


In [13]:
batcher = Batcher({
    "doc": {
        "token_norm": df_to_csr(tokens["doc_id"], tokens["token_idx"], tokens["token_norm"].cat.codes, n_rows=len(unique_doc_ids)),
        "token_charset_id": df_to_csr(tokens["doc_id"], tokens["token_idx"], tokens["token_charset_id"], n_rows=len(unique_doc_ids)),
        "token_tag": df_to_csr(tokens["doc_id"], tokens["token_idx"], tokens["tag"].cat.codes, n_rows=len(unique_doc_ids)),
        "token_mask": df_to_csr(tokens["doc_id"], tokens["token_idx"], n_rows=len(unique_doc_ids)),
        "mention_id": df_to_csr(mentions["doc_id"], mentions["idx"], mentions["mention_id"], n_rows=len(unique_doc_ids)),
        "mention_mask": df_to_csr(mentions["doc_id"], mentions["idx"], n_rows=len(unique_doc_ids)),
    },
    "token_charset": {
        "char": df_to_csr(charsets["token_charset_id"], charsets["char_idx"], charsets["char"].cat.codes),
        "mask": df_to_csr(charsets["token_charset_id"], charsets["char_idx"]),
    },
    "mention": {
        "mention_id": mentions["mention_id"],
        "doc_id": mentions["doc_id"],
        "begin": mentions["begin"],
        "end": mentions["end"],
        "category": mentions["category"].cat.codes,
    }}, 
    masks={"doc": {"token_charset_id": "token_mask", "token_norm": "token_mask", "token_tag": "token_mask", "mention_id": "mention_mask"}, 
           "token_charset": {"char": "mask"}}, 
    foreign_ids="absolute").prepare_for_indexing()

## Neural nets

In [14]:
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nlstruct.layers.crf import BIODecoder
from nlstruct.core.torch import torch_global as tg

### Character level encoder

In [15]:
class CharNet(torch.nn.Module):
    def __init__(self, n_chars, dim, n_filters=128, kernel_size=3):
        super().__init__()
        self.embeddings = torch.nn.Embedding(n_chars, dim)
        self.cnn = torch.nn.Conv1d(dim, n_filters, kernel_size=kernel_size)
        
    def forward(self, chars, mask):
        # chars shape: n * n_char_per_token * dim
        # mask shape: n * n_char_per_toke
        state = self.embeddings(chars).transpose(1, 2)
        # shape: n * n_char_per_token * dim
        state = self.cnn(state)
        # shape: n * n_filters * n_windows (n_windows ~= n_char_per_token)
        state = F.relu(state).max(dim=2)[0]
        # shape: n * n_filters
        return state

### Token tagger

In [16]:
class NERNet(torch.nn.Module):
    def __init__(self, 
                 n_chars, 
                 n_tokens, 
                 n_labels, 
                 token_dim, 
                 char_dim, 
                 char_hidden_dim, 
                 hidden_dim, 
                 rnn_layers, 
                 dropout, 
                 char_kernel_size):
        super().__init__()
        self.embeddings = torch.nn.Embedding(n_tokens, token_dim) if n_tokens > 0 else None
        self.char_net = CharNet(n_chars, char_dim, char_hidden_dim, kernel_size=char_kernel_size) if n_chars > 0 else None
        assert self.embeddings is not None or self.char_net is not None
        
        self.crf = BIODecoder(n_labels, with_start_end_transitions=False)
        self.lstm = torch.nn.LSTM((token_dim if n_tokens > 0 else 0) + (char_hidden_dim if n_chars > 0 else 0), 
                                  hidden_dim, dropout=dropout, batch_first=True, num_layers=rnn_layers, bidirectional=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.linear1 = torch.nn.Linear(hidden_dim*2, hidden_dim)
        self.linear2tag = torch.nn.Linear(hidden_dim, self.crf.num_tags)
    
    def forward(self, tokens, mask, charsets, charsets_mask, tokens_charset, tags=None, return_loss=False, return_argmax=False, reduction="mean"):
        # Embed the tokens
        state = torch.cat([
            *((self.embeddings(tokens).masked_fill(~mask.unsqueeze(-1), 0),) if self.embeddings is not None else ()),
            *((self.char_net(charsets, charsets_mask)[tokens_charset],) if self.char_net is not None else ()),
        ], dim=-1)
        
        # Run the lstm (first sort the sentences, pack them accorindg to the mask, unpack them and finally reorder them)
        sorter = torch.argsort(-mask.sum(1))
        invsorter = torch.argsort(sorter)
        state = pad_packed_sequence(
            self.lstm(pack_padded_sequence(state[sorter], mask[sorter].sum(1), batch_first=True))[0], 
            batch_first=True)[0].view(*state.shape[:2], -1)[invsorter]
        
        # Compute the tags scores
        state = F.relu(self.linear1(self.dropout(state)))
        state = self.linear2tag(state)
        
        return {
            # Run the linear CRF forward algorithm on the tokens to compute the loglikelihood of the targets
            "loss": -self.crf(state, mask, tags, reduction=reduction) if return_loss else None, 
            # Run the linear CRF Viterbi algorithm to compute the most likely sequence
            "pred": self.crf.decode(state, mask) if return_argmax else None
        }
    
    def tags_to_mentions(self, pred_batch):
        extracted = self.crf.tags_to_spans(pred_batch["doc", "token_tag"], pred_batch["doc", "token_mask"])
        return Batcher({
            "doc": {
                "doc_id": pred_batch["doc", "doc_id"],
                "mention_id": extracted["doc_spans_id"],
                "mention_mask": extracted["doc_spans_mask"],
            },
            "mention": {
                "begin": extracted["span_begin"],
                "end": extracted["span_end"],
                "category": extracted["span_label"],
                "doc_id": extracted["span_doc_id"],
            }},
            masks={"doc": {"mention_id": "mention_mask"}},
            main_table="doc",
            foreign_ids="relative")

## Training

### Train the NER model

In [17]:
from collections import defaultdict

# Define the training metrics
metrics_info = defaultdict(lambda: False)
flt_format = (5, "{:.4f}".format)
metrics_info.update({
    "train_loss": {"goal": 0, "format": flt_format},
    "train_acc": {"goal": 1, "format": flt_format},
    "val_loss": {"goal": 0, "format": flt_format},
    "val_acc": {"goal": 1, "format": flt_format},
    "val_recall": {"goal": 1, "format": flt_format, "name": "val_rec"},
    "val_precision": {"goal": 1, "format": flt_format, "name": "val_prec"},
    "val_f1": {"goal": 1, "format": flt_format, "name": "val_f1"},
    "duration": {"format": flt_format, "name": "   dur(s)"},
})

In [18]:
from torch.optim import Adam

from nlstruct.core.cache import get_cache
from nlstruct.xp_helpers import make_optimizer_and_schedules, run_optimization
from nlstruct.core.random import seed_all
from nlstruct.core.torch import evaluating
from nlstruct.core.scoring import compute_metrics, merge_pred_and_gold

device = torch.device('cpu')
seed = 42
seed_all(seed) # /!\ Super important to enable reproducibility

# Split into train and val
train_val_splits = np.random.choice([0, 1], size=len(batcher), p=[0.8, 0.2])
train_batcher = batcher[train_val_splits == 0]
val_batcher = batcher[train_val_splits == 1]

batch_size = 64
ner_net = NERNet(
    n_chars=len(vocabularies["char"]),
    n_tokens=len(vocabularies["token_norm"]), 
    n_labels=len(vocabularies["category"]), 
    
    token_dim=50,
    char_dim=16,
    char_hidden_dim=64,
    hidden_dim=100,
    dropout=0.5,
    rnn_layers=2,
    char_kernel_size=3,
).to(device)

optim, schedules = make_optimizer_and_schedules(ner_net, Adam, {
    "lr": 1e-3,
}, [".*"], num_iter_per_epoch=(len(train_batcher) + 1) / batch_size)

# To debug the training, we can just comment the "def run_epoch()" and execute the function body manually without changing anything to it
def run_epoch():
    total_train_loss = 0
    total_train_acc = 0
    total_train_size = 0

    #################
    # TRAINING STEP #
    #################
    for batch in train_batcher.set_main("doc").dataloader(batch_size=batch_size, shuffle=True, device=device):
        optim.zero_grad()
        res = ner_net.forward(
            tokens =         batch["doc", "token_norm"],
            mask =           batch["doc", "token_mask"],
            charsets =       batch["token_charset", "char"],
            charsets_mask =  batch["token_charset", "mask"],
            tokens_charset = batch["doc", "token_charset_id"],
            tags =           batch["doc", "token_tag"],
            return_loss=True)

        # Perform optimization step
        loss = res["loss"]
        loss.backward()
        optim.step()
        for schedule_name, schedule in schedules.items():
            schedule.step()

        total_train_loss += loss.item() * len(batch)
        total_train_size += len(batch)

    ###################
    # VALIDATION STEP #
    ###################
    total_val_loss = 0
    total_val_size = 0
    pred_batches = []
    with evaluating(ner_net): # eval mode, no dropout
        with torch.no_grad(): # no gradients -> faster
            for batch in val_batcher.set_main("doc").dataloader(batch_size=batch_size, device=device):
                res = ner_net.forward(
                    tokens =         batch["doc", "token_norm"],
                    mask =           batch["doc", "token_mask"],
                    charsets =       batch["token_charset", "char"],
                    charsets_mask =  batch["token_charset", "mask"],
                    tokens_charset = batch["doc", "token_charset_id"],
                    tags =           batch["doc", "token_tag"],
                    return_argmax=True,
                    return_loss=True,
                )
                total_val_loss += res['loss'].item() * len(batch)
                total_val_size += len(batch)
                pred_batch = ner_net.tags_to_mentions(Batcher({
                    "doc": {
                        "doc_id": batch["doc", "doc_id"],
                        "token_mask": batch["doc", "token_mask"],
                        "token_tag": res["pred"],
                    },
                }))
                pred_batches.append(pred_batch)
    pred = Batcher.concat(pred_batches)

    # Compute precision, recall and f1 on validation set
    val_metrics = compute_metrics(merge_pred_and_gold(
        pred=pd.DataFrame(dict(pred[{"mention": ["doc_id", "begin", "end", "category"]}])), 
        gold=pd.DataFrame(dict(val_batcher.switch_foreign_ids_mode("absolute")[{"mention": ["doc_id", "begin", "end", "category"]}])), 
        span_policy='partial_strict',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["doc_id", ("begin", "end"), "category"]), prefix='val_')[["val_recall", "val_precision", "val_f1"]].to_dict()
    return \
    {
        "train_loss": total_train_loss / total_train_size,
        "val_loss": total_val_loss / total_val_size,
        **val_metrics,
    }

state = {"ner_net": ner_net, "optim": optim, "schedules": schedules}  # all we need to restart the training from a given epoch
run_optimization(
    main_score='val_f1',
    metrics_info=metrics_info,
    required_start_score=10.00,
    patience_warmup=5,
    patience_rate=0.01,
    patience=10,
    max_epoch=32,

    state=state, 
    cache=get_cache("end2end_ner", {"seed": seed, "data": batcher, **state}, loader=torch.load, dumper=torch.save),  # where to store the model (main name + hashed parameters)
    epoch_fn=run_epoch,
)

Using cache /Users/perceval/Development/data/cache/end2end_ner/00dabfdf19a3ddf0


epoch | train_loss | val_loss | val_rec | val_prec | [31mval_f1[0m |    dur(s)
    1 |    [32m14.5835[0m |  [32m10.5479[0m |  [32m0.0000[0m |   [32m0.0000[0m | [32m0.0000[0m |   37.6641
    2 |     [32m9.5233[0m |   [32m8.0500[0m |  [31m0.0000[0m |   [31m0.0000[0m | [31m0.0000[0m |   41.4897
    3 |     [32m7.0389[0m |   [32m6.4940[0m |  [32m0.1511[0m |   [32m0.6588[0m | [32m0.2458[0m |   43.1168
    4 |     [32m5.8936[0m |   [32m5.6561[0m |  [32m0.3588[0m |   [31m0.5612[0m | [32m0.4377[0m |   45.4009
    5 |     [32m4.9126[0m |   [32m4.7158[0m |  [32m0.4371[0m |   [31m0.5347[0m | [32m0.4810[0m |   46.4007
    6 |     [32m4.0912[0m |   [32m4.1432[0m |  [32m0.4910[0m |   [31m0.5379[0m | [32m0.5134[0m |   43.6068
    7 |     [32m3.4834[0m |   [32m3.7168[0m |  [32m0.5027[0m |   [31m0.6302[0m | [32m0.5593[0m |   44.2475
    8 |     [32m3

{'train_loss': 0.2830894843448208,
 'val_loss': 3.236642536990542,
 'val_recall': 0.8156474820143885,
 'val_precision': 0.8500468603561387,
 'val_f1': 0.8324919687930245,
 'duration': 48.4348509311676,
 'best_epoch': 32}

## Inference

### Preprocessing and batch encoding

In [19]:
docs = test_dataset["docs"].copy()

# Clean the text / perform substitutions
docs, deltas = transform_text(docs, *zip(*subs), return_deltas=True)

# Tokenize
tokens = spacy_tokenize(docs, lang="en_core_web_sm", spacy_attributes=["norm_"])

def make_batcher(tokens):
    tokens = tokens.copy()
    tokens["token_charset_id"] = tokens["token_norm"]
    charsets = (
    tokens[['token_norm', 'token_charset_id']]
        .drop_duplicates().astype(str)
        .apply(lambda x: pd.Series({"char": tuple(x["token_norm"]), "token_charset_id": x["token_charset_id"]}, name=x.name), axis=1)
        .nlstruct.flatten("char_idx", tile_index=True)
    )
    [tokens["doc_id"]], unique_doc_ids = factorize_rows([tokens["doc_id"]])
    [charsets["token_charset_id"], tokens["token_charset_id"]], unique_charset_ids = factorize_rows([charsets["token_charset_id"], tokens["token_charset_id"]])
    [tokens, charsets] = normalize_vocabularies([tokens, charsets], verbose=1, unk=unk, vocabularies=vocabularies, train_vocabularies=False)[0]

    batcher = Batcher({
        "doc": {
            "token_norm": df_to_csr(tokens["doc_id"], tokens["token_idx"], tokens["token_norm"].cat.codes, n_rows=len(unique_doc_ids)),
            "token_charset_id": df_to_csr(tokens["doc_id"], tokens["token_idx"], tokens["token_charset_id"], n_rows=len(unique_doc_ids)),
            "token_mask": df_to_csr(tokens["doc_id"], tokens["token_idx"], n_rows=len(unique_doc_ids)),
        },
        "token_charset": {
            "char": df_to_csr(charsets["token_charset_id"], charsets["char_idx"], charsets["char"].cat.codes),
            "mask": df_to_csr(charsets["token_charset_id"], charsets["char_idx"]),
        }}, 
        masks={
            "doc": {"token_charset_id": "token_mask", "token_norm": "token_mask"}, 
            "token_charset": {"char": "mask"}}, 
        foreign_ids="absolute",
    ).prepare_for_indexing()
    return batcher, unique_doc_ids

batcher, unique_doc_ids = make_batcher(tokens)

Normalized token_norm, with given vocabulary and unk <unk>
Normalized char, with given vocabulary and unk <unk>


### NER inference

In [20]:
device = torch.device('cpu')
pred_batches = []
with evaluating(ner_net): # eval mode, no dropout
    with torch.no_grad(): # no gradients -> faster
        for batch in batcher.set_main("doc").dataloader(batch_size=128, device=device):
            res = ner_net.forward(
                tokens =         batch["doc", "token_norm"],
                mask =           batch["doc", "token_mask"],
                charsets =       batch["token_charset", "char"],
                charsets_mask =  batch["token_charset", "mask"],
                tokens_charset = batch["doc", "token_charset_id"],
                return_argmax=True,
            )
            pred_batches.append(ner_net.tags_to_mentions(Batcher({
                "doc": {
                    "doc_id": batch["doc", "doc_id"],
                    "token_mask": batch["doc", "token_mask"],
                    "token_tag": res["pred"],
                },
            })))
pred = Batcher.concat(pred_batches)

### Convert back to dataset

In [21]:
from nlstruct.core.text import reverse_deltas
from nlstruct.core.dataset import Dataset

# Convert predicted concat mention batches to mention df
pred_mentions = pd.DataFrame({
    "doc_id": unique_doc_ids.iloc[pred["mention", "doc_id"]],
    "begin": pred["mention", "begin"],
    "end": pred["mention", "end"] - 1, # if token span is 2:4, the last token is at position 3
    "category": np.asarray(vocabularies["category"])[pred["mention", "category"]],
}).groupby("doc_id", as_index=False, observed=True).apply(lambda x: x.assign(mention_id=["T"+str(i+1) for i in np.argsort(np.argsort(x["begin"].values))])).reset_index(drop=True)

# Convert token spans (mentions) to character spans
pred_mentions = pd.merge(pred_mentions, tokens[["doc_id", "token_idx", "begin"]],
                         left_on=['doc_id', 'begin'], right_on=['doc_id', 'token_idx'], suffixes=('_x', '')).drop(columns=["token_idx"])
pred_mentions = pd.merge(pred_mentions, tokens[["doc_id", "token_idx", "end"]],
                         left_on=['doc_id', 'end'], right_on=['doc_id', 'token_idx'], suffixes=('_x', '')).drop(columns=["token_idx"])

# Apply the reverse text transformations on the mentions
pred_mentions = reverse_deltas(pred_mentions.drop(columns=["begin_x", "end_x"]), deltas, on=["doc_id"])
pred_dataset = Dataset(
    docs=test_dataset["docs"],
    mentions=pred_mentions[["doc_id", "mention_id", "category"]],
    fragments=pred_mentions[["doc_id", "mention_id", "begin", "end"]].assign(fragment_id=0),
)

### Evaluate the model on test data

In [22]:
gold_mentions = test_dataset["mentions"].merge(test_dataset["fragments"], on=["doc_id", "mention_id"])
pred_mentions = pred_dataset["mentions"].merge(pred_dataset["fragments"], on=["doc_id", "mention_id"])
val_metrics = {
    **compute_metrics(merge_pred_and_gold(
        pred=pred_mentions, atom_pred_level=["doc_id", "mention_id"],
        gold=gold_mentions, atom_gold_level=["doc_id", "mention_id"],
        span_policy='partial_strict',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["doc_id", ("begin", "end"), "category"]), prefix="full/partial_strict/"),
    **compute_metrics(merge_pred_and_gold(
        pred=pred_mentions, atom_pred_level=["doc_id", "mention_id"],
        gold=gold_mentions, atom_gold_level=["doc_id", "mention_id"],
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["doc_id", ("begin", "end"), "category"]), prefix="full/exact/"),
    **compute_metrics(merge_pred_and_gold(
        pred=pred_mentions, atom_pred_level=["doc_id", "mention_id"],
        gold=gold_mentions, atom_gold_level=["doc_id", "mention_id"],
        span_policy='partial_strict',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["doc_id", ("begin", "end")]), prefix="span_only/partial_strict/"),
    **compute_metrics(merge_pred_and_gold(
        pred=pred_mentions, atom_pred_level=["doc_id", "mention_id"],
        gold=gold_mentions, atom_gold_level=["doc_id", "mention_id"],
        span_policy='exact',  # only partially match spans with strict bounds, we could also eval with 'exact' or 'partial'
        on=["doc_id", ("begin", "end")]), prefix="span_only/exact/"),
}
pd.Series({tuple(name.split("/")): value for name, value in val_metrics.items()}).unstack(2)

Unnamed: 0,Unnamed: 1,f1,gold_count,precision,pred_count,recall,tp
full,exact,0.614442,1291.0,0.526229,1811.0,0.738187,953.0
full,partial_strict,0.679562,1291.0,0.581999,1811.0,0.816421,1054.0
span_only,exact,0.678917,1291.0,0.581447,1811.0,0.815647,1053.0
span_only,partial_strict,0.770471,1291.0,0.659856,1811.0,0.925639,1195.0


### Export / visualize

In [23]:
from nlstruct.exporters.visualizers import render_with_displacy
render_with_displacy(pred_dataset.query('doc_id == "100562"'), label_colname="category")

In [24]:
from nlstruct.exporters.visualizers import render_with_displacy
render_with_displacy(test_dataset.query('doc_id == "100562"'), label_colname="category")