# Creating and using a custom tokenizer

The overall process is quite simple
- Implement `MutableDocument`, `MutableEntity`, and `MutableToken`
  - These are tokenizer-specific wrappers
  - In v1 this would have all been `spacy`-specific objects
- Implement and extend `BaseTokenizer`
- Register the component
- Specify its use in the config

We will do things in steps

### Start with the relevant imports and/or constance

In [1]:
from typing import Any, overload, Iterator, Optional, cast

import re

from medcat2.config.config import Config
from medcat2.tokenizing.tokens import BaseDocument, BaseEntity, BaseToken
from medcat2.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken


# define "whitespace"
WHITESPACE = [' ', '\t', '\n']

### Now we create a Token implementation

In [2]:
# the documents, entities, tokens

class Token:
    STOP_TOKENS = set()  # nothing for now

    # for base token
    def __init__(self, doc: 'Document', text: str,
                 token_index: int, start_index: int):
        self._doc = doc
        self._text = text
        self._token_index = token_index
        self._start_index = start_index

    @property
    def text(self) -> str:
        return self._text

    @property
    def lower(self) -> str:
        return self._text.lower()

    @property
    def text_versions(self) -> list[str]:
        return [self.lemma, self.lower]

    @property
    def is_upper(self) -> bool:
        return self._text.islower()

    @property
    def is_stop(self) -> bool:
        return self.lower in self.STOP_TOKENS

    @property
    def char_index(self) -> int:
        return self._start_index

    @property
    def index(self) -> int:
        return self._token_index

    @property
    def text_with_ws(self) -> str:
        end_index = self._start_index + len(self._text)
        text = self._doc.text
        if len(text) <= end_index:
            next_char = ''
        else:
            next_char = text[end_index]
        if next_char in WHITESPACE:
            return self._text + next_char
        return self._text

    @property
    def is_digit(self) -> bool:
        return self._text.isdigit()

    # for mutable token

    @property
    def base(self) -> BaseToken:
        return self  # we implement both in the same class

    _is_punctuation = False

    @property
    def is_punctuation(self) -> bool:
        return self._is_punctuation

    @is_punctuation.setter
    def is_punctuation(self, val: bool) -> None:
        self._is_punctuation = val

    _to_skip = False

    @property
    def to_skip(self) -> bool:
        return self._to_skip
        

    @to_skip.setter
    def to_skip(self, new_val: bool) -> None:
        self._to_skip = new_val

    @property
    def lemma(self) -> str:
        return self.norm

    @property
    def tag(self) -> Optional[str]:
        return None

    @property
    def norm(self) -> str:
        """The normalised text."""
        return self.lower

    @norm.setter
    def norm(self, value: str) -> None:
        pass # nothing, for now

### Moving on to Entity and Document implementation

The former needs to refer to the latter, so we want to define them together

In [3]:
class Entity:

    def __init__(self, doc: 'Document',
                 start_index: int, end_index: int):
        self._doc = doc
        self._start_index = start_index
        self._end_index = end_index
        # defaults
        self.link_candidates: list[str] = []
        self.context_similarity: float = 0.0
        self.confidence: float = 0.0
        self.cui = ''
        self.id = -1  # TODO - what's the default?
        self.detected_name = ''


    # for base entity

    @property
    def start_index(self) -> int:
        return self._start_index

    @property
    def end_index(self) -> int:
        return self._end_index

    @property
    def start_char_index(self) -> int:
        return self._doc[self._start_index].char_index

    @property
    def end_char_index(self) -> int:
        if self._start_index == self.end_index:
            return self._start_index
        end_tkn = self._doc[self._end_index]
        return end_tkn.char_index + len(end_tkn.text)

    @property
    def label(self) -> int:
        return -1

    @property
    def text(self) -> str:
        return self._doc.text[self.start_char_index: self.end_char_index]

    def __iter__(self) -> Iterator[BaseToken]:
        yield from self._doc._tokens[self.start_index: self.end_index]

    def __len__(self) -> int:
        if self._end_index == self._start_index:
            return int(self.end_char_index > self.start_char_index)
        return self.end_index - self.start_index

    # for mutable entity

    @property
    def base(self) -> BaseEntity:
        return self  # we implement both in same class

    _ENTITY_INFO_PREFIX = "ENTITY_INFO:"

    def set_addon_data(self, path: str, val: Any) -> None:
        doc_dict = self._doc.get_addon_data(f"{self._ENTITY_INFO_PREFIX}{path}")
        doc_dict[(self.start_index, self.end_index)] = val

    def get_addon_data(self, path: str) -> Any:
        doc_dict = self._doc.get_addon_data(f"{self._ENTITY_INFO_PREFIX}{path}")
        return doc_dict[(self.start_index, self.end_index)]

    @classmethod
    def register_addon_path(cls, path: str, def_val: Any = None,
                            force: bool = True) -> None:
        # NOTE: saving with Document for persistence
        Document.register_addon_path(f"{cls._ENTITY_INFO_PREFIX}{path}",
            def_val=def_val, force=force)

class Document:

    def __init__(self, text: str, tokens: Optional[list[Token]]):
        self.text = text
        self._tokens = tokens or []
        # filled by NER
        self.all_ents: list[MutableEntity] = []
        # filled by Linker
        self.final_ents: list[MutableEntity] = []

    @overload
    def __getitem__(self, index: int) -> BaseToken:
        pass

    @overload
    def __getitem__(self, index: slice) -> BaseEntity:
        pass

    def __iter__(self) -> Iterator[Token]:
        yield from self._tokens

    def isupper(self) -> bool:
        return self.text.isupper()

    # for mutable document

    @property
    def base(self) -> BaseDocument:
        return self  # same calss again

    def get_tokens(self, start_index: int, end_index: int
                   ) -> list[MutableToken]:
        tkns = []
        for tkn in self:
            if (tkn.char_index >= start_index and
                    tkn.char_index <= end_index):
                tkns.append(tkn)
        return tkns


    def set_addon_data(self, path: str, val: Any) -> None:
        if not hasattr(self.__class__, path):
            raise KeyError(f"Path not registered with {self.__class__}: {path}")
        setattr(self, path, val)


    def get_addon_data(self, path: str) -> Any:
        if not hasattr(self.__class__, path):
            raise KeyError(f"Path not registered with {self.__class__}: {path}")
        return getattr(self, path)


    @classmethod
    def register_addon_path(cls, path: str, def_val: Any = None,
                            force: bool = True) -> None:
        setattr(cls, path, def_val)

### And now the actual tokenizer

In [4]:
class WhitespaceTokenizer:
    SPLIT_PATTERN = re.compile("|".join(map(re.escape, WHITESPACE)))

    # methods needed for the BaseTokenizer protocol

    def create_entity(self, doc: MutableDocument,
                      token_start_index: int, token_end_index: int,
                      label: str) -> MutableEntity:
        return self.entity_from_tokens(doc[token_start_index: token_end_index])

    def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
        doc = cast(Token, tokens[0])._doc
        start_index = tokens[0].base.index
        end_index = tokens[-1].base.index
        return Entity(doc, start_index, end_index)

    def __call__(self, text: str) -> MutableDocument:
        words = self.SPLIT_PATTERN.split(text)
        doc = Document(text, tokens=None)
        start_index = 0
        for word_index, word in enumerate(words):
            # NOTE: this only works if all the whitespace has length 1
            if word:
                # NOTE: ignore empty strings
                doc._tokens.append(Token(doc, word, word_index, start_index))
            # append start index by word length + 1 (for the splitter)
            start_index += len(word) + 1 # splitter
        return doc

    def get_doc_class(self) -> type[MutableDocument]:
        return Document

    def get_entity_class(self) -> type[MutableEntity]:
        return Entity

    # for creation of class based on config
    @classmethod
    def get_init_args(cls, config: Config) -> list[Any]:
        return []

    @classmethod
    def get_init_kwargs(cls, config: Config) -> dict[str, Any]:
        return {}

### Now we need to register the tokenizer

Now that we've got the implementation, we need to register it.

In [5]:
from medcat2.tokenizing.tokenizers import register_tokenizer, list_available_tokenizers
register_tokenizer("whitespace-tokenizer", WhitespaceTokenizer)
print("Registered tokenizers:", list_available_tokenizers())

Registered tokenizers: [('whitespace-tokenizer', 'WhitespaceTokenizer'), ('regex', 'RegexTokenizer'), ('spacy', 'SpacyTokenizer')]


### Now we can see if we can create one through the registry

Normally, this will be automated through setting the tokenizer type in the config.

In [6]:
from medcat2.tokenizing.tokenizers import create_tokenizer
tokenizer = create_tokenizer("whitespace-tokenizer")
print("We've got one:", tokenizer)

We've got one: <__main__.WhitespaceTokenizer object at 0x1129dc5b0>


### Now we may want to tokenizer a simple sentence

In [8]:
sentence = "The quick brown fox jumped over a lazy dog"
doc = tokenizer(sentence)
print("Doc:", doc)
tokens = list(doc)
print("Tokens:", tokens)
print("Token texts:", [tkn.base.text for tkn in tokens])
# compare locations
for tkn in doc:
    start_char_index = tkn.base.char_index
    end_char_index = start_char_index + len(tkn.base.text)
    found_text = doc.base.text[start_char_index: end_char_index]
    print(tkn.base.index, "@", start_char_index, ":", repr(tkn.base.text), 'vs', repr(found_text),
          "\tW whitespace", repr(tkn.base.text_with_ws))

Doc: <__main__.Document object at 0x1129dc400>
Tokens: [<__main__.Token object at 0x1129df640>, <__main__.Token object at 0x1129ddab0>, <__main__.Token object at 0x1129dc8b0>, <__main__.Token object at 0x1129dc8e0>, <__main__.Token object at 0x1129dfbe0>, <__main__.Token object at 0x1129dfc40>, <__main__.Token object at 0x1129de6b0>, <__main__.Token object at 0x1129df820>, <__main__.Token object at 0x1129de050>]
Token texts: ['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'a', 'lazy', 'dog']
0 @ 0 : 'The' vs 'The' 	W whitespace 'The '
1 @ 4 : 'quick' vs 'quick' 	W whitespace 'quick '
2 @ 10 : 'brown' vs 'brown' 	W whitespace 'brown '
3 @ 16 : 'fox' vs 'fox' 	W whitespace 'fox '
4 @ 20 : 'jumped' vs 'jumped' 	W whitespace 'jumped '
5 @ 27 : 'over' vs 'over' 	W whitespace 'over '
6 @ 32 : 'a' vs 'a' 	W whitespace 'a '
7 @ 34 : 'lazy' vs 'lazy' 	W whitespace 'lazy '
8 @ 39 : 'dog' vs 'dog' 	W whitespace 'dog'
