diff --git a/ttok/cli.py b/ttok/cli.py index 693843e..3da981c 100644 --- a/ttok/cli.py +++ b/ttok/cli.py @@ -59,10 +59,7 @@ def cli(prompt, input, truncate, model, encode_tokens, decode_tokens, as_tokens) raise click.ClickException("Cannot use --decode with --encode") if as_tokens and not decode_tokens and not encode_tokens: encode_tokens = True - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError as e: - raise click.ClickException(f"Invalid model: {model}") from e + if not prompt and input is None: input = sys.stdin text = " ".join(prompt) @@ -73,6 +70,43 @@ def cli(prompt, input, truncate, model, encode_tokens, decode_tokens, as_tokens) else: text = input_text + if model.startswith("hf:"): + # We use Hugging Face tokenizers instead + try: + import tokenizers + except ImportError: + raise click.ClickException("Hugging Face tokenizers is not installed") + + hf_tokenizer = tokenizers.Tokenizer.from_pretrained(model[3:]) + if decode_tokens: + tokens = [int(token) for token in re.findall(r"\d+", text)] + if as_tokens: + click.echo(hf_tokenizer.decode(tokens)) + else: + click.echo(hf_tokenizer.decode(tokens)) + return + else: + tokens = hf_tokenizer.encode(text).ids + if truncate: + tokens = tokens[:truncate] + + if encode_tokens: + if as_tokens: + click.echo(hf_tokenizer.decode(tokens)) + else: + click.echo(" ".join(str(t) for t in tokens)) + elif truncate: + click.echo(hf_tokenizer.decode(tokens), nl=False) + else: + click.echo(len(tokens)) + return + + # Use tiktoken for OpenAI tokenizers instead + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError as e: + raise click.ClickException(f"Invalid model: {model}") from e + if decode_tokens: tokens = [int(token) for token in re.findall(r"\d+", text)] if as_tokens: