Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from complex_tokenization.graphs.settings import GraphSettings


@pytest.fixture(autouse=True)
def reset_graph_settings():
original = {
"USE_SINGLETONS": GraphSettings.USE_SINGLETONS,
"MAX_MERGE_SIZE": GraphSettings.MAX_MERGE_SIZE,
"ONLY_MINIMAL_MERGES": GraphSettings.ONLY_MINIMAL_MERGES,
}
yield
GraphSettings.USE_SINGLETONS = original["USE_SINGLETONS"]
GraphSettings.MAX_MERGE_SIZE = original["MAX_MERGE_SIZE"]
GraphSettings.ONLY_MINIMAL_MERGES = original["ONLY_MINIMAL_MERGES"]


@pytest.fixture(autouse=True)
def clear_singleton_cache():
from complex_tokenization.graph import GraphVertex
GraphVertex._instances.clear()
yield
GraphVertex._instances.clear()
36 changes: 36 additions & 0 deletions tests/test_singletons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from complex_tokenization.graph import GraphVertex, Node, NodesSequence
from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import utf8


class TestSingletons:
def test_singletons_off_creates_distinct_objects(self):
GraphSettings.USE_SINGLETONS = False
a = Node(value=b'a')
b = Node(value=b'a')
assert a == b
assert a is not b

def test_singletons_on_returns_same_object(self):
GraphSettings.USE_SINGLETONS = True
a = Node(value=b'a')
b = Node(value=b'a')
assert a is b

def test_singletons_different_values_different_objects(self):
GraphSettings.USE_SINGLETONS = True
a = Node(value=b'a')
b = Node(value=b'b')
assert a is not b

def test_singletons_different_classes_not_shared(self):
GraphSettings.USE_SINGLETONS = True
node = Node(value=b'a')
seq = NodesSequence(nodes=(node,))
assert type(node) is not type(seq)

def test_singleton_merge_preserves_identity(self):
GraphSettings.USE_SINGLETONS = True
graph = utf8("aa")
assert isinstance(graph, NodesSequence)
assert graph.nodes[0] is graph.nodes[1]
79 changes: 79 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest

from complex_tokenization.graph import Node, NodesSequence
from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import utf8
from complex_tokenization.trainer import Trainer


class TestTrainer:
def test_trainer_requires_graph_or_graphs(self):
with pytest.raises(ValueError, match="Must provide either graph or graphs"):
Trainer()

def test_trainer_rejects_both_graph_and_graphs(self):
graph = utf8("test")
with pytest.raises(ValueError, match="Must provide either graph or graphs, not both"):
Trainer(graph=graph, graphs=(graph,))

def test_train_single_node_no_merges(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
node = Node(value=b'a')
trainer = Trainer(graph=node)
trainer.train(num_merges=10)
assert len(trainer.merges) == 0

def test_train_stops_when_no_merges_left(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
graph = utf8("ab")
trainer = Trainer(graph=graph)
trainer.train(num_merges=100)
assert len(trainer.merges) == 1

def test_train_merge_reduces_graph(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
graph = utf8("aaa")
trainer = Trainer(graph=graph)
trainer.train(num_merges=1)
assert len(trainer.merges) == 1
assert isinstance(trainer.graph, NodesSequence)

def test_train_full_merge_to_single_node(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
graph = utf8("aa")
trainer = Trainer(graph=graph)
trainer.train(num_merges=1)
assert len(trainer.merges) == 1
assert isinstance(trainer.graph, Node)

def test_get_merges_returns_readable(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
graph = utf8("abab")
trainer = Trainer(graph=graph)
trainer.train(num_merges=1)
merges = trainer.get_merges()
assert len(merges) == 1
assert merges[0] == ('a', 'b')

def test_train_with_multiple_graphs(self):
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.ONLY_MINIMAL_MERGES = True
graphs = (utf8("ab"), utf8("ab"), utf8("cd"))
trainer = Trainer(graphs=graphs)
trainer.train(num_merges=1)
assert trainer.get_merges()[0] == ('a', 'b')

def test_characters_produce_valid_bytes(self):
from complex_tokenization.graphs.units import characters
graph = characters("hello")
assert bytes(graph) == b"hello"

def test_characters_non_ascii_produce_valid_bytes(self):
from complex_tokenization.graphs.units import characters
graph = characters("שלום")
assert bytes(graph) == "שלום".encode("utf-8")