In [128]:
import re
import ipywidgets as widgets
from IPython.display import display, clear_output
from collections import namedtuple

In [None]:
def tokenize(input_str):   
    return re.findall(r'\w+', input_str)

In [None]:
Token = namedtuple('Token', 'token style raw_tokens label start end')

In [147]:
# available for python > 3.5
from typing import List, NamedTuple
from functools import partial

class AnnotationResult(NamedTuple):
    span: tuple
    text: str
    label: str
        
class Token(NamedTuple):
    token: str
    style: str # https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20Styling.html#Predefined-styles
    raw_tokens: list # [str]
    label: str
    start: int
    end: int
    
    @staticmethod
    def merge(token1, token2):
        return Token(
            token=token1.token + ' ' + token2.token,
            style=token1.style or token2.style,
            raw_tokens = token1.raw_tokens + token2.raw_tokens,
            label=token1.label or token2.label,
            start=min(token1.start,token2.start),
            end=max(token1.end, token2.end))

    @staticmethod
    def merge_all(tokens):
        while len(tokens) > 1:
            tokens = [Token.merge(tokens[0], tokens[1])] + tokens[2:]
        return tokens[0]
        

class Annotator:
    
    LINE_MAX_WIDTH = 50
    
    def __init__(self, docs, scheme):
        self.mapping = self._create_style_mapping(scheme)
        self.raw_docs = docs
        self.docs = [self._internal_doc(doc) for doc in docs]
        assert len(docs) > 0, 'Should pass at least 1 text to annotate'
        self.current_index = 0
        self.render()
    
    @classmethod
    def _internal_doc(cls, from_text):
        original_tokens = cls._tokenize(from_text)
        from_index = 0
        doc = []
        for token in original_tokens:
            start = from_text.find(token, from_index)
            from_index = start + len(token)
            doc.append(Token(
                token=token,
                style='',
                raw_tokens=[token],
                start=start,
                end=start + len(token),
                label=None
            ))
        return doc
    
    def _on_label(self, doc, new_label, _):
        new_merged = []
        stack = []
        for token, check in zip(doc, self.span_check_togglers):
            if check.value:
                new_token = token._replace(label=new_label, style=self.mapping[new_label])
                stack.append(new_token)
            else:
                if len(stack):
                    new_merged.append(Token.merge_all(stack))
                    stack = []
                new_merged.append(token)
        if len(stack):
            new_merged.append(Token.merge_all(stack))
        doc[:] = new_merged
        self.render()
    
    def _generate_labelling_for(self, doc):
        self.span_check_togglers = []

        for merged_token in doc:
            self.span_check_togglers.append(widgets.ToggleButton(
                value=False,
                description=merged_token.token,
                button_style=merged_token.style,
                disabled=False,
                tooltip='Click to select for merging + assigning class'
            ))
        
        label_buttons = []
        for label, button_style in self.mapping.items():
            label_button = widgets.Button(description=f'Merge as {label}', button_style=button_style, icon='check')
            label_button.on_click(partial(self._on_label, doc, label))
            label_buttons.append(label_button)
        
        rows = self._break_into_rows(self.span_check_togglers)
        labelling_widget = widgets.VBox([
            *[widgets.HBox(row_with_toggles) for row_with_toggles in rows],
            widgets.HBox(label_buttons)
        ])
        return labelling_widget
        
    def _break_into_rows(self, items_with_desc):
        resulting_rows = []
        current_line = []
        current_length = 0
        for item in items_with_desc:
            if len(item.description) + current_length > self.LINE_MAX_WIDTH:
                resulting_rows.append(current_line)
                current_line = []
                current_length = 0
            current_line.append(item)
            current_length += len(item.description)
        if len(current_line):
            resulting_rows.append(current_line)
        return resulting_rows

    def _navigate(self, direction, _):
        new_index = self.current_index + direction
        if 0 < new_index < len(self.docs):
            self.current_index = new_index
        elif new_index < 0:
            self.current_index = len(self.docs) - 1
        else:
            self.current_index = 0
        self.render()
    
    def render(self):
        clear_output()
        current_doc = self.docs[self.current_index]
        labelling_widget = self._generate_labelling_for(current_doc)
        prev_button = widgets.Button(description='← Prev', button_style='primary')
        prev_button.on_click(partial(self._navigate, -1))
        next_button = widgets.Button(description='Next →', button_style='primary')
        next_button.on_click(partial(self._navigate, 1))
        control_widgets = widgets.HBox([
            prev_button,
            next_button,
        ])
        display(widgets.VBox([
            control_widgets,
            labelling_widget,
        ]))

    @property
    def results(self):
        res = []
        for doc, raw_text in zip(self.docs, self.raw_docs):
            spanning_labels = []
            for token in doc:
                if token.label:
                    start = token.start
                    end = token.end
                    
                    spanning_labels.append(
                        AnnotationResult(label=token.label,
                                         span=(start, end),
                                         text=raw_text[start:end])
                    )
            res.append(spanning_labels)
        return res
                
    @staticmethod
    def _tokenize(text):
        return re.findall(r'\w+', text)
    
    @staticmethod
    def _create_style_mapping(scheme):
        return {label: style for label, style in scheme}

In [152]:
docs = [
    "SALT LAKE CITY — Disney will soon finish acquiring 20th Century Fox, which could mean some major changes are in store for the Marvel Cinematic Universe.",
    "Brazil was one of the final countries that had yet to approve the deal, according to Bloomberg. A source told Bloomberg that Disney was willing to unload the Fox Sports network to different buyers.",
]

In [155]:
scheme = [
    ("LOC", "info"),
    ("PER", "success"),
    ("ORG", "warning"),
    ("MISC", "success"),
]

In [156]:
annotations = Annotator(docs, scheme)

VBox(children=(HBox(children=(Button(button_style='primary', description='← Prev', style=ButtonStyle()), Butto…

In [157]:
annotations.results

[[AnnotationResult(span=(0, 14), text='SALT LAKE CITY', label='LOC'),
  AnnotationResult(span=(17, 23), text='Disney', label='ORG'),
  AnnotationResult(span=(126, 151), text='Marvel Cinematic Universe', label='MISC')],
 [AnnotationResult(span=(0, 6), text='Brazil', label='LOC'),
  AnnotationResult(span=(85, 94), text='Bloomberg', label='ORG'),
  AnnotationResult(span=(110, 119), text='Bloomberg', label='ORG'),
  AnnotationResult(span=(125, 131), text='Disney', label='ORG'),
  AnnotationResult(span=(158, 168), text='Fox Sports', label='ORG')]]

In [None]:
def annotate(text, labels, tokenize_fn=tokenize):
    mapping = {label: style for label, style in labels}
    def merge(token1, token2):
        return Token(
            token=token1.token + ' ' + token2.token,
            style=token1.style or token2.style,
            raw_tokens = token1.raw_tokens + token2.raw_tokens,
            label=token1.label or token2.label,
            start=min(token1.start,token2.start),
            end=max(token1.end, token2.end))

    def merge_list(tokens):
        while len(tokens) > 1:
            tokens = [merge(tokens[0], tokens[1])] + tokens[2:]
        return tokens[0]

    
    original_tokens = tokenize_fn(text)
    from_index = 0
    merged_tokens = []
    for token in original_tokens:
        start = text.find(token, from_index)
        from_index = start + len(token)
        merged_tokens.append(Token(
            token=token,
            style='',
            raw_tokens=[token],
            start=start,
            end=start + len(token),
            label=None
        ))
    def repaint():
        clear_output()
        to_display = []

        for merged in merged_tokens:
            to_display.append(widgets.ToggleButton(
                value=False,
                description=merged.token,
                button_style=merged.style,
                disabled=False,
                tooltip='Click to select for merging'
            ))

        def on_label(b):
            nonlocal buttons_meta, merged_tokens
            new_merged = []
            stack = []
            new_label = buttons_meta[b]['label']
            for token, check in zip(merged_tokens, to_display):
                if check.value:
                    new_token = token._replace(label=new_label, style=mapping[new_label])
                    stack.append(new_token)
                else:
                    if len(stack):
                        new_merged.append(merge_list(stack))
                        stack = []
                    new_merged.append(token)
            if len(stack):
                new_merged.append(merge_list(stack))
            merged_tokens = new_merged
            repaint()
        
        buttons_meta = {}
        label_buttons = []
        for label, button_style in labels:
            label_button = widgets.Button(description=f'Merge as {label}', button_style=button_style, icon='check')
            buttons_meta[label_button] = {'label': label}
            label_button.on_click(on_label)
            label_buttons.append(label_button)
        labelling_widget = widgets.VBox([
            widgets.HBox(to_display),
            widgets.HBox(label_buttons)
        ])
        display(labelling_widget)
    repaint()
    def get_results():
        spanning_labels = []
        for token in merged_tokens:
            if token.label:
                start = token.start
                end = token.end
                spanning_labels.append((token.label, (start, end), text[start:end]))
        return spanning_labels
    return get_results
    
get_results = annotate(text = 'this is not a test', labels=[('green', 'success'), ('yellow', 'warning')])

In [None]:
get_results()