In [158]:
import re
from abc import ABC, abstractmethod
from functools import partial
from transformers import AutoTokenizer, FlaxAutoModel
import jax.numpy as jnp

In [159]:
class Tokeniser(ABC):

    @abstractmethod
    def tokenise(self, text: str) -> list[str]:
        pass


In [160]:
class WordTokeniser(Tokeniser):
    def __init__(self, drop_punctuation: bool = True, to_lower: bool = True):

        self.drop_punctuation = drop_punctuation
        self.to_lower = to_lower
        if self.drop_punctuation:
            self.pattern = re.compile(r'\b\w+\b')
        else:
            self.pattern = None

        self.token_to_idx = {}

    @property
    def vocab(self):
        return list(self.token_to_idx.keys())
    
    @property
    def idx_to_token(self):
        return {v:k for k,v in self.token_to_idx.items()}

    def tokenise(self, text: str) -> list[int]:
        """
        Map a string to a list of string tokens
        """
        if self.to_lower:
            text = text.lower()
        
        if self.drop_punctuation:
            str_tokens = self.pattern.findall(text)
        else:
            str_tokens = text.split()

        int_tokens = []
        for t in str_tokens:
            if t not in self.token_to_idx:
                self.token_to_idx[t] = len(self.token_to_idx)

            int_tokens.append(self.token_to_idx[t])

        return jnp.array(int_tokens)


In [161]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    tokeniser = WordTokeniser(f)

with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    idxs = []
    for _, line in zip(range(20), f.readlines()):
        idxs.extend(tokeniser.tokenise(line))
    tokens = [tokeniser.idx_to_token[i.item()] for i in idxs]

    cap = 25
    for t, i, _ in zip(tokens, idxs, range(cap)):
        print(f"{i:02}:\t {t}")

00:	 i
01:	 introduction
02:	 the
03:	 time
04:	 traveller
05:	 for
06:	 so
07:	 it
08:	 will
09:	 be
10:	 convenient
11:	 to
12:	 speak
13:	 of
14:	 him
15:	 was
16:	 expounding
17:	 a
18:	 recondite
19:	 matter
11:	 to
20:	 us
21:	 his
22:	 pale
23:	 grey


## Pre-trained embeddings

In [162]:
from transformers import AutoTokenizer
import jax.numpy as jnp

class HFTokeniser:
    def __init__(self, model_name, max_length=512, truncation=True, add_special_tokens=True, *args, **kwargs):
        self._model = AutoTokenizer.from_pretrained(model_name)
        self.max_length = max_length
        self.truncation = truncation
        self.add_special_tokens = add_special_tokens
        self.args = args
        self.kwargs = kwargs

    @property
    def vocab(self):
        return list(set(self._model.vocab.keys()))
    
    @property
    def token_to_idx(self):
        return self._model.vocab if self._model.vocab is not None else {}

    @property
    def idx_to_token(self):
        return {v:k for k,v in self.token_to_idx.items()}

    def tokenise(self, text: str) -> list[str]:
        words = text.split()
        if not words:
            return jnp.array([])
        chunks = []
        current_chunk = []
        for i, word in enumerate(words):
            #print(f"Word {i:03d}/{len(words)}", end="\r")
            # Temporarily tokenize the current chunk plus the next word to see its length
            temp_tokens = self._model(" ".join(current_chunk + [word]), add_special_tokens=self.add_special_tokens, truncation=self.truncation, *self.args, **self.kwargs)['input_ids']
            if len(temp_tokens) < 512:
                current_chunk.append(word)
            else:
                if current_chunk:
                    chunks.append(" ".join(current_chunk))
                current_chunk = [word]
        if current_chunk:
            chunks.append(" ".join(current_chunk))
        
        # Tokenize each chunk without truncation now, as each should be within the limit
        tokens_list = []
        for chunk in chunks:
            tokens = self._model(chunk, return_tensors='jax', truncation=self.truncation, add_special_tokens=self.add_special_tokens, *self.args, **self.kwargs)['input_ids'][0]
            tokens_list.append(tokens)
        
        # Concatenate the tokens along the sequence dimension
        return jnp.concatenate(tokens_list)

In [166]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    tokeniser = HFTokeniser('bert-base-uncased')
    print(tokeniser.tokenise(f.read()))

[ 101 1045 1012 ... 2158 1012  102]
