Skip to content

Commit

Permalink
--split prototype, refs #3
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed May 19, 2023
1 parent 665983c commit f18656d
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion ttok/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import click
import json
import sys
import tiktoken

Expand All @@ -10,9 +11,13 @@
@click.option(
"-t", "--truncate", "truncate", type=int, help="Truncate to this many tokens"
)
@click.option("--split", is_flag=True, help="Split text based on truncate argument")
@click.option(
"-0", "--null", is_flag=True, help="Output split text with null byte delimiters"
)
@click.option("-m", "--model", default="gpt-3.5-turbo", help="Which model to use")
@click.option("output_tokens", "--tokens", is_flag=True, help="Output token integers")
def cli(prompt, input, truncate, model, output_tokens):
def cli(prompt, input, truncate, split, null, model, output_tokens):
"""
Count and truncate text based on tokens
Expand All @@ -36,6 +41,8 @@ def cli(prompt, input, truncate, model, output_tokens):
cat input.txt | ttok --tokens
"""
if split and not truncate:
raise click.ClickException("Cannot use --split without --truncate")
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError as e:
Expand All @@ -51,6 +58,28 @@ def cli(prompt, input, truncate, model, output_tokens):
text = input_text
# Tokenize it
tokens = encoding.encode(text)

if split:
if null:
# Filter out null byte tokens
null_token = encoding.encode("\0")[0]
tokens = [t for t in tokens if t != null_token]
token_chunks = list(chunks(tokens, truncate))
if null:
click.echo(
"\0".join(encoding.decode(chunk) for chunk in token_chunks) + "\0"
)
else:
if output_tokens:
click.echo(json.dumps(token_chunks, indent=2))
else:
click.echo(
json.dumps(
[encoding.decode(chunk) for chunk in token_chunks], indent=2
)
)
return

if truncate:
tokens = tokens[:truncate]

Expand All @@ -60,3 +89,8 @@ def cli(prompt, input, truncate, model, output_tokens):
click.echo(encoding.decode(tokens), nl=False)
else:
click.echo(len(tokens))


def chunks(sequence, n):
for i in range(0, len(sequence), n):
yield sequence[i : i + n]

0 comments on commit f18656d

Please sign in to comment.