In [1]:
from tokenization_baichuan import BaichuanTokenizer

In [8]:
original = BaichuanTokenizer.from_pretrained(".")

The original tokenizer code looks a lot like the original Llama/CodeLlama tokenizer, let's try that.

In [63]:
from transformers.convert_slow_tokenizer import SpmConverter, LlamaConverter, GemmaConverter, _get_prepend_scheme
from tokenizers import decoders, normalizers, pre_tokenizers, processors

In [393]:
class BaichuanConverter(SpmConverter):
    handle_byte_fallback = True

    def vocab(self, proto):
        vocab = [
            (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
            (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
            (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
        ]
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        return vocab

    def unk_id(self, proto):
        unk_id = 0
        return unk_id

    def decoder(self, replacement, add_prefix_space):
        sequence = [
            decoders.Replace("▁", " "),
            decoders.ByteFallback(),
            decoders.Fuse(),
        ]
        return decoders.Sequence(sequence)

    def normalizer(self, proto):
#         return normalizers.Sequence([
#             normalizers.Prepend(prepend="▁"),
#             normalizers.Replace(pattern=" ", content="▁"),
#         ])
        return normalizers.Replace(pattern=" ", content="▁")

    def pre_tokenizer(self, replacement, add_prefix_space):
        return None
    
#         prepend_scheme = "never"
#         return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)

    def post_processor(self):
        # the processor is defined in the LlamaTokenizerFast class.
        return None


In [394]:
converter = BaichuanConverter(original_tokenizer)

In [395]:
converted = converter.converted()

In [396]:
from transformers import PreTrainedTokenizerFast

t_fast = PreTrainedTokenizerFast(
    tokenizer_object=converted,
    model_input_names=original.model_input_names,
    model_max_length=32768,
    clean_up_tokenization_spaces=False,
)

In [397]:
original.encode(" {\n")

[133035]

In [398]:
t_fast.encode(" {\n")

[124108, 133081]

In [386]:
original.decode([133081])

'{\n'

In [387]:
original.encode("hello")

[18632]

In [388]:
t_fast.encode("hello")

[18632]

In [389]:
original.decode([18632])

'hello'

In [390]:
t_fast.decode([30109])

' hello'

In [391]:
t_fast.decode([18632])

'hello'

In [392]:
original.decode([30109])

' hello'

Testing on xnli

In [144]:
from datasets import load_dataset
from tqdm import tqdm

In [145]:
xnli = load_dataset("xnli", "all_languages", split="validation")

In [146]:
def verify(lang, text):
    encoded_original = original.encode(text)
    encoded_fast = t_fast.encode(text)
    assert encoded_fast == encoded_original, f"Fast encode error: {lang} - {text}"
    decoded = original.decode(encoded_original)
    decoded_fast = t_fast.decode(encoded_fast, skip_special_tokens=True)
    assert decoded_fast == decoded, f"Fast decode error: {lang} - {text}"

In [383]:
for p in tqdm(xnli["premise"]):
    for lang, text in p.items():
        verify(lang, text)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2490/2490 [00:28<00:00, 86.78it/s]


Testing on codeparrot

In [159]:
ds = load_dataset("codeparrot/github-code", streaming=True, trust_remote_code=True, split="train")

README.md: 0.00B [00:00, ?B/s]

github-code.py: 0.00B [00:00, ?B/s]

In [160]:
skipped = 0
iterator = iter(ds)
for _ in tqdm(range(1000)):
    item = next(iterator)
    code = item["code"]
    lang = item["language"]
    if False and item["size"] > 1000:
        skipped += 1
        continue
    verify(lang, code)

  0%|                                                                                                                                                                                 | 0/1000 [00:02<?, ?it/s]


AssertionError: Fast encode error: JavaScript - 'use strict';

var clear          = require('es5-ext/array/#/clear')
  , eIndexOf       = require('es5-ext/array/#/e-index-of')
  , setPrototypeOf = require('es5-ext/object/set-prototype-of')
  , callable       = require('es5-ext/object/valid-callable')
  , d              = require('d')
  , ee             = require('event-emitter')
  , Symbol         = require('es6-symbol')
  , iterator       = require('es6-iterator/valid-iterable')
  , forOf          = require('es6-iterator/for-of')
  , Iterator       = require('./lib/iterator')
  , isNative       = require('./is-native-implemented')

  , call = Function.prototype.call, defineProperty = Object.defineProperty
  , SetPoly, getValues;

module.exports = SetPoly = function (/*iterable*/) {
	var iterable = arguments[0];
	if (!(this instanceof SetPoly)) return new SetPoly(iterable);
	if (this.__setData__ !== undefined) {
		throw new TypeError(this + " cannot be reinitialized");
	}
	if (iterable != null) iterator(iterable);
	defineProperty(this, '__setData__', d('c', []));
	if (!iterable) return;
	forOf(iterable, function (value) {
		if (eIndexOf.call(this, value) !== -1) return;
		this.push(value);
	}, this.__setData__);
};

if (isNative) {
	if (setPrototypeOf) setPrototypeOf(SetPoly, Set);
	SetPoly.prototype = Object.create(Set.prototype, {
		constructor: d(SetPoly)
	});
}

ee(Object.defineProperties(SetPoly.prototype, {
	add: d(function (value) {
		if (this.has(value)) return this;
		this.emit('_add', this.__setData__.push(value) - 1, value);
		return this;
	}),
	clear: d(function () {
		if (!this.__setData__.length) return;
		clear.call(this.__setData__);
		this.emit('_clear');
	}),
	delete: d(function (value) {
		var index = eIndexOf.call(this.__setData__, value);
		if (index === -1) return false;
		this.__setData__.splice(index, 1);
		this.emit('_delete', index, value);
		return true;
	}),
	entries: d(function () { return new Iterator(this, 'key+value'); }),
	forEach: d(function (cb/*, thisArg*/) {
		var thisArg = arguments[1], iterator, result, value;
		callable(cb);
		iterator = this.values();
		result = iterator._next();
		while (result !== undefined) {
			value = iterator._resolve(result);
			call.call(cb, thisArg, value, value, this);
			result = iterator._next();
		}
	}),
	has: d(function (value) {
		return (eIndexOf.call(this.__setData__, value) !== -1);
	}),
	keys: d(getValues = function () { return this.values(); }),
	size: d.gs(function () { return this.__setData__.length; }),
	values: d(function () { return new Iterator(this); }),
	toString: d(function () { return '[object Set]'; })
}));
defineProperty(SetPoly.prototype, Symbol.iterator, d(getValues));
defineProperty(SetPoly.prototype, Symbol.toStringTag, d('c', 'Set'));


In [164]:
encoded = original.encode(code)

In [165]:
fast_encoded = t_fast.encode(code)

In [166]:
len(encoded), len(fast_encoded)

(784, 801)

In [167]:
for i, (x, y) in enumerate(zip(encoded, fast_encoded)):
    if x != y:
        print(f"Mismatch at {i}: {x} != {y}")
        break

Mismatch at 206: 133035 != 124108


In [168]:
original.decode([133035])

' {\n'

In [169]:
original.decode([124108])

' '

In [172]:
t_fast.decode([133035])

' {\n'

In [173]:
t_fast.decode([124108])

' '

In [212]:
original.encode(" {\n")

[133035]

In [213]:
t_fast.encode(" {\n")

[124108, 133081]