Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
40 changes: 0 additions & 40 deletions complex_tokenization/examples/bne.py

This file was deleted.

11 changes: 0 additions & 11 deletions complex_tokenization/examples/boundless_bpe.py

This file was deleted.

24 changes: 0 additions & 24 deletions complex_tokenization/examples/super_bpe.py

This file was deleted.

10 changes: 0 additions & 10 deletions complex_tokenization/examples/utils.py

This file was deleted.

110 changes: 110 additions & 0 deletions complex_tokenization/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""High-level tokenizer API.

Usage:
tokenizer = BPETokenizer()
tokenizer.train(texts, num_merges=100)
merges = tokenizer.get_merges()

With language-specific decomposition:
from complex_tokenization.languages.hebrew.decompose import decompose_cluster
tokenizer = BPETokenizer()
tokenizer.register_script("Hebrew", decompose_cluster)
tokenizer.train(texts, num_merges=100)
"""

from collections.abc import Callable
from functools import reduce

from complex_tokenization.graph import GraphVertex, Node
from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import characters, register_script, utf8, utf8_clusters
from complex_tokenization.graphs.words import words
from complex_tokenization.trainer import Trainer

UNIT_FUNCTIONS: dict[str, Callable[[str], GraphVertex]] = {
"utf8": utf8,
"utf8_clusters": utf8_clusters,
"characters": characters,
}


class Tokenizer:
def __init__(
self,
units: str | Callable[[str], GraphVertex] = "utf8_clusters",
merge_size: int = 2,
connected: bool = False,
):
if isinstance(units, str):
if units not in UNIT_FUNCTIONS:
raise ValueError(f"Unknown units: {units!r}. Choose from {list(UNIT_FUNCTIONS)}")
self.units = UNIT_FUNCTIONS[units]
else:
self.units = units
self.merge_size = merge_size
self.connected = connected
self.merges: list[tuple[str, ...]] = []

@staticmethod
def register_script(script: str, handler: Callable[[str], GraphVertex]):
register_script(script, handler)

def add_merges(self, merges: list[tuple[str, ...]]):
self.merges.extend(merges)

def _build_graphs(self, texts: list[str]) -> tuple[GraphVertex, ...]:
return tuple(
words(text, connected=self.connected, units=self.units)
for text in texts
)

def train(self, texts: list[str], num_merges: int = 100) -> list[tuple[str, ...]]:
GraphSettings.ONLY_MINIMAL_MERGES = True
GraphSettings.MAX_MERGE_SIZE = self.merge_size

graphs = self._build_graphs(texts)
trainer = Trainer(graphs=graphs)

for merge_strs in self.merges:
nodes = tuple(Node(value=s.encode("utf-8")) for s in merge_strs)
token = reduce(lambda a, b: a + b, nodes)
trainer.graph = trainer.graph.merge(token, nodes)
trainer.merges.append((token, nodes))

trainer.train(num_merges=num_merges)
self.merges = trainer.get_merges()
return self.merges

def get_merges(self) -> list[tuple[str, ...]]:
return list(self.merges)


class BPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters"):
super().__init__(units=units, merge_size=2, connected=False)


class BNETokenizer(Tokenizer):
def __init__(self, n=4, units="utf8_clusters"):
super().__init__(units=units, merge_size=n, connected=False)


class BoundlessBPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters"):
super().__init__(units=units, merge_size=2, connected=True)


class SuperBPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters", disconnected_merges: int | None = None):
super().__init__(units=units, merge_size=2, connected=False)
self._disconnected_merges = disconnected_merges

def train(self, texts: list[str], num_merges: int = 100) -> list[tuple[str, ...]]:
disconnected_merges = self._disconnected_merges or num_merges // 2

phase1 = BPETokenizer(units=self.units)
phase1.train(texts, num_merges=disconnected_merges)

self.connected = True
self.add_merges(phase1.merges)
return super().train(texts, num_merges=num_merges)
28 changes: 11 additions & 17 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@

import pytest

from complex_tokenization.examples.bne import train_bne_tokenizer
from complex_tokenization.examples.boundless_bpe import train_boundless_bpe_tokenizer
from complex_tokenization.examples.bpe import train_bpe_tokenizer, train_huggingface_tokenizer
from complex_tokenization.examples.super_bpe import train_super_bpe_tokenizer
from complex_tokenization.examples.utils import text_dataset
from complex_tokenization.tokenizer import BNETokenizer, BoundlessBPETokenizer, BPETokenizer, SuperBPETokenizer
from tests.utils import text_dataset, train_huggingface_tokenizer


@pytest.fixture(scope="module")
Expand All @@ -17,45 +14,42 @@ def small_dataset():


class TestBenchmarkSmall:
"""Benchmark with small dataset (10 samples) to ensure correctness and basic perf."""

def test_bpe_matches_huggingface_merges(self, small_dataset):
ours = train_bpe_tokenizer(small_dataset, num_merges=10)
ours = BPETokenizer().train(small_dataset, num_merges=10)
hf = train_huggingface_tokenizer(small_dataset, num_merges=10)
hf_normalized = [(m[0].replace("Ġ", " "), m[1]) for m in hf]
assert ours == hf_normalized

def test_bpe_faster_than_60s(self, small_dataset):
start = time.perf_counter()
train_bpe_tokenizer(small_dataset, num_merges=50)
BPETokenizer().train(small_dataset, num_merges=50)
elapsed = time.perf_counter() - start
assert elapsed < 60, f"BPE training took {elapsed:.1f}s (limit: 60s)"

def test_boundless_bpe_faster_than_60s(self, small_dataset):
start = time.perf_counter()
train_boundless_bpe_tokenizer(small_dataset, num_merges=50)
BoundlessBPETokenizer().train(small_dataset, num_merges=50)
elapsed = time.perf_counter() - start
assert elapsed < 60, f"Boundless BPE training took {elapsed:.1f}s (limit: 60s)"

def test_super_bpe_faster_than_60s(self, small_dataset):
start = time.perf_counter()
train_super_bpe_tokenizer(small_dataset, num_merges=50)
SuperBPETokenizer().train(small_dataset, num_merges=50)
elapsed = time.perf_counter() - start
assert elapsed < 60, f"Super BPE training took {elapsed:.1f}s (limit: 60s)"

def test_bne_faster_than_60s(self, small_dataset):
start = time.perf_counter()
train_bne_tokenizer(small_dataset, n=4, num_merges=50)
BNETokenizer(n=4).train(small_dataset, num_merges=50)
elapsed = time.perf_counter() - start
assert elapsed < 60, f"BNE training took {elapsed:.1f}s (limit: 60s)"

def test_all_tokenizers_produce_merges(self, small_dataset):
"""Sanity check that all tokenizer variants produce results."""
num = 10
bpe = train_bpe_tokenizer(small_dataset, num_merges=num)
bne = train_bne_tokenizer(small_dataset, n=4, num_merges=num)
boundless = train_boundless_bpe_tokenizer(small_dataset, num_merges=num)
super_bpe = train_super_bpe_tokenizer(small_dataset, num_merges=num)
bpe = BPETokenizer().train(small_dataset, num_merges=num)
bne = BNETokenizer(n=4).train(small_dataset, num_merges=num)
boundless = BoundlessBPETokenizer().train(small_dataset, num_merges=num)
super_bpe = SuperBPETokenizer().train(small_dataset, num_merges=num)

assert len(bpe) == num
assert len(bne) == num
Expand Down
65 changes: 65 additions & 0 deletions tests/test_tokenizer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Test the high-level Tokenizer API."""

import pytest

from complex_tokenization.tokenizer import (
BNETokenizer,
BoundlessBPETokenizer,
BPETokenizer,
SuperBPETokenizer,
Tokenizer,
)


class TestTokenizerAPI:
def test_default_tokenizer(self):
tok = Tokenizer()
merges = tok.train(["hello world hello world"], num_merges=3)
assert len(merges) == 3

def test_bpe_tokenizer(self):
tok = BPETokenizer()
merges = tok.train(["the teacher teaches the thick"], num_merges=2)
assert all(len(m) == 2 for m in merges)

def test_bne_tokenizer(self):
tok = BNETokenizer(n=4)
merges = tok.train(["the teacher teaches the thick"], num_merges=2)
assert all(2 <= len(m) <= 4 for m in merges)

def test_boundless_bpe_tokenizer(self):
tok = BoundlessBPETokenizer()
merges = tok.train(["the teacher teaches the thick"], num_merges=2)
assert all(len(m) == 2 for m in merges)

def test_super_bpe_tokenizer(self):
tok = SuperBPETokenizer()
merges = tok.train(["the teacher teaches the thick"], num_merges=4)
assert len(merges) == 4

def test_custom_units(self):
tok = Tokenizer(units="utf8")
merges = tok.train(["hello hello"], num_merges=2)
assert len(merges) == 2

def test_invalid_units_raises(self):
with pytest.raises(ValueError, match="Unknown units"):
Tokenizer(units="invalid")

def test_callable_units(self):
from complex_tokenization.graphs.units import utf8
tok = Tokenizer(units=utf8)
merges = tok.train(["test test"], num_merges=2)
assert len(merges) == 2

def test_get_merges_before_train(self):
tok = Tokenizer()
assert tok.get_merges() == []

def test_super_bpe_phase1_matches_bpe(self):
texts = ["the teacher teaches the thick thing"] * 3
bpe = BPETokenizer()
bpe_merges = bpe.train(texts, num_merges=5)
super_bpe = SuperBPETokenizer(disconnected_merges=5)
super_merges = super_bpe.train(texts, num_merges=10)
assert super_merges[:5] == bpe_merges
8 changes: 4 additions & 4 deletions tests/tokenizers/test_bne.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from complex_tokenization.examples.bne import train_bne_tokenizer
from complex_tokenization.examples.utils import text_dataset
from complex_tokenization.tokenizer import BNETokenizer
from tests.utils import text_dataset


class TestBNE:
def test_large_train_bne_tokenizer(self):
"""Test training BNE tokenizer with n=4 and expected merges"""
texts = list(text_dataset(max_samples=10))
merges = train_bne_tokenizer(texts, n=4, num_merges=10)
tok = BNETokenizer(n=4)
merges = tok.train(texts, num_merges=10)

expected = [
(' ', 't', 'h', 'e'),
Expand Down
22 changes: 8 additions & 14 deletions tests/tokenizers/test_boundless_bpe.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
from complex_tokenization.examples.boundless_bpe import train_boundless_bpe_tokenizer
from complex_tokenization.examples.bpe import train_bpe_tokenizer
from complex_tokenization.tokenizer import BoundlessBPETokenizer, BPETokenizer


class TestBoundlessBPE:
def test_basic_boundless_bpe(self):
texts = ["the teacher teaches the thick thing"]
merges = train_boundless_bpe_tokenizer(texts, num_merges=2)
tok = BoundlessBPETokenizer()
merges = tok.train(texts, num_merges=2)
assert len(merges) == 2

def test_boundless_extends_bpe_with_cross_word_merges(self):
"""BPE exhausts intra-word merges; boundless continues across words."""
texts = ["ab cd ab cd ab cd"]
bpe_merges = train_bpe_tokenizer(texts, num_merges=5)
boundless_merges = train_boundless_bpe_tokenizer(texts, num_merges=5)

bpe_merges = BPETokenizer().train(texts, num_merges=5)
boundless_merges = BoundlessBPETokenizer().train(texts, num_merges=5)

assert bpe_merges == [
('a', 'b'),
(' ', 'c'),
(' c', 'd'),
(' ', 'ab'),
('a', 'b'), (' ', 'c'), (' c', 'd'), (' ', 'ab'),
]
assert boundless_merges == [
('a', 'b'),
(' ', 'c'),
(' c', 'd'),
(' ', 'ab'),
(' cd', ' ab'),
('a', 'b'), (' ', 'c'), (' c', 'd'), (' ', 'ab'), (' cd', ' ab'),
]
assert boundless_merges[:len(bpe_merges)] == bpe_merges
assert len(boundless_merges) > len(bpe_merges)
Loading
Loading