In [None]:

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"

import re
from typing import List, Tuple, Union
from transformers import AutoTokenizer, PreTrainedTokenizerBase

In [3]:
MATH_CHARS = set("0123456789+-*/=() ")
MIN_INT = -500
MAX_INT = 500


In [4]:
def is_math_char(ch: str) -> bool:
    # allowed characters INSIDE math spans
    return ch.isdigit() or ch in "+-*/=() "



def split_math_text_spans(text: str) -> List[Tuple[bool, str]]:
    spans = []
    if not text:
        return spans

    # First pass – split by char-class transitions
    raw_spans = []
    current = [text[0]]
    current_is_math = is_math_char(text[0])

    for ch in text[1:]:
        ch_is_math = is_math_char(ch)
        if ch_is_math == current_is_math:
            current.append(ch)
        else:
            raw_spans.append((current_is_math, "".join(current)))
            current = [ch]
            current_is_math = ch_is_math

    raw_spans.append((current_is_math, "".join(current)))

    # Second pass – validate math spans (MUST contain a digit)
    final_spans = []
    for is_math, span in raw_spans:
        if is_math:
            # check if span actually contains math (digit required)
            if any(c.isdigit() for c in span):
                final_spans.append((True, span))
            else:
                # treat span as text if no digits
                final_spans.append((False, span))
        else:
            final_spans.append((False, span))

    return final_spans


MATH_TOKEN_RE = re.compile(r"""
    \s+              |   # whitespace (skip)
    (-?\d+)          |   # integer, possibly negative
    ([+\-*/=()])         # operators / parens
""", re.VERBOSE)


def tokenize_math_expr(expr: str) -> List[str]:
    """
    Turn a math expression string like '-47 * -2 - 35 * -19 = 759'
    into ['-47', '*', '-2', '-', '35', '*', '-19', '=', '759'].
    """
    tokens = []
    pos = 0
    n = len(expr)
    if "--" in expr:
        raise ValueError("Invalid unary sequence '--'")

    while pos < n:
        m = MATH_TOKEN_RE.match(expr, pos)
        if not m:
            # Anything non-math that slipped in: you can choose to raise or
            # fall back to char-level; for now, be strict.
            raise ValueError(f"Unexpected character in math expr: {expr[pos]!r} at position {pos}")
        if m.group(1):  # integer token
            val = int(m.group(1))
            if val < MIN_INT or val > MAX_INT:
                # Not a valid integer token -> treat whole span as TEXT instead
                raise ValueError("Integer out of allowed range")
            tokens.append(m.group(1))

        elif m.group(2):  # operator/parens
            tokens.append(m.group(2))
        # group(0) is full match; could be whitespace, which we ignore
        pos = m.end()
    return tokens


class HybridMathTokenizer:
    """
    Hybrid tokenizer:
      - Uses a base HF tokenizer for natural language.
      - Uses custom integer/operator tokens for math spans.

    Math spans are detected heuristically via MATH_CHARS.
    """

    def __init__(
        self,
        base_tokenizer: Union[str, "PreTrainedTokenizerBase"],
        min_int: int = -500,
        max_int: int = 500,
        add_expr_markers: bool = False,
        extra_special_tokens: List[str] = None,
    ):
        """
        base_tokenizer: HF tokenizer name or instance (e.g. "meta-llama/Llama-3-8B-Instruct").
        min_int, max_int: inclusive integer range for atomic number tokens.
        add_expr_markers: if True, adds <EXPR_START> and <EXPR_END> tokens around math spans.
        extra_special_tokens: optional list of extra tokens (e.g. <FACT_1>, <READY_1>).
        """
        if isinstance(base_tokenizer, str):
            if AutoTokenizer is None:
                raise ImportError("transformers is required to load a tokenizer by name")
            self.base = AutoTokenizer.from_pretrained(base_tokenizer, use_fast=True)
        else:
            self.base = base_tokenizer

        self.min_int = min_int
        self.max_int = max_int
        self.add_expr_markers = add_expr_markers

        # Build list of tokens to add to the base tokenizer
        new_tokens = []

        # Operators / parens
        self.op_tokens = ["+", "-", "*", "/", "=", "(", ")"]
        new_tokens.extend(self.op_tokens)

        # Integers as atomic tokens
        self.int_tokens = [str(i) for i in range(min_int, max_int + 1)]
        new_tokens.extend(self.int_tokens)

        # Optional expression markers and extra specials
        self.expr_start = "<EXPR_START>"
        self.expr_end = "<EXPR_END>"
        self.special_extra = extra_special_tokens or []

        if add_expr_markers:
            new_tokens.extend([self.expr_start, self.expr_end])

        new_tokens.extend(self.special_extra)

        # Add to base tokenizer vocab
        # NOTE: add_tokens will only add tokens that aren't already in vocab
        self.base.add_tokens(new_tokens)

        # For convenience, keep the token->id mapping
        self._update_maps()

    def _update_maps(self):
        self.token_to_id = self.base.get_vocab()
        # Reverse map
        self.id_to_token = {i: t for t, i in self.token_to_id.items()}

    # ------------- public API -------------

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """
        Encode text into a single list of token ids, with math spans
        tokenized using atomic integer/operator tokens.
        """
        spans = split_math_text_spans(text)
        all_ids: List[int] = []

        for is_math, span in spans:
            if not span:
                continue

            if is_math:
                # math span: tokenize via our math rules
                try:
                    math_tokens = tokenize_math_expr(span)
                except ValueError:
                    # fallback: this isn't valid math for our domain → treat as text
                    ids = self.base.encode(span, add_special_tokens=False)
                    all_ids.extend(ids)
                    continue

                if self.add_expr_markers:
                    math_tokens = [self.expr_start] + math_tokens + [self.expr_end]
                ids = self.base.convert_tokens_to_ids(math_tokens)
                all_ids.extend(ids)
            else:
                # normal text span: delegate to base tokenizer
                ids = self.base.encode(
                    span,
                    add_special_tokens=False,  # we handle global special tokens outside
                )
                all_ids.extend(ids)

        if add_special_tokens and hasattr(self.base, "bos_token_id") and hasattr(self.base, "eos_token_id"):
            # Very simple: wrap with BOS/EOS if they exist
            bos = [] if self.base.bos_token_id is None else [self.base.bos_token_id]
            eos = [] if self.base.eos_token_id is None else [self.base.eos_token_id]
            return bos + all_ids + eos

        return all_ids

    def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
        """
        Decode token ids back to text via the base tokenizer.
        For math tokens, since we used actual string forms as tokens,
        this round-trips cleanly.
        """
        return self.base.decode(ids, skip_special_tokens=skip_special_tokens)

    # Convenience helpers if you want to bypass span splitting:

    def encode_math_only(self, expr: str) -> List[int]:
        """Encode a pure math expression string."""
        tokens = tokenize_math_expr(expr)
        if self.add_expr_markers:
            tokens = [self.expr_start] + tokens + [self.expr_end]
        return self.base.convert_tokens_to_ids(tokens)

    def encode_text_only(self, text: str) -> List[int]:
        """Encode pure natural language through the base tokenizer."""
        return self.base.encode(text, add_special_tokens=False)

In [18]:
from transformers import AutoTokenizer

base = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0125-SFT", use_fast=True)

tok = HybridMathTokenizer(base_tokenizer=base, min_int=MIN_INT, max_int=MAX_INT, add_expr_markers=True)

s = "-47 * -2 = 79"
ids = tok.encode(s)
print(ids)
print(tok.decode(ids))


tokenizer_config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/330 [00:00<?, ?B/s]

[50279, 50780, 50733, 11, 50778, 30, 2787, 50781, 50279]
<EXPR_START>-47*-2=79<EXPR_END>


In [23]:
s = "Compare the sizes: $\pi$ ____ $3.14$ (fill in the blank with $=$, $>$, or $<$)."

In [24]:
ids = tok.encode(s)
print(ids)
print(tok.decode(ids))


[50279, 33925, 209, 783, 209, 84, 4219, 27, 209, 1202, 2059, 5, 209, 1713, 209, 5, 50780, 20, 50781, 15, 50780, 1047, 50781, 5, 209, 9, 9337, 209, 249, 209, 783, 209, 22473, 209, 3113, 209, 5, 30, 1366, 209, 5, 31, 1366, 209, 263, 209, 5, 29, 5, 10, 15, 50279]
Compare the sizes: $\pi$ ____ $<EXPR_START>3<EXPR_END>.<EXPR_START>14<EXPR_END>$ (fill in the blank with $=$, $>$, or $<$).


In [8]:
tok.decode([151643])

''

In [9]:
tests = [
    # A. Pure integer arithmetic
    "-47 * -2 = 94",
    "0 + 0 = 0",
    "12 - 5 + 3 - 2",
    "999 + -999",
    "(-3) * (-4)",

    # B. Math glued to text
    "Compute -47 * -2 = 94 please.",
    "What is (12 * 4) - 3?",
    "-47*-2=94 is correct.",
    "Check: 3+4=7. Good?",

    # C. Stray punctuation
    "-47 * -2 = 94.",
    "(-3) * (-4), obviously.",
    "What?",
    "test...",
    "12? no",

    # D. Decimals (should NOT be math)
    "12.5",
    "3.14159",
    "0.0",
    "-.5",
    "5.",

    # E. Garbage mixed
    "12..5",
    "--47",
    "3*-*-2",
    "abc123",
    "123abc",

    # F. Odd whitespace
    "  -47   *   -2    = 94    ",
    "(-3)*(-4)",
    "   (  3 + 4 )  ",

    # G. Multiple expressions
    "The first: -47 * -2 = 94, the second: 12 * 12 = 144.",
    "Compute 3 + 4 = 7 and 8 + 9 = 17.",
    "Edge case: 3-4= -1. Weird!",

    # H. Parentheses hell
    "(((3)))",
    "((3+4)*((2-1)))",
    "(3*(4-(5+6)))",

    # I. False alarms (NOT math)
    "e=mc^2",
    "version1.2.3",
    "file_name-47",
    "token-ids=3",
]

for s in tests:
    print("\n====================")
    print("INPUT:", s)

    try:
        ids = tok.encode(s, add_special_tokens=False)
        print("IDS:", ids)
        decoded = tok.decode(ids)
        print("DECODED:", decoded)
    except Exception as e:
        print("ERROR:", e)



INPUT: -47 * -2 = 94
IDS: [152656, 152118, 9, 152163, 28, 152249, 152657]
DECODED: <EXPR_START>-47*-2=94<EXPR_END>

INPUT: 0 + 0 = 0
IDS: [152656, 15, 10, 15, 28, 15, 152657]
DECODED: <EXPR_START>0+0=0<EXPR_END>

INPUT: 12 - 5 + 3 - 2
IDS: [152656, 152167, 12, 20, 10, 18, 12, 17, 152657]
DECODED: <EXPR_START>12-5+3-2<EXPR_END>

INPUT: 999 + -999
IDS: [152254, 24, 220, 10, 220, 152066, 24]
DECODED: 999 + -999

INPUT: (-3) * (-4)
IDS: [152656, 7, 152162, 8, 9, 7, 152161, 8, 152657]
DECODED: <EXPR_START>(-3)*(-4)<EXPR_END>

INPUT: Compute -47 * -2 = 94 please.
IDS: [46254, 152656, 152118, 9, 152163, 28, 152249, 152657, 30021, 13]
DECODED: Compute<EXPR_START>-47*-2=94<EXPR_END>please.

INPUT: What is (12 * 4) - 3?
IDS: [3838, 220, 285, 152656, 7, 152167, 9, 19, 8, 12, 18, 152657, 30]
DECODED: What is<EXPR_START>(12*4)-3<EXPR_END>?

INPUT: -47*-2=94 is correct.
IDS: [152656, 152118, 9, 152163, 28, 152249, 152657, 285, 220, 19928, 13]
DECODED: <EXPR_START>-47*-2=94<EXPR_END>is correct.

INP