In [5]:
import json
from typing import List

import torch
from torch.utils.data import Dataset
from nltk import WhitespaceTokenizer

from src.NLP.nlp_base import NLPBaseDataset

We need to find a way to convert the start/end indices of the character of each labelled text to the start/end indices of the tokens corresponding to the annotated words.

E.g. 

Current setup: `sentence = 'This is an example'`
               
               label    = [{char_start: 8, char_end: 18, text: 'an example', label: 'EXAMPLE'}]
                
Desired setup: `sentence = ['This', 'is', 'an',      'example']`
               
               label    = ['O',    'O',  'EXAMPLE', 'EXAMPLE']

In [27]:
samples = [
    {'data': 'This is an example', 'label': [{'char_start': 8, 'char_end': 18, 'text': 'an example', 'label': 'EXAMPLE'}]},
    {'data': 'This is another example', 'label': [{'char_start': 8, 'char_end': 15, 'text': 'another', 'label': 'STH'}]},
    {
        'data': 'This is another example with multiple annots', 'label': [
            {'char_start': 8, 'char_end': 15, 'text': 'another', 'label': 'STH'},
            {'char_start': 16, 'char_end': 23, 'text': 'example', 'label': 'EXAMPLE'},
            {'char_start': 38, 'char_end': 44, 'text': 'a third', 'label': 'ANNOT'}
        ]
    },
]

sample = samples[2]

def span_tokenize(text):
    # basic space-based tokenizer
    # TODO: Also consider using huggingface tok
    # https://huggingface.co/docs/transformers/main_classes/tokenizer
    spans = [token for token in WhitespaceTokenizer().span_tokenize(text)]
    tokenized_text = [text[span[0]: span[1]] for span in spans]
    return spans, tokenized_text

print(span_tokenize(sample['data']))

([(0, 4), (5, 7), (8, 15), (16, 23), (24, 28), (29, 37), (38, 44)], ['This', 'is', 'another', 'example', 'with', 'multiple', 'annots'])


In [36]:
def convert(entry: dict):
    spans, tokenized_text = span_tokenize(entry['data'])
    new_labels = ['O']*len(tokenized_text)
    for annot in entry['label']:
        label_start = annot['char_start']
        label_end = annot['char_end']
        for i, span in enumerate(spans):
            start, end = span
            label = 'O'
            if (label_start <= start <= label_end) and (label_start <= end <= label_end):
                label = annot['label']
                assert new_labels[i] == "O", f"{new_labels=} {i=} {span=} {label=}, {annot=}"
                new_labels[i] = label
    new_entry = {
        "data": tokenized_text,
        "label": new_labels
    }
    return new_entry

In [37]:
convert(sample)

{'data': ['This', 'is', 'another', 'example', 'with', 'multiple', 'annots'],
 'label': ['O', 'O', 'STH', 'EXAMPLE', 'O', 'O', 'ANNOT']}