Skip to content

Commit

Permalink
--encode and --decode pair, closes #7
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jul 10, 2023
1 parent 665983c commit ccebd84
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,21 @@ Usage: ttok [OPTIONS] [PROMPT]...
cat input.txt | ttok -t 100 -m gpt2
To view tokens:
To view token integers:
cat input.txt | ttok --tokens
cat input.txt | ttok --encode
To convert tokens back to text:
ttok 9906 1917 --decode
Options:
--version Show the version and exit.
-i, --input FILENAME
-t, --truncate INTEGER Truncate to this many tokens
-m, --model TEXT Which model to use
--tokens Output token integers
--encode, --tokens Output token integers
--decode Convert token integers to text
--help Show this message and exit.
```
Expand Down
25 changes: 16 additions & 9 deletions tests/test_ttok.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,36 @@


@pytest.mark.parametrize(
"args,expected,expected_tokens",
"args,expected_length,expected_tokens",
(
(["one"], "1\n", "606"),
(["one", "two"], "2\n", "606 1403"),
(["boo", "hello", "there", "this", "is"], "5\n", "34093 24748 1070 420 374"),
(["one"], 1, "606"),
(["one", "two"], 2, "606 1403"),
(["boo", "hello", "there", "this", "is"], 5, "34093 24748 1070 420 374"),
(
["boo", "hello", "there", "this", "is", "-m", "gpt2"],
"6\n",
6,
"2127 78 23748 612 428 318",
),
),
)
def test_ttok_count_and_tokens(args, expected, expected_tokens):
def test_ttok_count_and_tokens(args, expected_length, expected_tokens):
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(cli, args)
assert result.exit_code == 0
assert result.output == expected
# Now with --tokens
result2 = runner.invoke(cli, args + ["--tokens"])
assert int(result.output.strip()) == expected_length
# Now with --encode
result2 = runner.invoke(cli, args + ["--encode"])
assert result2.exit_code == 0
assert result2.output.strip() == expected_tokens

# And try round-tripping it through --decode/--encode
as_text = runner.invoke(cli, ["--decode"], input=expected_tokens)
assert as_text.exit_code == 0
as_tokens_again = runner.invoke(cli, ["--encode"], input=as_text.output.strip())
assert as_tokens_again.exit_code == 0
assert as_tokens_again.output.strip() == expected_tokens


@pytest.mark.parametrize("use_stdin", (True, False))
@pytest.mark.parametrize("use_extra_args", (True, False))
Expand Down
28 changes: 23 additions & 5 deletions ttok/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import click
import re
import sys
import tiktoken

Expand All @@ -11,8 +12,13 @@
"-t", "--truncate", "truncate", type=int, help="Truncate to this many tokens"
)
@click.option("-m", "--model", default="gpt-3.5-turbo", help="Which model to use")
@click.option("output_tokens", "--tokens", is_flag=True, help="Output token integers")
def cli(prompt, input, truncate, model, output_tokens):
@click.option(
"encode_tokens", "--encode", "--tokens", is_flag=True, help="Output token integers"
)
@click.option(
"decode_tokens", "--decode", is_flag=True, help="Convert token integers to text"
)
def cli(prompt, input, truncate, model, encode_tokens, decode_tokens):
"""
Count and truncate text based on tokens
Expand All @@ -32,10 +38,16 @@ def cli(prompt, input, truncate, model, output_tokens):
cat input.txt | ttok -t 100 -m gpt2
To view tokens:
To view token integers:
cat input.txt | ttok --tokens
cat input.txt | ttok --encode
To convert tokens back to text:
ttok 9906 1917 --decode
"""
if decode_tokens and encode_tokens:
raise click.ClickException("Cannot use --decode with --encode")
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError as e:
Expand All @@ -49,12 +61,18 @@ def cli(prompt, input, truncate, model, output_tokens):
text = input_text + " " + text
else:
text = input_text

if decode_tokens:
integer_tokens = [int(token) for token in re.findall(r"\d+", text)]
click.echo(encoding.decode(integer_tokens))
return

# Tokenize it
tokens = encoding.encode(text)
if truncate:
tokens = tokens[:truncate]

if output_tokens:
if encode_tokens:
click.echo(" ".join(str(t) for t in tokens))
elif truncate:
click.echo(encoding.decode(tokens), nl=False)
Expand Down

0 comments on commit ccebd84

Please sign in to comment.